AutoCatalog.CreateSweepableEstimator<T> 方法
定义
重要
一些信息与预发行产品相关,相应产品在发行之前可能会进行重大修改。 对于此处提供的信息,Microsoft 不作任何明示或暗示的担保。
使用自定义工厂和搜索空间创建可扫描估算器。
public Microsoft.ML.AutoML.SweepableEstimator CreateSweepableEstimator<T> (Func<Microsoft.ML.MLContext,T,Microsoft.ML.IEstimator<Microsoft.ML.ITransformer>> factory, Microsoft.ML.SearchSpace.SearchSpace<T> ss = default) where T : class, new();
member this.CreateSweepableEstimator : Func<Microsoft.ML.MLContext, 'T, Microsoft.ML.IEstimator<Microsoft.ML.ITransformer> (requires 'T : null and 'T : (new : unit -> 'T))> * Microsoft.ML.SearchSpace.SearchSpace<'T (requires 'T : null and 'T : (new : unit -> 'T))> -> Microsoft.ML.AutoML.SweepableEstimator (requires 'T : null and 'T : (new : unit -> 'T))
Public Function CreateSweepableEstimator(Of T As {Class, New}) (factory As Func(Of MLContext, T, IEstimator(Of ITransformer)), Optional ss As SearchSpace(Of T) = Nothing) As SweepableEstimator
类型参数
- T
参数
- factory
- Func<MLContext,T,IEstimator<ITransformer>>
返回
示例
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
namespace Microsoft.ML.AutoML.Samples
{
public static class SweepableLightGBMBinaryExperiment
{
class LightGBMOption
{
[Range(4, 32768, init: 4, logBase: false)]
public int NumberOfLeaves { get; set; } = 4;
[Range(4, 32768, init: 4, logBase: false)]
public int NumberOfTrees { get; set; } = 4;
}
public static async Task RunAsync()
{
// This example shows how to use Sweepable API to run hyper-parameter optimization over
// LightGBM trainer with a customized search space.
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var seed = 0;
var context = new MLContext(seed);
// Create a list of training data points and convert it to IDataView.
var data = GenerateRandomBinaryClassificationDataPoints(100, seed);
var dataView = context.Data.LoadFromEnumerable(data);
// Split the dataset into train and test sets with 10% of the data used for testing.
var trainTestSplit = context.Data.TrainTestSplit(dataView, testFraction: 0.1);
// Define a customized search space for LightGBM
var lgbmSearchSpace = new SearchSpace<LightGBMOption>();
// Define the sweepable LightGBM estimator.
var lgbm = context.Auto().CreateSweepableEstimator((_context, option) =>
{
return _context.BinaryClassification.Trainers.LightGbm(
"Label",
"Features",
numberOfLeaves: option.NumberOfLeaves,
numberOfIterations: option.NumberOfTrees);
}, lgbmSearchSpace);
// Create sweepable pipeline
var pipeline = new EstimatorChain<ITransformer>().Append(lgbm);
// Create an AutoML experiment
var experiment = context.Auto().CreateExperiment();
// Redirect AutoML log to console
context.Log += (object o, LoggingEventArgs e) =>
{
if (e.Source == nameof(AutoMLExperiment) && e.Kind > Runtime.ChannelMessageKind.Trace)
{
Console.WriteLine(e.RawMessage);
}
};
// Config experiment to optimize "Accuracy" metric on given dataset.
// This experiment will run hyper-parameter optimization on given pipeline
experiment.SetPipeline(pipeline)
.SetDataset(trainTestSplit.TrainSet, fold: 5) // use 5-fold cross validation to evaluate each trial
.SetBinaryClassificationMetric(BinaryClassificationMetric.Accuracy, "Label")
.SetMaxModelToExplore(100); // explore 100 trials
// start automl experiment
var result = await experiment.RunAsync();
// Expected output samples during training. The pipeline will be unknown because it's created using
// customized sweepable estimator, therefore AutoML doesn't have the knowledge of the exact type of the estimator.
// Update Running Trial - Id: 0
// Update Completed Trial - Id: 0 - Metric: 0.5105967259285338 - Pipeline: Unknown=>Unknown - Duration: 616 - Peak CPU: 0.00% - Peak Memory in MB: 35.54
// Update Best Trial - Id: 0 - Metric: 0.5105967259285338 - Pipeline: Unknown=>Unknown
// evaluate test dataset on best model.
var bestModel = result.Model;
var eval = bestModel.Transform(trainTestSplit.TestSet);
var metrics = context.BinaryClassification.Evaluate(eval);
PrintMetrics(metrics);
// Expected output:
// Accuracy: 0.67
// AUC: 0.75
// F1 Score: 0.33
// Negative Precision: 0.88
// Negative Recall: 0.70
// Positive Precision: 0.25
// Positive Recall: 0.50
// TEST POSITIVE RATIO: 0.1667(2.0 / (2.0 + 10.0))
// Confusion table
// ||======================
// PREDICTED || positive | negative | Recall
// TRUTH ||======================
// positive || 1 | 1 | 0.5000
// negative || 3 | 7 | 0.7000
// ||======================
// Precision || 0.2500 | 0.8750 |
}
private static IEnumerable<BinaryClassificationDataPoint> GenerateRandomBinaryClassificationDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = randomFloat() > 0.5f;
yield return new BinaryClassificationDataPoint
{
Label = label,
// Create random features that are correlated with the label.
// For data points with false label, the feature values are
// slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 50)
.Select(x => x ? randomFloat() : randomFloat() +
0.1f).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class BinaryClassificationDataPoint
{
public bool Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public bool Label { get; set; }
// Predicted label from the trainer.
public bool PredictedLabel { get; set; }
}
// Pretty-print BinaryClassificationMetrics objects.
private static void PrintMetrics(BinaryClassificationMetrics metrics)
{
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
Console.WriteLine($"Negative Precision: " +
$"{metrics.NegativePrecision:F2}");
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
Console.WriteLine($"Positive Precision: " +
$"{metrics.PositivePrecision:F2}");
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
}
}
}