다음을 통해 공유


DataOperationsCatalog.TrainTestSplit 메서드

정의

데이터 세트를 학습 집합으로 분할하고 지정된 분수에 따라 테스트 집합을 분할합니다. 제공된 경우를 samplingKeyColumnName 존중합니다.

public Microsoft.ML.DataOperationsCatalog.TrainTestData TrainTestSplit (Microsoft.ML.IDataView data, double testFraction = 0.1, string samplingKeyColumnName = default, int? seed = default);
member this.TrainTestSplit : Microsoft.ML.IDataView * double * string * Nullable<int> -> Microsoft.ML.DataOperationsCatalog.TrainTestData
Public Function TrainTestSplit (data As IDataView, Optional testFraction As Double = 0.1, Optional samplingKeyColumnName As String = Nothing, Optional seed As Nullable(Of Integer) = Nothing) As DataOperationsCatalog.TrainTestData

매개 변수

data
IDataView

분할할 데이터 세트입니다.

testFraction
Double

테스트 집합에 들어갈 데이터의 소수입니다.

samplingKeyColumnName
String

행 그룹화에 사용할 열의 이름입니다. 두 예제가 동일한 값을 samplingKeyColumnName공유하는 경우 동일한 하위 집합(학습 또는 테스트)에 표시되도록 보장됩니다. 이는 기차에서 테스트 세트로 레이블이 누출되지 않도록 하는 데 사용할 수 있습니다. 순위 실험을 samplingKeyColumnName 수행할 때 GroupId 열이어야 합니다. 행 그룹화가 수행되지 않는 경우 null

seed
Nullable<Int32>

학습 테스트 분할에 대한 행을 선택하는 데 사용되는 난수 생성기의 초기값입니다.

반환

예제

using System;
using System.Collections.Generic;
using Microsoft.ML;

namespace Samples.Dynamic
{
    /// <summary>
    /// Sample class showing how to use TrainTestSplit.
    /// </summary>
    public static class TrainTestSplit
    {
        public static void Example()
        {
            // Creating the ML.Net IHostEnvironment object, needed for the pipeline.
            var mlContext = new MLContext();

            // Generate some data points.
            var examples = GenerateRandomDataPoints(10);

            // Convert the examples list to an IDataView object, which is consumable
            // by ML.NET API.
            var dataview = mlContext.Data.LoadFromEnumerable(examples);

            // Leave out 10% of the dataset for testing.For some types of problems,
            // for example for ranking or anomaly detection, we must ensure that the
            // split leaves the rows with the same value in a particular column, in
            // one of the splits. So below, we specify Group column as the column
            // containing the sampling keys. Notice how keeping the rows with the
            // same value in the Group column overrides the testFraction definition. 
            var split = mlContext.Data
                .TrainTestSplit(dataview, testFraction: 0.1,
                samplingKeyColumnName: "Group");

            var trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);

            var testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);

            //  The data in the Train split.
            //  [Group, 1], [Features, 0.8173254]
            //  [Group, 1], [Features, 0.5581612]
            //  [Group, 1], [Features, 0.5588848]
            //  [Group, 1], [Features, 0.4421779]
            //  [Group, 1], [Features, 0.2737045]

            //  The data in the Test split.
            //  [Group, 0], [Features, 0.7262433]
            //  [Group, 0], [Features, 0.7680227]
            //  [Group, 0], [Features, 0.2060332]
            //  [Group, 0], [Features, 0.9060271]
            //  [Group, 0], [Features, 0.9775497]

            // Example of a split without specifying a sampling key column.
            split = mlContext.Data.TrainTestSplit(dataview, testFraction: 0.2);
            trainSet = mlContext.Data
                .CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);

            testSet = mlContext.Data
                .CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);

            PrintPreviewRows(trainSet, testSet);

            // The data in the Train split.
            // [Group, 0], [Features, 0.7262433]
            // [Group, 1], [Features, 0.8173254]
            // [Group, 0], [Features, 0.7680227]
            // [Group, 1], [Features, 0.5581612]
            // [Group, 0], [Features, 0.2060332]
            // [Group, 1], [Features, 0.4421779]
            // [Group, 0], [Features, 0.9775497]
            // [Group, 1], [Features, 0.2737045]

            // The data in the Test split.
            // [Group, 1], [Features, 0.5588848]
            // [Group, 0], [Features, 0.9060271]

        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            for (int i = 0; i < count; i++)
            {
                yield return new DataPoint
                {
                    Group = i % 2,

                    // Create random features that are correlated with label.
                    Features = (float)random.NextDouble()
                };
            }
        }

        // Example with label and group column. A data set is a collection of such
        // examples.
        private class DataPoint
        {
            public float Group { get; set; }

            public float Features { get; set; }
        }

        // print helper
        private static void PrintPreviewRows(IEnumerable<DataPoint> trainSet,
            IEnumerable<DataPoint> testSet)

        {

            Console.WriteLine($"The data in the Train split.");
            foreach (var row in trainSet)
                Console.WriteLine($"{row.Group}, {row.Features}");

            Console.WriteLine($"\nThe data in the Test split.");
            foreach (var row in testSet)
                Console.WriteLine($"{row.Group}, {row.Features}");
        }
    }
}

적용 대상