Udostępnij za pośrednictwem


Trenowanie modelu uczenia maszynowego przy użyciu krzyżowej walidacji

Dowiedz się, jak używać krzyżowej walidacji do trenowania bardziej niezawodnych modeli uczenia maszynowego w ML.NET.

Krzyżowa walidacja to technika trenowania i oceny modelu, która dzieli dane na kilka partycji i trenuje wiele algorytmów na tych partycjach. Ta technika poprawia niezawodność modelu, przechowując dane z procesu trenowania. Oprócz poprawy wydajności na nieznanych obserwacjach, w środowiskach z ograniczoną ilością danych może stanowić skuteczne narzędzie do treningu modeli z mniejszym zestawem danych.

Dane i model danych

Podane dane z pliku, który ma następujący format:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

Dane mogą być modelowane przez klasę, na przykład HousingData i ładowane do IDataView.

public class HousingData
{
    [LoadColumn(0)]
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    [VectorType(3)]
    public float[] HistoricalPrices { get; set; }

    [LoadColumn(4)]
    [ColumnName("Label")]
    public float CurrentPrice { get; set; }
}

Przygotowywanie danych

Wstępne przetwarzanie danych przed użyciem ich do utworzenia modelu uczenia maszynowego. W tym przykładzie kolumny Size i HistoricalPrices są łączone w jeden wektor funkcji, który jest wynikiem nowej kolumny o nazwie Features przy użyciu metody Concatenate. Oprócz pobierania danych do formatu oczekiwanego przez algorytmy ML.NET łączenie kolumn optymalizuje kolejne operacje w potoku, stosując operację raz dla kolumny łączonej zamiast każdej z oddzielnych kolumn.

Po połączeniu kolumn w jednym wektorze NormalizeMinMax jest stosowany do kolumny Features w celu uzyskania Size i HistoricalPrices w tym samym zakresie od 0 do 1.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
        .Append(mlContext.Transforms.NormalizeMinMax("Features"));

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Trenowanie modelu za pomocą kroswalidacji

Gdy dane zostały wstępnie przetworzone, nadszedł czas, aby wytrenować model. Najpierw wybierz algorytm, który najlepiej pasuje do zadania uczenia maszynowego do wykonania. Ponieważ przewidywana wartość jest wartością liczbowo ciągłą, zadanie to regresja. Jednym z algorytmów regresji implementowanych przez ML.NET jest algorytm StochasticDualCoordinateAscentCoordinator. Aby wytrenować model za pomocą krzyżowego sprawdzania poprawności, użyj metody CrossValidate.

Notatka

Mimo że w tym przykładzie użyto modelu regresji liniowej, funkcja CrossValidate ma zastosowanie do wszystkich innych zadań uczenia maszynowego w ML.NET z wyjątkiem wykrywania anomalii.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate wykonuje następujące operacje:

  1. Partycjonuje dane w kilka partycji równych wartości określonej w parametrze numberOfFolds. Wynikiem każdej partycji jest obiekt TrainTestData.
  2. Model jest trenowany na poszczególnych partycjach przy użyciu określonego narzędzia do szacowania algorytmu uczenia maszynowego w zestawie danych treningowych.
  3. Wydajność każdego modelu jest oceniana przy użyciu metody Evaluate w zestawie danych testowych.
  4. Model wraz z jego metrykami jest zwracany dla każdego z modeli.

Wynikiem przechowywanym w cvResults jest kolekcja obiektów CrossValidationResult. Ten obiekt zawiera wytrenowany model, a także metryki, które są dostępne odpowiednio w ramach właściwości Model i Metrics. W tym przykładzie właściwość Model jest typu ITransformer, a właściwość Metrics jest typu RegressionMetrics.

Ocena modelu

Dostęp do metryk dla różnych wytrenowanych modeli można uzyskać za pośrednictwem właściwości Metrics pojedynczego obiektu CrossValidationResult. W tym przypadku dostęp do metryki R-Squared jest przechowywany w zmiennej .

IEnumerable<double> rSquared =
    cvResults
        .Select(fold => fold.Metrics.RSquared);

Jeśli sprawdzisz zawartość zmiennej rSquared, dane wyjściowe powinny mieć pięć wartości z zakresu od 0 do 1, gdzie bliżej 1 oznacza najlepiej. Używając metryk, takich jak R-Squared, wybierz modele od najlepszych do najgorszych. Następnie wybierz górny model, aby przewidywać lub wykonywać dodatkowe operacje.

// Select all models
ITransformer[] models =
    cvResults
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)
        .ToArray();

// Get Top Model
ITransformer topModel = models[0];