Trenowanie modelu uczenia maszynowego przy użyciu krzyżowej walidacji
Dowiedz się, jak używać krzyżowej walidacji do trenowania bardziej niezawodnych modeli uczenia maszynowego w ML.NET.
Krzyżowa walidacja to technika trenowania i oceny modelu, która dzieli dane na kilka partycji i trenuje wiele algorytmów na tych partycjach. Ta technika poprawia niezawodność modelu, przechowując dane z procesu trenowania. Oprócz poprawy wydajności na nieznanych obserwacjach, w środowiskach z ograniczoną ilością danych może stanowić skuteczne narzędzie do treningu modeli z mniejszym zestawem danych.
Dane i model danych
Podane dane z pliku, który ma następujący format:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
Dane mogą być modelowane przez klasę, na przykład HousingData
i ładowane do IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Przygotowywanie danych
Wstępne przetwarzanie danych przed użyciem ich do utworzenia modelu uczenia maszynowego. W tym przykładzie kolumny Size
i HistoricalPrices
są łączone w jeden wektor funkcji, który jest wynikiem nowej kolumny o nazwie Features
przy użyciu metody Concatenate
. Oprócz pobierania danych do formatu oczekiwanego przez algorytmy ML.NET łączenie kolumn optymalizuje kolejne operacje w potoku, stosując operację raz dla kolumny łączonej zamiast każdej z oddzielnych kolumn.
Po połączeniu kolumn w jednym wektorze NormalizeMinMax
jest stosowany do kolumny Features
w celu uzyskania Size
i HistoricalPrices
w tym samym zakresie od 0 do 1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Trenowanie modelu za pomocą kroswalidacji
Gdy dane zostały wstępnie przetworzone, nadszedł czas, aby wytrenować model. Najpierw wybierz algorytm, który najlepiej pasuje do zadania uczenia maszynowego do wykonania. Ponieważ przewidywana wartość jest wartością liczbowo ciągłą, zadanie to regresja. Jednym z algorytmów regresji implementowanych przez ML.NET jest algorytm StochasticDualCoordinateAscentCoordinator
. Aby wytrenować model za pomocą krzyżowego sprawdzania poprawności, użyj metody CrossValidate
.
Notatka
Mimo że w tym przykładzie użyto modelu regresji liniowej, funkcja CrossValidate ma zastosowanie do wszystkich innych zadań uczenia maszynowego w ML.NET z wyjątkiem wykrywania anomalii.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
wykonuje następujące operacje:
- Partycjonuje dane w kilka partycji równych wartości określonej w parametrze
numberOfFolds
. Wynikiem każdej partycji jest obiektTrainTestData
. - Model jest trenowany na poszczególnych partycjach przy użyciu określonego narzędzia do szacowania algorytmu uczenia maszynowego w zestawie danych treningowych.
- Wydajność każdego modelu jest oceniana przy użyciu metody
Evaluate
w zestawie danych testowych. - Model wraz z jego metrykami jest zwracany dla każdego z modeli.
Wynikiem przechowywanym w cvResults
jest kolekcja obiektów CrossValidationResult
. Ten obiekt zawiera wytrenowany model, a także metryki, które są dostępne odpowiednio w ramach właściwości Model
i Metrics
. W tym przykładzie właściwość Model
jest typu ITransformer
, a właściwość Metrics
jest typu RegressionMetrics
.
Ocena modelu
Dostęp do metryk dla różnych wytrenowanych modeli można uzyskać za pośrednictwem właściwości Metrics
pojedynczego obiektu CrossValidationResult
. W tym przypadku dostęp do metryki
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Jeśli sprawdzisz zawartość zmiennej rSquared
, dane wyjściowe powinny mieć pięć wartości z zakresu od 0 do 1, gdzie bliżej 1 oznacza najlepiej. Używając metryk, takich jak R-Squared, wybierz modele od najlepszych do najgorszych. Następnie wybierz górny model, aby przewidywać lub wykonywać dodatkowe operacje.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];