Обучение модели машинного обучения с помощью перекрестной проверки
Узнайте, как использовать перекрестную проверку для обучения более надежных моделей машинного обучения в ML.NET.
Перекрестная проверка — это метод обучения и оценки модели, который разбивает данные на несколько секций и обучает несколько алгоритмов на этих секциях. Этот метод повышает надежность модели, удерживая данные из процесса обучения. Помимо повышения производительности при невиденных наблюдениях, в ограниченных данными средах это может быть эффективным инструментом для обучения моделей с меньшим набором данных.
Данные и модель данных
Данные из файла, имеющего следующий формат:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
Данные можно моделировать классом, например HousingData
, и загружать их в IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Подготовка данных
Предварительно обработайте данные перед его использованием для создания модели машинного обучения. В этом примере столбцы Size
и HistoricalPrices
объединяются в один вектор признаков, который выводится в новый столбец с именем Features
с помощью метода Concatenate
. Помимо получения данных в формат, ожидаемый алгоритмами ML.NET, объединение столбцов оптимизирует последующие операции в конвейере, применяя операцию один раз для сцепленного столбца вместо каждого из отдельных столбцов.
После объединения столбцов в один вектор NormalizeMinMax
применяется к столбцу Features
, чтобы получить Size
и HistoricalPrices
в одном диапазоне от 0 до 1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Обучайте модель с перекрестной проверкой
После предварительной обработки данных пришло время обучить модель. Сначала выберите алгоритм, который наиболее тесно соответствует выполняемой задаче машинного обучения. Поскольку прогнозируемое значение является числовым непрерывным значением, задача - регрессия. Одним из алгоритмов регрессии, реализованных ML.NET, является алгоритм StochasticDualCoordinateAscentCoordinator
. Чтобы обучить модель с перекрестной проверкой, используйте метод CrossValidate
.
Заметка
Хотя в этом примере используется модель линейной регрессии, CrossValidate применимо ко всем остальным задачам машинного обучения в ML.NET кроме обнаружения аномалий.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
выполняет следующие операции:
- Секционирует данные в ряд секций, равных значению, указанному в параметре
numberOfFolds
. Результатом каждого раздела является объектTrainTestData
. - Модель обучается для каждой части с помощью указанного оценочного алгоритма машинного обучения на обучающем наборе данных.
- Производительность каждой модели оценивается с помощью метода
Evaluate
в тестовом наборе данных. - Для каждой из моделей возвращаются сама модель и ее метрики.
Результатом, хранящимся в cvResults
, является коллекция объектов CrossValidationResult
. Этот объект включает в себя обученную модель, а также метрики, которые доступны через свойства Model
и Metrics
соответственно. В этом примере свойство Model
имеет тип ITransformer
, а свойство Metrics
имеет тип RegressionMetrics
.
Оценка модели
Метрики для различных обученных моделей можно получить через свойство Metrics
отдельного объекта CrossValidationResult
. В этом случае доступ к метрики
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Если вы проверяете содержимое переменной rSquared
, выходные данные должны иметь пять значений от 0 до 1, где ближе к 1 означает лучшее. Используя такие метрики, как R-Squared, выберите модели из лучших до худших показателей. Затем выберите верхнюю модель для прогнозирования или выполнения дополнительных операций.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];