แก้ไข

แชร์ผ่าน


Interpret model predictions using Permutation Feature Importance

Using Permutation Feature Importance (PFI), learn how to interpret ML.NET machine learning model predictions. PFI gives the relative contribution each feature makes to a prediction.

Machine learning models are often thought of as opaque boxes that take inputs and generate an output. The intermediate steps or interactions among the features that influence the output are rarely understood. As machine learning is introduced into more aspects of everyday life, such as healthcare, it's of utmost importance to understand why a machine learning model makes the decisions it does. For example, if diagnoses are made by a machine learning model, healthcare professionals need a way to look into the factors that went into making that diagnosis. Providing the right diagnosis could make a great difference on whether a patient has a speedy recovery or not. Therefore the higher the level of explainability in a model, the greater confidence healthcare professionals have to accept or reject the decisions made by the model.

Various techniques are used to explain models, one of which is PFI. PFI is a technique used to explain classification and regression models that's inspired by Breiman's Random Forests paper (see section 10). At a high level, the way it works is by randomly shuffling data one feature at a time for the entire dataset and calculating how much the performance metric of interest decreases. The larger the change, the more important that feature is.

Additionally, by highlighting the most important features, model builders can focus on using a subset of more meaningful features, which can potentially reduce noise and training time.

Load the data

The features in the dataset used for this sample are in columns 1-12. The goal is to predict Price.

Column Feature Description
1 CrimeRate Per capita crime rate
2 ResidentialZones Residential zones in town
3 CommercialZones Nonresidential zones in town
4 NearWater Proximity to body of water
5 ToxicWasteLevels Toxicity levels (PPM)
6 AverageRoomNumber Average number of rooms in house
7 HomeAge Age of home
8 BusinessCenterDistance Distance to nearest business district
9 HighwayAccess Proximity to highways
10 TaxRate Property tax rate
11 StudentTeacherRatio Ratio of students to teachers
12 PercentPopulationBelowPoverty Percent of population living below poverty
13 Price Price of the home

A sample of the dataset is shown here:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

The data in this sample can be modeled by a class like HousingPriceData and loaded into an IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Train the model

The following code sample illustrates the process of training a linear regression model to predict house prices.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);

Explain the model with Permutation Feature Importance (PFI)

In ML.NET, use the PermutationFeatureImportance method for your respective task.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

The result of using PermutationFeatureImportance on the training dataset is an ImmutableArray of RegressionMetricsStatistics objects. RegressionMetricsStatistics provides summary statistics like mean and standard deviation for multiple observations of RegressionMetrics equal to the number of permutations specified by the permutationCount parameter.

The metric used to measure feature importance depends on the machine learning task used to solve your problem. For example, regression tasks might use a common evaluation metric such as R-squared to measure importance. For more information on model evaluation metrics, see evaluate your ML.NET model with metrics.

The importance, or in this case, the absolute average decrease in R-squared metric, calculated by PermutationFeatureImportance, can then be ordered from most important to least important.

// Order features by importance.
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Printing the values for each of the features in featureImportanceMetrics generates output similar to the output that follows. You should expect to see different results because these values vary based on the data that they're given.

Feature Change to R-Squared
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

If you look at the five most important features for this dataset, the price of a house predicted by this model is influenced by its proximity to highways, student teacher ratio of schools in the area, proximity to major employment centers, property tax rate, and average number of rooms in the home.

Next steps