DataOperationsCatalog.CrossValidationSplit 方法
定义
重要
一些信息与预发行产品相关,相应产品在发行之前可能会进行重大修改。 对于此处提供的信息,Microsoft 不作任何明示或暗示的担保。
将数据集拆分为训练集和测试集的交叉验证折叠。
samplingKeyColumnName
尊重提供时。
public System.Collections.Generic.IReadOnlyList<Microsoft.ML.DataOperationsCatalog.TrainTestData> CrossValidationSplit (Microsoft.ML.IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = default, int? seed = default);
member this.CrossValidationSplit : Microsoft.ML.IDataView * int * string * Nullable<int> -> System.Collections.Generic.IReadOnlyList<Microsoft.ML.DataOperationsCatalog.TrainTestData>
Public Function CrossValidationSplit (data As IDataView, Optional numberOfFolds As Integer = 5, Optional samplingKeyColumnName As String = Nothing, Optional seed As Nullable(Of Integer) = Nothing) As IReadOnlyList(Of DataOperationsCatalog.TrainTestData)
参数
- data
- IDataView
要拆分的数据集。
- numberOfFolds
- Int32
交叉验证折叠数。
- samplingKeyColumnName
- String
用于对行进行分组的列的名称。 如果两个示例共享相同的值 samplingKeyColumnName
,则保证它们出现在 (训练或测试) 的同一子集中。 这可用于确保不会将标签从训练泄漏到测试集。
请注意,执行排名试验时, samplingKeyColumnName
必须是 GroupId 列。
如果未 null
执行行分组。
返回
示例
using System;
using System.Collections.Generic;
using Microsoft.ML;
namespace Samples.Dynamic
{
/// <summary>
/// Sample class showing how to use CrossValidationSplit.
/// </summary>
public static class CrossValidationSplit
{
public static void Example()
{
// Creating the ML.Net IHostEnvironment object, needed for the pipeline.
var mlContext = new MLContext();
// Generate some data points.
var examples = GenerateRandomDataPoints(10);
// Convert the examples list to an IDataView object, which is consumable
// by ML.NET API.
var dataview = mlContext.Data.LoadFromEnumerable(examples);
// Cross validation splits your data randomly into set of "folds", and
// creates groups of Train and Test sets, where for each group, one fold
// is the Test and the rest of the folds the Train. So below, we specify
// Group column as the column containing the sampling keys. If we pass
// that column to cross validation it would be used to break data into
// certain chunks.
var folds = mlContext.Data
.CrossValidationSplit(dataview, numberOfFolds: 3,
samplingKeyColumnName: "Group");
var trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[0].TrainSet,
reuseRowObject: false);
var testSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[0].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 1], [Features, 0.8173254]
// [Group, 2], [Features, 0.7680227]
// [Group, 1], [Features, 0.2060332]
// [Group, 2], [Features, 0.5588848]
// [Group, 1], [Features, 0.4421779]
// [Group, 2], [Features, 0.9775497]
//
// The data in the Test split.
// [Group, 0], [Features, 0.7262433]
// [Group, 0], [Features, 0.5581612]
// [Group, 0], [Features, 0.9060271]
// [Group, 0], [Features, 0.2737045]
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[1].TrainSet,
reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[1].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 2], [Features, 0.7680227]
// [Group, 0], [Features, 0.5581612]
// [Group, 2], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
// [Group, 2], [Features, 0.9775497]
// [Group, 0], [Features, 0.2737045]
//
// The data in the Test split.
// [Group, 1], [Features, 0.8173254]
// [Group, 1], [Features, 0.2060332]
// [Group, 1], [Features, 0.4421779]
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[2].TrainSet,
reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[2].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 0], [Features, 0.5581612]
// [Group, 1], [Features, 0.2060332]
// [Group, 0], [Features, 0.9060271]
// [Group, 1], [Features, 0.4421779]
// [Group, 0], [Features, 0.2737045]
//
// The data in the Test split.
// [Group, 2], [Features, 0.7680227]
// [Group, 2], [Features, 0.5588848]
// [Group, 2], [Features, 0.9775497]
// Example of a split without specifying a sampling key column.
folds = mlContext.Data.CrossValidationSplit(dataview, numberOfFolds: 3);
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[0].TrainSet,
reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[0].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 2], [Features, 0.7680227]
// [Group, 0], [Features, 0.5581612]
// [Group, 1], [Features, 0.2060332]
// [Group, 1], [Features, 0.4421779]
// [Group, 2], [Features, 0.9775497]
// [Group, 0], [Features, 0.2737045]
//
// The data in the Test split.
// [Group, 2], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[1].TrainSet,
reuseRowObject: false);
testSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[1].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 2], [Features, 0.7680227]
// [Group, 0], [Features, 0.5581612]
// [Group, 1], [Features, 0.2060332]
// [Group, 2], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
// [Group, 1], [Features, 0.4421779]
//
// The data in the Test split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 2], [Features, 0.9775497]
// [Group, 0], [Features, 0.2737045]
trainSet = mlContext.Data
.CreateEnumerable<DataPoint>(folds[2].TrainSet,
reuseRowObject: false);
testSet = mlContext.Data.CreateEnumerable<DataPoint>(folds[2].TestSet,
reuseRowObject: false);
PrintPreviewRows(trainSet, testSet);
// The data in the Train split.
// [Group, 0], [Features, 0.7262433]
// [Group, 1], [Features, 0.8173254]
// [Group, 2], [Features, 0.5588848]
// [Group, 0], [Features, 0.9060271]
// [Group, 2], [Features, 0.9775497]
// [Group, 0], [Features, 0.2737045]
//
// The data in the Test split.
// [Group, 2], [Features, 0.7680227]
// [Group, 0], [Features, 0.5581612]
// [Group, 1], [Features, 0.2060332]
// [Group, 1], [Features, 0.4421779]
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
yield return new DataPoint
{
Group = i % 3,
// Create random features that are correlated with label.
Features = (float)random.NextDouble()
};
}
}
// Example with features and group column. A data set is a collection of
// such examples.
private class DataPoint
{
public float Group { get; set; }
public float Features { get; set; }
}
// print helper
private static void PrintPreviewRows(IEnumerable<DataPoint> trainSet,
IEnumerable<DataPoint> testSet)
{
Console.WriteLine($"The data in the Train split.");
foreach (var row in trainSet)
Console.WriteLine($"{row.Group}, {row.Features}");
Console.WriteLine($"\nThe data in the Test split.");
foreach (var row in testSet)
Console.WriteLine($"{row.Group}, {row.Features}");
}
}
}