Sdílet prostřednictvím


Trénování modelu strojového učení pomocí křížového ověřování

Naučte se používat křížové ověřování k trénování robustnějších modelů strojového učení v ML.NET.

Křížové ověření je technika trénování a vyhodnocení modelu, která rozdělí data do několika oddílů a trénuje více algoritmů v těchto oddílech. Tato technika zlepšuje odolnost modelu tím, že z trénovacího procesu vylučuje data. Kromě zlepšení výkonu u nezoznaných pozorování může být v prostředích s omezenými daty efektivním nástrojem pro trénování modelů s menší datovou sadou.

Data a datový model

Data ze souboru, který má následující formát:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

Data mohou být modelována třídou, jako je HousingData a načtena do IDataView.

public class HousingData
{
    [LoadColumn(0)]
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    [VectorType(3)]
    public float[] HistoricalPrices { get; set; }

    [LoadColumn(4)]
    [ColumnName("Label")]
    public float CurrentPrice { get; set; }
}

Příprava dat

Před použitím dat před sestavením modelu strojového učení je předzpracujte. V této ukázce se sloupce Size a HistoricalPrices zkombinují do jednoho vektoru funkce, který je výstupem nového sloupce, který se nazývá Features pomocí metody Concatenate. Kromě převedení dat do formátu požadovaného algoritmy ML.NET, zřetězení sloupců optimalizuje následné operace v potrubí tím, že operaci použije pouze jednou pro zřetězený sloupec, místo jednotlivých samostatných sloupců.

Jakmile se sloupce zkombinují do jednoho vektoru, NormalizeMinMax se použije u sloupce Features, aby se Size a HistoricalPrices dostaly do stejného rozsahu mezi 0 a 1.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
        .Append(mlContext.Transforms.NormalizeMinMax("Features"));

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Trénujte model s křížovou validací

Po předběžném zpracování dat je čas model vytrénovat. Nejprve vyberte algoritmus, který je nejvíce v souladu s úlohou strojového učení, který se má provést. Vzhledem k tomu, že predikovaná hodnota je číselně souvislá hodnota, je úkol regresní. Jedním z regresních algoritmů implementovaných ML.NET je algoritmus StochasticDualCoordinateAscentCoordinator. K trénování modelu pomocí křížového ověření použijte metodu CrossValidate.

Poznámka

I když tato ukázka používá lineární regresní model, crossValidate se vztahuje na všechny ostatní úlohy strojového učení v ML.NET s výjimkou detekce anomálií.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate provádí následující operace:

  1. Rozdělí data do několika oddílů, které se rovnají hodnotě zadané v parametru numberOfFolds. Výsledkem každého oddílu je objekt TrainTestData.
  2. Model se trénuje na každé části pomocí určeného odhadce algoritmu strojového učení v tréninkové datové sadě.
  3. Výkon každého modelu se vyhodnocuje pomocí metody Evaluate v testovací sadě dat.
  4. Model spolu s jeho metrikami se vrátí pro každý z těchto modelů.

Výsledek uložený v cvResults je kolekce CrossValidationResult objektů. Tento objekt zahrnuje natrénovaný model i metriky, které jsou přístupné z vlastností Model a Metrics. V této ukázce je vlastnost Model typu ITransformer a vlastnost Metrics je typu RegressionMetrics.

Vyhodnocení modelu

K metrikám pro různé natrénované modely je možné přistupovat prostřednictvím vlastnosti Metrics jednotlivého objektu CrossValidationResult. V tomto případě je metrika R-Squared přístupná a uložená v proměnné rSquared.

IEnumerable<double> rSquared =
    cvResults
        .Select(fold => fold.Metrics.RSquared);

Pokud zkontrolujete obsah proměnné rSquared, výstup by měl obsahovat pět hodnot v rozsahu od 0 do 1, kde hodnota blíže k 1 znamená lepší. Pomocí metrik, jako je R-Squared, vyberte modely od nejlepších po nejhorší výkon. Pak výběrem horního modelu proveďte předpovědi nebo proveďte další operace.

// Select all models
ITransformer[] models =
    cvResults
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)
        .ToArray();

// Get Top Model
ITransformer topModel = models[0];