Een machine learning-model trainen met kruisvalidatie
Meer informatie over het gebruik van kruisvalidatie voor het trainen van robuustere machine learning-modellen in ML.NET.
Kruisvalidatie is een trainings- en modelevaluatietechniek waarmee de gegevens worden gesplitst in verschillende partities en meerdere algoritmen worden getraind op deze partities. Deze techniek verbetert de robuustheid van het model door gegevens uit het trainingsproces vast te houden. Naast het verbeteren van de prestaties van ongeziene waarnemingen, kan het in gegevensbeperkingsomgevingen een effectief hulpmiddel zijn voor het trainen van modellen met een kleinere gegevensset.
Het gegevens- en gegevensmodel
Gegeven gegevens uit een bestand met de volgende indeling:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
De gegevens kunnen worden gemodelleerd door een klasse zoals HousingData
en in een IDataView
geladen.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
De gegevens voorbereiden
De gegevens vooraf verwerken voordat u deze gebruikt om het machine learning-model te bouwen. In dit voorbeeld worden de kolommen Size
en HistoricalPrices
gecombineerd tot één functievector. Dit is uitvoer naar een nieuwe kolom met de naam Features
met behulp van de methode Concatenate
. Naast het ophalen van de gegevens in de indeling die wordt verwacht door ML.NET algoritmen, optimaliseert het samenvoegen van kolommen de volgende bewerkingen in de pijplijn door de bewerking eenmaal toe te passen voor de samengevoegde kolom in plaats van elk van de afzonderlijke kolommen.
Zodra de kolommen in één vector zijn gecombineerd, wordt NormalizeMinMax
toegepast op de Features
kolom om Size
en HistoricalPrices
in hetzelfde bereik tussen 0 en 1 op te halen.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Model trainen met kruisvalidatie
Zodra de gegevens vooraf zijn verwerkt, is het tijd om het model te trainen. Selecteer eerst het algoritme dat het meest overeenkomt met de machine learning-taak die moet worden uitgevoerd. Omdat de voorspelde waarde een numerieke doorlopende waarde is, is de taak regressie. Een van de regressiealgoritmen die door ML.NET zijn geïmplementeerd, is het StochasticDualCoordinateAscentCoordinator
algoritme. Als u het model wilt trainen met kruisvalidatie, gebruikt u de methode CrossValidate
.
Notitie
Hoewel in dit voorbeeld een lineair regressiemodel wordt gebruikt, is CrossValidate van toepassing op alle andere machine learning-taken in ML.NET behalve anomaliedetectie.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
voert de volgende bewerkingen uit:
- Partitioneert de gegevens in een aantal partities die gelijk zijn aan de waarde die is opgegeven in de parameter
numberOfFolds
. Het resultaat van elke partitie is eenTrainTestData
-object. - Een model wordt getraind op elk van de partities met behulp van de opgegeven machine learning-algoritmeschatting voor de trainingsdataset.
- De prestaties van elk model worden geëvalueerd met behulp van de
Evaluate
methode in de testgegevensset. - Het model wordt samen met de metrische gegevens geretourneerd voor elk van de modellen.
Het resultaat dat is opgeslagen in cvResults
is een verzameling CrossValidationResult
objecten. Dit object bevat het getrainde model en de metrics, die respectievelijk toegankelijk zijn via de eigenschappen Model
en Metrics
. In dit voorbeeld is de eigenschap Model
van het type ITransformer
en is de eigenschap Metrics
van het type RegressionMetrics
.
Het model evalueren
Metrische gegevens voor de verschillende getrainde modellen kunnen worden geopend via de eigenschap Metrics
van het afzonderlijke CrossValidationResult
-object. In dit geval wordt de R-Squared-metrische geopend en opgeslagen in de variabele rSquared
.
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Als u de inhoud van de variabele rSquared
inspecteert, zou de uitvoer vijf waarden moeten zijn tussen 0 en 1, waarbij dichter bij 1 beter is. Gebruik metrische gegevens zoals R-Squared om de modellen te rangschikken van beste naar slechtste prestaties. Selecteer vervolgens het bovenste model om voorspellingen te doen of extra bewerkingen uit te voeren.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];