使用排列特征重要性解释模型预测

使用排列特征重要性 (PFI),了解如何解释 ML.NET 机器学习模型预测。 PFI 可表示每个特征对预测的相对贡献。

机器学习模型通常被视为不透明盒,它们会接收输入并生成输出。 人们对影响输出的中间步骤或特征之间的交互了解甚少。 随着机器学习被引入日常生活的更多方面(例如医疗保健),理解机器学习模型为何做出其决策变得至关重要。 例如,如果诊断由机器学习模型做出,则医疗保健专业人员需要查看影响做出诊断的因素的方法。 提供正确的诊断可以对患者是否快速康复产生重大影响。 因此,模型的可解释性水平越高,医疗保健专业人员就越有信心接受或拒绝模型做出的决策。

有各种技术被用于解释模型,其中之一是 PFI。 PFI 是一种用于解释分类和回归模型的技术,其灵感来自 Breima 的 Random Forests(随机森林)论文(参见第 10 部分)。 概括而言,其工作原理是一次随机为整个数据集随机抽取数据的一个特征,并计算关注性能指标的下降程度。 变化越大,特征就越重要。

此外,通过突出显示最重要的特征,模型生成器可以专注于使用一组更有意义的特征,这可能会减少干扰和训练时间。

加载数据

数据集中用于此示例的特征位于列 1-12 中。 目标在于预测 Price

功能 描述
1 CrimeRate 人均犯罪率
2 ResidentialZones 城镇住宅区
3 CommercialZones 城镇非住宅区
4 NearWater 距水体的距离
5 ToxicWasteLevels 毒性程度 (PPM)
6 AverageRoomNumber 房屋平均房间数量
7 HomeAge 楼龄
8 BusinessCenterDistance 与最近商业区的距离
9 HighwayAccess 距公路的距离
10 TaxRate 财产税率
11 StudentTeacherRatio 师生比率
12 PercentPopulationBelowPoverty 贫困线以下人口百分比
13 Price 住宅价格

数据集的示例如下所示:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

此示例中的数据可以通过 HousingPriceData 等类进行建模并加载到 IDataView 中。

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

定型模型

下面的代码示例演示了训练线性回归模型用于预测房屋价格的过程。

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

使用排列特征重要性 (PFI) 解释模型

在 ML.NET 中,为相应的任务使用 PermutationFeatureImportance 方法。

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

在训练数据集上使用 PermutationFeatureImportance 的结果是 对象的 RegressionMetricsStatisticsImmutableArrayRegressionMetricsStatistics 提供 RegressionMetrics 的多个观测值的均值和标准差等摘要统计信息,观测值数量等于 permutationCount 参数指定的排列数。

用于度量特征重要性的指标取决于用于解决问题的机器学习任务。 例如,回归任务可以使用诸如 R 平方这样的常见评估指标来度量重要性。 有关模型评估指标详细信息,请参阅使用指标评估 ML.NET 模型

重要性(在本例中,由 PermutationFeatureImportance 计算的 R 平方指标的绝对平均下降)可随后按从最重要到最不重要的顺序排序。

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

打印 featureImportanceMetrics 中每个特征的值将生成类似如下的输出。 请记住,应该预期看到不同的结果,因为这些值根据其获得的数据而有所不同。

功能 R 平方的变化
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

查看此数据集最重要的五个特征,此模型预测的房屋价格受其与公路的距离、该区域学校的师生比率、与主要就业中心的距离、资产税率和房屋平均房间数量的影响。

后续步骤