StandardTrainersCatalog.OneVersusAll<TModel> 方法

定义

创建一个 OneVersusAllTrainer,它使用一对多策略和二元分类估算器指定的 binaryEstimator二元分类估算器来预测多类目标。

public static Microsoft.ML.Trainers.OneVersusAllTrainer OneVersusAll<TModel> (this Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<TModel>,TModel> binaryEstimator, string labelColumnName = "Label", bool imputeMissingLabelsAsNegative = false, Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> calibrator = default, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class;
static member OneVersusAll : Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers * Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<'Model>, 'Model (requires 'Model : null)> * string * bool * Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> * int * bool -> Microsoft.ML.Trainers.OneVersusAllTrainer (requires 'Model : null)
<Extension()>
Public Function OneVersusAll(Of TModel As Class) (catalog As MulticlassClassificationCatalog.MulticlassClassificationTrainers, binaryEstimator As ITrainerEstimator(Of BinaryPredictionTransformer(Of TModel), TModel), Optional labelColumnName As String = "Label", Optional imputeMissingLabelsAsNegative As Boolean = false, Optional calibrator As IEstimator(Of ISingleFeaturePredictionTransformer(Of ICalibrator)) = Nothing, Optional maximumCalibrationExampleCount As Integer = 1000000000, Optional useProbabilities As Boolean = true) As OneVersusAllTrainer

类型参数

TModel

模型的类型。 通常从中推断 binaryEstimator出此类型参数。

参数

catalog
MulticlassClassificationCatalog.MulticlassClassificationTrainers

多类分类目录训练器对象。

binaryEstimator
ITrainerEstimator<BinaryPredictionTransformer<TModel>,TModel>

用作基础训练器的二进制文件的 ITrainerEstimator<TTransformer,TModel> 实例。

labelColumnName
String

标签列的名称。

imputeMissingLabelsAsNegative
Boolean

是否将缺失的标签视为具有负标签,而不是使其缺失。

calibrator
IEstimator<ISingleFeaturePredictionTransformer<ICalibrator>>

校准器。 如果未显式提供校准器,则默认为 Microsoft.ML.Calibrators.PlattCalibratorTrainer

maximumCalibrationExampleCount
Int32

要训练校准器的实例数。

useProbabilities
Boolean

使用概率 (与原始输出) 来确定最高分数类别。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Samples.Dynamic.Trainers.MulticlassClassification
{
    public static class OneVersusAll
    {
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define the trainer.
            var pipeline =
                // Convert the string labels into key types.
                mlContext.Transforms.Conversion.MapValueToKey("Label")
                // Apply OneVersusAll multiclass meta trainer on top of
                // binary trainer.
                .Append(mlContext.MulticlassClassification.Trainers
                .OneVersusAll(
                mlContext.BinaryClassification.Trainers.SdcaLogisticRegression()));

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data
                .LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data
                .CreateEnumerable<Prediction>(transformedTestData,
                reuseRowObject: false).ToList();

            // Look at 5 predictions
            foreach (var p in predictions.Take(5))
                Console.WriteLine($"Label: {p.Label}, " +
                    $"Prediction: {p.PredictedLabel}");

            // Expected output:
            //   Label: 1, Prediction: 1
            //   Label: 2, Prediction: 2
            //   Label: 3, Prediction: 2
            //   Label: 2, Prediction: 2
            //   Label: 3, Prediction: 2

            // Evaluate the overall metrics
            var metrics = mlContext.MulticlassClassification
                .Evaluate(transformedTestData);

            PrintMetrics(metrics);

            // Expected output:
            //   Micro Accuracy: 0.90
            //   Macro Accuracy: 0.90
            //   Log Loss: 0.36
            //   Log Loss Reduction: 0.68

            //   Confusion table
            //             ||========================
            //   PREDICTED ||     0 |     1 |     2 | Recall
            //   TRUTH     ||========================
            //           0 ||   152 |     0 |     8 | 0.9500
            //           1 ||     0 |   168 |     9 | 0.9492
            //           2 ||    17 |    15 |   131 | 0.8037
            //             ||========================
            //   Precision ||0.8994 |0.9180 |0.8851 |
        }

        // Generates random uniform doubles in [-0.5, 0.5)
        // range with labels 1, 2 or 3.
        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            float randomFloat() => (float)(random.NextDouble() - 0.5);
            for (int i = 0; i < count; i++)
            {
                // Generate Labels that are integers 1, 2 or 3
                var label = random.Next(1, 4);
                yield return new DataPoint
                {
                    Label = (uint)label,
                    // Create random features that are correlated with the label.
                    // The feature values are slightly increased by adding a
                    // constant multiple of label.
                    Features = Enumerable.Repeat(label, 20)
                        .Select(x => randomFloat() + label * 0.2f).ToArray()

                };
            }
        }

        // Example with label and 20 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public uint Label { get; set; }
            [VectorType(20)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public uint Label { get; set; }
            // Predicted label from the trainer.
            public uint PredictedLabel { get; set; }
        }

        // Pretty-print MulticlassClassificationMetrics objects.
        public static void PrintMetrics(MulticlassClassificationMetrics metrics)
        {
            Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
            Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
            Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
            Console.WriteLine(
                $"Log Loss Reduction: {metrics.LogLossReduction:F2}\n");

            Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
        }
    }
}

注解

在一对全策略中,二元分类算法用于为每个类训练一个分类器,该分类器区分该类与其他所有类。 然后,通过运行这些二进制分类器,并选择具有最高置信度分数的预测来执行预测。

适用于