Train a machine learning model using cross validation
Learn how to use cross validation to train more robust machine learning models in ML.NET.
Cross-validation is a training and model evaluation technique that splits the data into several partitions and trains multiple algorithms on these partitions. This technique improves the robustness of the model by holding out data from the training process. In addition to improving performance on unseen observations, in data-constrained environments it can be an effective tool for training models with a smaller dataset.
The data and data model
Given data from a file that has the following format:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
The data can be modeled by a class like HousingData
and loaded into an IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Prepare the data
Pre-process the data before using it to build the machine learning model. In this sample, the Size
and HistoricalPrices
columns are combined into a single feature vector, which is output to a new column called Features
using the Concatenate
method. In addition to getting the data into the format expected by ML.NET algorithms, concatenating columns optimizes subsequent operations in the pipeline by applying the operation once for the concatenated column instead of each of the separate columns.
Once the columns are combined into a single vector, NormalizeMinMax
is applied to the Features
column to get Size
and HistoricalPrices
in the same range between 0-1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Train model with cross validation
Once the data has been preprocessed, it's time to train the model. First, select the algorithm that most closely aligns with the machine learning task to be performed. Because the predicted value is a numerically continuous value, the task is regression. One of the regression algorithms implemented by ML.NET is the StochasticDualCoordinateAscentCoordinator
algorithm. To train the model with cross-validation use the CrossValidate
method.
Note
Although this sample uses a linear regression model, CrossValidate is applicable to all other machine learning tasks in ML.NET except Anomaly Detection.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
performs the following operations:
- Partitions the data into a number of partitions equal to the value specified in the
numberOfFolds
parameter. The result of each partition is aTrainTestData
object. - A model is trained on each of the partitions using the specified machine learning algorithm estimator on the training data set.
- Each model's performance is evaluated using the
Evaluate
method on the test data set. - The model along with its metrics are returned for each of the models.
The result stored in cvResults
is a collection of CrossValidationResult
objects. This object includes the trained model as well as metrics which are both accessible form the Model
and Metrics
properties respectively. In this sample, the Model
property is of type ITransformer
and the Metrics
property is of type RegressionMetrics
.
Evaluate the model
Metrics for the different trained models can be accessed through the Metrics
property of the individual CrossValidationResult
object. In this case, the R-Squared metric is accessed and stored in the variable rSquared
.
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
If you inspect the contents of the rSquared
variable, the output should be five values ranging from 0-1 where closer to 1 means best. Using metrics like R-Squared, select the models from best to worst performing. Then, select the top model to make predictions or perform additional operations with.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];