Sdílet prostřednictvím


Přetrénování modelu

Naučte se přetrénovat model strojového učení v ML.NET.

Svět a její data se neustále mění. Modely se proto musí také měnit a aktualizovat. ML.NET poskytuje funkce pro přetrénování modelů, které používají naučené parametry modelu jako výchozí bod pro neustálé sestavování na předchozích zkušenostech, a ne vždy od začátku.

V ML.NET se dají přetrénovat následující algoritmy:

Načtení předem natrénovaného modelu

Nejprve načtěte předem natrénovaný model do aplikace. Další informace o načítání trénovacích kanálů a modelů najdete v tématu Uložení a načtení natrénovaného modelu.

// Create MLContext
MLContext mlContext = new MLContext();

// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;

// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);

// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);

Extrahování předtrénovaných parametrů modelu

Po načtení modelu extrahujte naučené parametry modelu přístupem k atributu Model předtrénovaného modelu. Předtrénovaný model byl natrénován pomocí lineárního regresního modelu OnlineGradientDescentTrainer, který vytvoří RegressionPredictionTransformer, který vypíše LinearRegressionModelParameters. Tyto parametry modelu obsahují naučené předsudky a váhy nebo koeficienty modelu. Tyto hodnoty se používají jako výchozí bod nového přetrénovaného modelu.

// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
    ((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;

Poznámka

Výstup parametrů modelu závisí na použitém algoritmu. Například OnlineGradientDescentTrainer používá LinearRegressionModelParameters, zatímco LbfgsMaximumEntropyMulticlassTrainer výstupy MaximumEntropyModelParameters. Při extrahování parametrů modelu přetypujte na příslušný typ.

Přetrénování modelu

Proces opětovného trénování modelu se neliší od trénování modelu. Jediným rozdílem je, že předáte další argument metodě Fit(IDataView, LinearModelParameters): původní naučené parametry modelu. Fit() je používá jako výchozí bod procesu opětovného vytrénování.

// New Data
HousingData[] housingData = new HousingData[]
{
    new HousingData
    {
        Size = 850f,
        HistoricalPrices = new float[] { 150000f,175000f,210000f },
        CurrentPrice = 205000f
    },
    new HousingData
    {
        Size = 900f,
        HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
        CurrentPrice = 210000f
    },
    new HousingData
    {
        Size = 550f,
        HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
        CurrentPrice = 180000f
    }
};

//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);

// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);

// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
    mlContext.Regression.Trainers.OnlineGradientDescent()
        .Fit(transformedNewData, originalModelParameters);

V tomto okamžiku můžete znovu natrénovaný model uložit a použít jej ve své aplikaci. Další informace najdete v tématu Uložení a načtení natrénovaného modelu a Vytváření předpovědí pomocí vytrénovaného modelu.

Porovnání parametrů modelu

Jak poznáte, jestli k přetrénování skutečně došlo? Jedním ze způsobů je porovnat, jestli se parametry přetrénovaného modelu liší od parametrů původního modelu. Následující ukázka kódu porovná původní hodnoty s přetrénovanými hmotnostmi modelu a vypíše je do konzoly.

// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;

// Inspect Change in Weights
var weightDiffs =
    originalModelParameters.Weights.Zip(
        retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();

Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
    Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}

Následující tabulka ukazuje, jak může výstup vypadat.

Původní Přeškoleno Rozdíl
33039.86 56293.76 -23253.9
29099.14 49586.03 -20486.89
28 938,38 48609.23 -19670.85
30484.02 53745.43 -23261.41