DataOperationsCatalog.TrainTestSplit 메서드
정의
중요
일부 정보는 릴리스되기 전에 상당 부분 수정될 수 있는 시험판 제품과 관련이 있습니다. Microsoft는 여기에 제공된 정보에 대해 어떠한 명시적이거나 묵시적인 보증도 하지 않습니다.
데이터 세트를 학습 집합으로 분할하고 지정된 분수에 따라 테스트 집합을 분할합니다.
제공된 경우를 samplingKeyColumnName
존중합니다.
public Microsoft.ML.DataOperationsCatalog.TrainTestData TrainTestSplit (Microsoft.ML.IDataView data, double testFraction = 0.1, string samplingKeyColumnName = default, int? seed = default);
member this.TrainTestSplit : Microsoft.ML.IDataView * double * string * Nullable<int> -> Microsoft.ML.DataOperationsCatalog.TrainTestData
Public Function TrainTestSplit (data As IDataView, Optional testFraction As Double = 0.1, Optional samplingKeyColumnName As String = Nothing, Optional seed As Nullable(Of Integer) = Nothing) As DataOperationsCatalog.TrainTestData
매개 변수
- data
- IDataView
분할할 데이터 세트입니다.
- testFraction
- Double
테스트 집합에 들어갈 데이터의 소수입니다.
- samplingKeyColumnName
- String
행 그룹화에 사용할 열의 이름입니다. 두 예제가 동일한 값을 samplingKeyColumnName
공유하는 경우 동일한 하위 집합(학습 또는 테스트)에 표시되도록 보장됩니다. 이는 기차에서 테스트 세트로 레이블이 누출되지 않도록 하는 데 사용할 수 있습니다.
순위 실험을 samplingKeyColumnName
수행할 때 GroupId 열이어야 합니다.
행 그룹화가 수행되지 않는 경우 null
반환
예제
using System;
using System.Collections.Generic;
using Microsoft.ML;
namespace Samples.Dynamic
{
/// <summary>
/// Sample class showing how to use TrainTestSplit.
/// </summary>
public static class TrainTestSplit
{
public static void Example()
{
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
var mlContext = new MLContext();
// Generate some data points.
var examples = GenerateRandomDataPoints(10);
// Convert the examples list to an IDataView object, which is consumable
// by ML.NET API.
var dataview = mlContext.Data.LoadFromEnumerable(examples);
// Leave out 10% of the dataset for testing.For some types of problems,
// for example for ranking or anomaly detection, we must ensure that the
// split leaves the rows with the same value in a particular column, in
// one of the splits. So below, we specify Group column as the column
// containing the sampling keys. Notice how keeping the rows with the
// same value in the Group column overrides the testFraction definition.
var split = mlContext.Data
.TrainTestSplit(dataview, testFraction: 0.1,
samplingKeyColumnName: "Group");
var trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);
var testSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 1], [Features, 0.8173254]
// [Group, 1], [Features, 0.5581612]
// [Group, 1], [Features, 0.5588848]
// [Group, 1], [Features, 0.4421779]
// [Group, 1], [Features, 0.2737045]
// The data in the Test split.
// [Group, 0], [Features, 0.7262433]
// [Group, 0], [Features, 0.7680227]
// [Group, 0], [Features, 0.2060332]
// [Group, 0], [Features, 0.9060271]
// [Group, 0], [Features, 0.9775497]
// Example of a split without specifying a sampling key column.
split = mlContext.Data.TrainTestSplit(dataview, testFraction: 0.2);
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 0], [Features, 0.7680227]
// [Group, 1], [Features, 0.5581612]
// [Group, 0], [Features, 0.2060332]
// [Group, 1], [Features, 0.4421779]
// [Group, 0], [Features, 0.9775497]
// [Group, 1], [Features, 0.2737045]
// The data in the Test split.
// [Group, 1], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
yield return new DataPoint
{
Group = i % 2,
// Create random features that are correlated with label.
Features = (float)random.NextDouble()
};
}
}
// Example with label and group column. A data set is a collection of such
// examples.
private class DataPoint
{
public float Group { get; set; }
public float Features { get; set; }
}
// print helper
private static void PrintPreviewRows(IEnumerable<DataPoint> trainSet,
IEnumerable<DataPoint> testSet)
{
Console.WriteLine($"The data in the Train split.");
foreach (var row in trainSet)
Console.WriteLine($"{row.Group}, {row.Features}");
Console.WriteLine($"\nThe data in the Test split.");
foreach (var row in testSet)
Console.WriteLine($"{row.Group}, {row.Features}");
}
}
}