使用排列特征重要性解释模型预测
使用排列特征重要性 (PFI),了解如何解释 ML.NET 机器学习模型预测。 PFI 可表示每个特征对预测的相对贡献。
机器学习模型通常被视为不透明盒,它们会接收输入并生成输出。 人们对影响输出的中间步骤或特征之间的交互了解甚少。 随着机器学习被引入日常生活的更多方面(例如医疗保健),理解机器学习模型为何做出其决策变得至关重要。 例如,如果诊断由机器学习模型做出,则医疗保健专业人员需要查看影响做出诊断的因素的方法。 提供正确的诊断可以对患者是否快速康复产生重大影响。 因此,模型的可解释性水平越高,医疗保健专业人员就越有信心接受或拒绝模型做出的决策。
有各种技术被用于解释模型,其中之一是 PFI。 PFI 是一种用于解释分类和回归模型的技术,其灵感来自 Breima 的 Random Forests(随机森林)论文(参见第 10 部分)。 概括而言,其工作原理是一次随机为整个数据集随机抽取数据的一个特征,并计算关注性能指标的下降程度。 变化越大,特征就越重要。
此外,通过突出显示最重要的特征,模型生成器可以专注于使用一组更有意义的特征,这可能会减少干扰和训练时间。
加载数据
数据集中用于此示例的特征位于列 1-12 中。 目标在于预测 Price
。
列 | 功能 | 描述 |
---|---|---|
1 | CrimeRate | 人均犯罪率 |
2 | ResidentialZones | 城镇住宅区 |
3 | CommercialZones | 城镇非住宅区 |
4 | NearWater | 距水体的距离 |
5 | ToxicWasteLevels | 毒性程度 (PPM) |
6 | AverageRoomNumber | 房屋平均房间数量 |
7 | HomeAge | 楼龄 |
8 | BusinessCenterDistance | 与最近商业区的距离 |
9 | HighwayAccess | 距公路的距离 |
10 | TaxRate | 财产税率 |
11 | StudentTeacherRatio | 师生比率 |
12 | PercentPopulationBelowPoverty | 贫困线以下人口百分比 |
13 | Price | 住宅价格 |
数据集的示例如下所示:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
此示例中的数据可以通过 HousingPriceData
等类进行建模并加载到 IDataView
中。
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
定型模型
下面的代码示例演示了训练线性回归模型用于预测房屋价格的过程。
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
使用排列特征重要性 (PFI) 解释模型
在 ML.NET 中,为相应的任务使用 PermutationFeatureImportance
方法。
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
在训练数据集上使用 PermutationFeatureImportance
的结果是 对象的 RegressionMetricsStatistics
ImmutableArray
。 RegressionMetricsStatistics
提供 RegressionMetrics
的多个观测值的均值和标准差等摘要统计信息,观测值数量等于 permutationCount
参数指定的排列数。
用于度量特征重要性的指标取决于用于解决问题的机器学习任务。 例如,回归任务可以使用诸如 R 平方这样的常见评估指标来度量重要性。 有关模型评估指标详细信息,请参阅使用指标评估 ML.NET 模型。
重要性(在本例中,由 PermutationFeatureImportance
计算的 R 平方指标的绝对平均下降)可随后按从最重要到最不重要的顺序排序。
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
打印 featureImportanceMetrics
中每个特征的值将生成类似如下的输出。 请记住,应该预期看到不同的结果,因为这些值根据其获得的数据而有所不同。
功能 | R 平方的变化 |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
查看此数据集最重要的五个特征,此模型预测的房屋价格受其与公路的距离、该区域学校的师生比率、与主要就业中心的距离、资产税率和房屋平均房间数量的影响。