StandardTrainersCatalog.OneVersusAll<TModel> メソッド
定義
重要
一部の情報は、リリース前に大きく変更される可能性があるプレリリースされた製品に関するものです。 Microsoft は、ここに記載されている情報について、明示または黙示を問わず、一切保証しません。
OneVersusAllTrainerで指定された二項分類推定器を使用して、一対全戦略を使用して多クラスターゲットを予測する、複数クラスのターゲットを作成しますbinaryEstimator
。
public static Microsoft.ML.Trainers.OneVersusAllTrainer OneVersusAll<TModel> (this Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<TModel>,TModel> binaryEstimator, string labelColumnName = "Label", bool imputeMissingLabelsAsNegative = false, Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> calibrator = default, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class;
static member OneVersusAll : Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers * Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<'Model>, 'Model (requires 'Model : null)> * string * bool * Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> * int * bool -> Microsoft.ML.Trainers.OneVersusAllTrainer (requires 'Model : null)
<Extension()>
Public Function OneVersusAll(Of TModel As Class) (catalog As MulticlassClassificationCatalog.MulticlassClassificationTrainers, binaryEstimator As ITrainerEstimator(Of BinaryPredictionTransformer(Of TModel), TModel), Optional labelColumnName As String = "Label", Optional imputeMissingLabelsAsNegative As Boolean = false, Optional calibrator As IEstimator(Of ISingleFeaturePredictionTransformer(Of ICalibrator)) = Nothing, Optional maximumCalibrationExampleCount As Integer = 1000000000, Optional useProbabilities As Boolean = true) As OneVersusAllTrainer
型パラメーター
- TModel
モデルの型。 この型パラメーターは、通常から自動的 binaryEstimator
に推論されます。
パラメーター
多クラス分類カタログ トレーナー オブジェクト。
- binaryEstimator
- ITrainerEstimator<BinaryPredictionTransformer<TModel>,TModel>
ベース トレーナーとして使用されるバイナリ ITrainerEstimator<TTransformer,TModel> のインスタンス。
- labelColumnName
- String
ラベル列の名前。
- imputeMissingLabelsAsNegative
- Boolean
欠落しているラベルを、欠落したままにするのではなく、負のラベルを持つものとして扱うかどうか。
- calibrator
- IEstimator<ISingleFeaturePredictionTransformer<ICalibrator>>
校正器。 校正器が明示的に指定されていない場合は、既定で Microsoft.ML.Calibrators.PlattCalibratorTrainer
- maximumCalibrationExampleCount
- Int32
校正器をトレーニングするインスタンスの数。
- useProbabilities
- Boolean
確率 (生の出力と比較) を使用して、トップ スコア カテゴリを識別します。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.MulticlassClassification
{
public static class OneVersusAll
{
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline =
// Convert the string labels into key types.
mlContext.Transforms.Conversion.MapValueToKey("Label")
// Apply OneVersusAll multiclass meta trainer on top of
// binary trainer.
.Append(mlContext.MulticlassClassification.Trainers
.OneVersusAll(
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression()));
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data
.CreateEnumerable<Prediction>(transformedTestData,
reuseRowObject: false).ToList();
// Look at 5 predictions
foreach (var p in predictions.Take(5))
Console.WriteLine($"Label: {p.Label}, " +
$"Prediction: {p.PredictedLabel}");
// Expected output:
// Label: 1, Prediction: 1
// Label: 2, Prediction: 2
// Label: 3, Prediction: 2
// Label: 2, Prediction: 2
// Label: 3, Prediction: 2
// Evaluate the overall metrics
var metrics = mlContext.MulticlassClassification
.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Micro Accuracy: 0.90
// Macro Accuracy: 0.90
// Log Loss: 0.36
// Log Loss Reduction: 0.68
// Confusion table
// ||========================
// PREDICTED || 0 | 1 | 2 | Recall
// TRUTH ||========================
// 0 || 152 | 0 | 8 | 0.9500
// 1 || 0 | 168 | 9 | 0.9492
// 2 || 17 | 15 | 131 | 0.8037
// ||========================
// Precision ||0.8994 |0.9180 |0.8851 |
}
// Generates random uniform doubles in [-0.5, 0.5)
// range with labels 1, 2 or 3.
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
float randomFloat() => (float)(random.NextDouble() - 0.5);
for (int i = 0; i < count; i++)
{
// Generate Labels that are integers 1, 2 or 3
var label = random.Next(1, 4);
yield return new DataPoint
{
Label = (uint)label,
// Create random features that are correlated with the label.
// The feature values are slightly increased by adding a
// constant multiple of label.
Features = Enumerable.Repeat(label, 20)
.Select(x => randomFloat() + label * 0.2f).ToArray()
};
}
}
// Example with label and 20 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public uint Label { get; set; }
[VectorType(20)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public uint Label { get; set; }
// Predicted label from the trainer.
public uint PredictedLabel { get; set; }
}
// Pretty-print MulticlassClassificationMetrics objects.
public static void PrintMetrics(MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
Console.WriteLine(
$"Log Loss Reduction: {metrics.LogLossReduction:F2}\n");
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
}
}
}
注釈
一対全戦略では、二項分類アルゴリズムを使用してクラスごとに 1 つの分類子をトレーニングします。このアルゴリズムは、そのクラスを他のすべてのクラスと区別します。 次に、これらの二項分類子を実行し、信頼度スコアが最も高い予測を選択することで、予測が行われます。