DataOperationsCatalog.TrainTestSplit メソッド
定義
重要
一部の情報は、リリース前に大きく変更される可能性があるプレリリースされた製品に関するものです。 Microsoft は、ここに記載されている情報について、明示または黙示を問わず、一切保証しません。
指定された分数に従って、データセットをトレーニング セットとテスト セットに分割します。
指定された場合を尊重します samplingKeyColumnName
。
public Microsoft.ML.DataOperationsCatalog.TrainTestData TrainTestSplit (Microsoft.ML.IDataView data, double testFraction = 0.1, string samplingKeyColumnName = default, int? seed = default);
member this.TrainTestSplit : Microsoft.ML.IDataView * double * string * Nullable<int> -> Microsoft.ML.DataOperationsCatalog.TrainTestData
Public Function TrainTestSplit (data As IDataView, Optional testFraction As Double = 0.1, Optional samplingKeyColumnName As String = Nothing, Optional seed As Nullable(Of Integer) = Nothing) As DataOperationsCatalog.TrainTestData
パラメーター
- data
- IDataView
分割するデータセット。
- testFraction
- Double
テスト セットに入るデータの割合。
- samplingKeyColumnName
- String
行のグループ化に使用する列の名前。 2 つの例が同じ値 samplingKeyColumnName
を共有している場合は、同じサブセット (トレーニングまたはテスト) に表示されます。 これを使用して、列車からテスト セットへのラベル漏れがないことを確認できます。
ランク付け実験を実行する場合は、 samplingKeyColumnName
GroupId 列である必要があることに注意してください。
行のグループ化が実行されない場合 null
。
戻り値
例
using System;
using System.Collections.Generic;
using Microsoft.ML;
namespace Samples.Dynamic
{
/// <summary>
/// Sample class showing how to use TrainTestSplit.
/// </summary>
public static class TrainTestSplit
{
public static void Example()
{
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
var mlContext = new MLContext();
// Generate some data points.
var examples = GenerateRandomDataPoints(10);
// Convert the examples list to an IDataView object, which is consumable
// by ML.NET API.
var dataview = mlContext.Data.LoadFromEnumerable(examples);
// Leave out 10% of the dataset for testing.For some types of problems,
// for example for ranking or anomaly detection, we must ensure that the
// split leaves the rows with the same value in a particular column, in
// one of the splits. So below, we specify Group column as the column
// containing the sampling keys. Notice how keeping the rows with the
// same value in the Group column overrides the testFraction definition.
var split = mlContext.Data
.TrainTestSplit(dataview, testFraction: 0.1,
samplingKeyColumnName: "Group");
var trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);
var testSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 1], [Features, 0.8173254]
// [Group, 1], [Features, 0.5581612]
// [Group, 1], [Features, 0.5588848]
// [Group, 1], [Features, 0.4421779]
// [Group, 1], [Features, 0.2737045]
// The data in the Test split.
// [Group, 0], [Features, 0.7262433]
// [Group, 0], [Features, 0.7680227]
// [Group, 0], [Features, 0.2060332]
// [Group, 0], [Features, 0.9060271]
// [Group, 0], [Features, 0.9775497]
// Example of a split without specifying a sampling key column.
split = mlContext.Data.TrainTestSplit(dataview, testFraction: 0.2);
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TrainSet, reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(split.TestSet, reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 0], [Features, 0.7680227]
// [Group, 1], [Features, 0.5581612]
// [Group, 0], [Features, 0.2060332]
// [Group, 1], [Features, 0.4421779]
// [Group, 0], [Features, 0.9775497]
// [Group, 1], [Features, 0.2737045]
// The data in the Test split.
// [Group, 1], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
yield return new DataPoint
{
Group = i % 2,
// Create random features that are correlated with label.
Features = (float)random.NextDouble()
};
}
}
// Example with label and group column. A data set is a collection of such
// examples.
private class DataPoint
{
public float Group { get; set; }
public float Features { get; set; }
}
// print helper
private static void PrintPreviewRows(IEnumerable<DataPoint> trainSet,
IEnumerable<DataPoint> testSet)
{
Console.WriteLine($"The data in the Train split.");
foreach (var row in trainSet)
Console.WriteLine($"{row.Group}, {row.Features}");
Console.WriteLine($"\nThe data in the Test split.");
foreach (var row in testSet)
Console.WriteLine($"{row.Group}, {row.Features}");
}
}
}