Trénování modelu strojového učení pomocí křížového ověřování
Naučte se používat křížové ověřování k trénování robustnějších modelů strojového učení v ML.NET.
Křížové ověření je technika trénování a vyhodnocení modelu, která rozdělí data do několika oddílů a trénuje více algoritmů v těchto oddílech. Tato technika zlepšuje odolnost modelu tím, že z trénovacího procesu vylučuje data. Kromě zlepšení výkonu u nezoznaných pozorování může být v prostředích s omezenými daty efektivním nástrojem pro trénování modelů s menší datovou sadou.
Data a datový model
Data ze souboru, který má následující formát:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
Data mohou být modelována třídou, jako je HousingData
a načtena do IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Příprava dat
Před použitím dat před sestavením modelu strojového učení je předzpracujte. V této ukázce se sloupce Size
a HistoricalPrices
zkombinují do jednoho vektoru funkce, který je výstupem nového sloupce, který se nazývá Features
pomocí metody Concatenate
. Kromě převedení dat do formátu požadovaného algoritmy ML.NET, zřetězení sloupců optimalizuje následné operace v potrubí tím, že operaci použije pouze jednou pro zřetězený sloupec, místo jednotlivých samostatných sloupců.
Jakmile se sloupce zkombinují do jednoho vektoru, NormalizeMinMax
se použije u sloupce Features
, aby se Size
a HistoricalPrices
dostaly do stejného rozsahu mezi 0 a 1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Trénujte model s křížovou validací
Po předběžném zpracování dat je čas model vytrénovat. Nejprve vyberte algoritmus, který je nejvíce v souladu s úlohou strojového učení, který se má provést. Vzhledem k tomu, že predikovaná hodnota je číselně souvislá hodnota, je úkol regresní. Jedním z regresních algoritmů implementovaných ML.NET je algoritmus StochasticDualCoordinateAscentCoordinator
. K trénování modelu pomocí křížového ověření použijte metodu CrossValidate
.
Poznámka
I když tato ukázka používá lineární regresní model, crossValidate se vztahuje na všechny ostatní úlohy strojového učení v ML.NET s výjimkou detekce anomálií.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
provádí následující operace:
- Rozdělí data do několika oddílů, které se rovnají hodnotě zadané v parametru
numberOfFolds
. Výsledkem každého oddílu je objektTrainTestData
. - Model se trénuje na každé části pomocí určeného odhadce algoritmu strojového učení v tréninkové datové sadě.
- Výkon každého modelu se vyhodnocuje pomocí metody
Evaluate
v testovací sadě dat. - Model spolu s jeho metrikami se vrátí pro každý z těchto modelů.
Výsledek uložený v cvResults
je kolekce CrossValidationResult
objektů. Tento objekt zahrnuje natrénovaný model i metriky, které jsou přístupné z vlastností Model
a Metrics
. V této ukázce je vlastnost Model
typu ITransformer
a vlastnost Metrics
je typu RegressionMetrics
.
Vyhodnocení modelu
K metrikám pro různé natrénované modely je možné přistupovat prostřednictvím vlastnosti Metrics
jednotlivého objektu CrossValidationResult
. V tomto případě je metrika R-Squared přístupná a uložená v proměnné rSquared
.
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Pokud zkontrolujete obsah proměnné rSquared
, výstup by měl obsahovat pět hodnot v rozsahu od 0 do 1, kde hodnota blíže k 1 znamená lepší. Pomocí metrik, jako je R-Squared, vyberte modely od nejlepších po nejhorší výkon. Pak výběrem horního modelu proveďte předpovědi nebo proveďte další operace.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];