Trenowanie modelu uczenia maszynowego przy użyciu krzyżowej walidacji
Dowiedz się, jak używać krzyżowej walidacji do trenowania bardziej niezawodnych modeli uczenia maszynowego w ML.NET.
Krzyżowa walidacja to technika trenowania i oceny modelu, która dzieli dane na kilka partycji i trenuje wiele algorytmów na tych partycjach. Ta technika poprawia niezawodność modelu, przechowując dane z procesu trenowania. Oprócz poprawy wydajności obserwacji, w środowiskach ograniczonych danymi może to być skuteczne narzędzie do trenowania modeli z mniejszym zestawem danych.
Model danych i danych
Podane dane z pliku, który ma następujący format:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
Dane mogą być modelowane przez klasę, na HousingData
przykład i ładowane do klasy IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Przygotowywanie danych
Wstępne przetwarzanie danych przed użyciem ich do utworzenia modelu uczenia maszynowego. W tym przykładzie Size
kolumny i HistoricalPrices
są łączone w jeden wektor funkcji, który jest wynikiem nowej kolumny o nazwie Features
przy użyciu Concatenate
metody . Oprócz pobierania danych do formatu oczekiwanego przez algorytmy ML.NET łączenie kolumn optymalizuje kolejne operacje w potoku, stosując operację raz dla kolumny łączonej zamiast każdej z oddzielnych kolumn.
Po połączeniu kolumn w jednym wektorze NormalizeMinMax
zostanie zastosowany do kolumny w celu uzyskania Size
i HistoricalPrices
w tym samym zakresie z zakresu od 0 do Features
1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Trenowanie modelu za pomocą krzyżowego sprawdzania poprawności
Po wstępnym przetworzeniu danych nadszedł czas na trenowanie modelu. Najpierw wybierz algorytm, który najlepiej pasuje do zadania uczenia maszynowego do wykonania. Ponieważ przewidywana wartość jest wartością liczbowo ciągłą, zadanie to regresja. Jednym z algorytmów regresji implementowanych przez ML.NET jest StochasticDualCoordinateAscentCoordinator
algorytm. Aby wytrenować model za pomocą krzyżowej walidacji, użyj CrossValidate
metody .
Uwaga
Mimo że w tym przykładzie użyto modelu regresji liniowej, funkcja CrossValidate ma zastosowanie do wszystkich innych zadań uczenia maszynowego w ML.NET z wyjątkiem wykrywania anomalii.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
wykonuje następujące operacje:
- Partycjonuje dane na liczbę partycji równych wartości określonej w parametrze
numberOfFolds
. Wynikiem każdej partycji jestTrainTestData
obiekt. - Model jest trenowany na poszczególnych partycjach przy użyciu określonego narzędzia do szacowania algorytmu uczenia maszynowego w zestawie danych treningowych.
- Wydajność każdego modelu jest oceniana przy użyciu
Evaluate
metody w zestawie danych testowych. - Model wraz z jego metrykami jest zwracany dla każdego z modeli.
Wynikiem przechowywanym w cvResults
pliku jest kolekcja CrossValidationResult
obiektów. Ten obiekt zawiera wytrenowany model, a także metryki, które są odpowiednio dostępne dla Model
właściwości i Metrics
. W tym przykładzie Model
właściwość jest typu ITransformer
, a Metrics
właściwość jest typu RegressionMetrics
.
Ocenianie modelu
Metryki dla różnych wytrenowanych modeli można uzyskać za pośrednictwem Metrics
właściwości pojedynczego CrossValidationResult
obiektu. W takim przypadku metryka R-Squared jest uzyskiwana i przechowywana w zmiennej rSquared
.
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Jeśli sprawdzisz zawartość zmiennej rSquared
, dane wyjściowe powinny mieć pięć wartości z zakresu od 0 do 1, gdzie bliżej 1 oznacza, że najlepiej. Używając metryk, takich jak R-Squared, wybierz modele od najlepszych do najgorszych. Następnie wybierz górny model, aby przewidywać lub wykonywać dodatkowe operacje.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];