Delen via


Een machine learning-model trainen met kruisvalidatie

Meer informatie over het gebruik van kruisvalidatie voor het trainen van robuustere machine learning-modellen in ML.NET.

Kruisvalidatie is een trainings- en modelevaluatietechniek waarmee de gegevens worden gesplitst in verschillende partities en meerdere algoritmen worden getraind op deze partities. Deze techniek verbetert de robuustheid van het model door gegevens uit het trainingsproces vast te houden. Naast het verbeteren van de prestaties van ongeziene waarnemingen, kan het in gegevensbeperkingsomgevingen een effectief hulpmiddel zijn voor het trainen van modellen met een kleinere gegevensset.

Het gegevens- en gegevensmodel

Gegeven gegevens uit een bestand met de volgende indeling:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

De gegevens kunnen worden gemodelleerd door een klasse zoals HousingData en in een IDataViewgeladen.

public class HousingData
{
    [LoadColumn(0)]
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    [VectorType(3)]
    public float[] HistoricalPrices { get; set; }

    [LoadColumn(4)]
    [ColumnName("Label")]
    public float CurrentPrice { get; set; }
}

De gegevens voorbereiden

De gegevens vooraf verwerken voordat u deze gebruikt om het machine learning-model te bouwen. In dit voorbeeld worden de kolommen Size en HistoricalPrices gecombineerd tot één functievector. Dit is uitvoer naar een nieuwe kolom met de naam Features met behulp van de methode Concatenate. Naast het ophalen van de gegevens in de indeling die wordt verwacht door ML.NET algoritmen, optimaliseert het samenvoegen van kolommen de volgende bewerkingen in de pijplijn door de bewerking eenmaal toe te passen voor de samengevoegde kolom in plaats van elk van de afzonderlijke kolommen.

Zodra de kolommen in één vector zijn gecombineerd, wordt NormalizeMinMax toegepast op de Features kolom om Size en HistoricalPrices in hetzelfde bereik tussen 0 en 1 op te halen.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
        .Append(mlContext.Transforms.NormalizeMinMax("Features"));

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Model trainen met kruisvalidatie

Zodra de gegevens vooraf zijn verwerkt, is het tijd om het model te trainen. Selecteer eerst het algoritme dat het meest overeenkomt met de machine learning-taak die moet worden uitgevoerd. Omdat de voorspelde waarde een numerieke doorlopende waarde is, is de taak regressie. Een van de regressiealgoritmen die door ML.NET zijn geïmplementeerd, is het StochasticDualCoordinateAscentCoordinator algoritme. Als u het model wilt trainen met kruisvalidatie, gebruikt u de methode CrossValidate.

Notitie

Hoewel in dit voorbeeld een lineair regressiemodel wordt gebruikt, is CrossValidate van toepassing op alle andere machine learning-taken in ML.NET behalve anomaliedetectie.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate voert de volgende bewerkingen uit:

  1. Partitioneert de gegevens in een aantal partities die gelijk zijn aan de waarde die is opgegeven in de parameter numberOfFolds. Het resultaat van elke partitie is een TrainTestData-object.
  2. Een model wordt getraind op elk van de partities met behulp van de opgegeven machine learning-algoritmeschatting voor de trainingsdataset.
  3. De prestaties van elk model worden geëvalueerd met behulp van de Evaluate methode in de testgegevensset.
  4. Het model wordt samen met de metrische gegevens geretourneerd voor elk van de modellen.

Het resultaat dat is opgeslagen in cvResults is een verzameling CrossValidationResult objecten. Dit object bevat het getrainde model en de metrics, die respectievelijk toegankelijk zijn via de eigenschappen Model en Metrics. In dit voorbeeld is de eigenschap Model van het type ITransformer en is de eigenschap Metrics van het type RegressionMetrics.

Het model evalueren

Metrische gegevens voor de verschillende getrainde modellen kunnen worden geopend via de eigenschap Metrics van het afzonderlijke CrossValidationResult-object. In dit geval wordt de R-Squared-metrische geopend en opgeslagen in de variabele rSquared.

IEnumerable<double> rSquared =
    cvResults
        .Select(fold => fold.Metrics.RSquared);

Als u de inhoud van de variabele rSquared inspecteert, zou de uitvoer vijf waarden moeten zijn tussen 0 en 1, waarbij dichter bij 1 beter is. Gebruik metrische gegevens zoals R-Squared om de modellen te rangschikken van beste naar slechtste prestaties. Selecteer vervolgens het bovenste model om voorspellingen te doen of extra bewerkingen uit te voeren.

// Select all models
ITransformer[] models =
    cvResults
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)
        .ToArray();

// Get Top Model
ITransformer topModel = models[0];