Interpréter les prédictions de modèle à l’aide de l’importance des fonctionnalités de permutation
À l’aide de l’importance des caractéristiques par permutation (PFI), découvrez comment interpréter les prédictions de modèles d'apprentissage automatique de ML.NET. PFI fournit la contribution relative que chaque fonctionnalité apporte à une prédiction.
Les modèles Machine Learning sont souvent considérés comme des zones opaques qui prennent des entrées et génèrent une sortie. Les étapes intermédiaires ou les interactions entre les caractéristiques qui influencent la sortie sont rarement comprises. Étant donné que le Machine Learning est introduit dans d’autres aspects de la vie quotidienne, tels que la santé, il est de la plus grande importance de comprendre pourquoi un modèle Machine Learning prend les décisions qu’il prend. Par exemple, si des diagnostics sont effectués par un modèle Machine Learning, les professionnels de la santé ont besoin d’un moyen de déterminer les facteurs qui ont été pris en compte dans la réalisation de ce diagnostic. Fournir le bon diagnostic pourrait faire une grande différence sur le fait qu’un patient ait une récupération rapide ou non. Par conséquent, plus le modèle offre de transparence, plus les professionnels de la santé ont confiance dans leur capacité d'accepter ou de rejeter les décisions prises par le modèle.
Différentes techniques sont utilisées pour expliquer les modèles, dont l’une est PFI. PFI est une technique utilisée pour expliquer les modèles de classification et de régression, qui s’inspire du document Random Forests de Leo Breiman (voir la section 10). De manière générale, la manière dont cela fonctionne consiste à mélanger de façon aléatoire les données une fonction à la fois pour l'ensemble du jeu de données et à calculer à quel point la métrique de performance d'intérêt diminue. Plus la modification est importante, plus cette fonctionnalité est importante.
En outre, en mettant en évidence les fonctionnalités les plus importantes, les générateurs de modèles peuvent se concentrer sur l’utilisation d’un sous-ensemble de fonctionnalités plus significatives, ce qui peut potentiellement réduire le bruit et le temps d’entraînement.
Charger les données
Les fonctionnalités du jeu de données utilisé pour cet exemple se trouvent dans les colonnes 1 à 12. L’objectif est de prédire Price
.
Colonne | Caractéristique | Description |
---|---|---|
1 | CrimeRate | Taux de criminalité par habitant |
2 | ResidentialZones | Zones résidentielles en ville |
3 | CommercialZones | Zones non résidentielles en ville |
4 | NearWater | Proximité d'un plan d'eau |
5 | Niveaux de déchets toxiques | Niveaux de toxicité (PPM) |
6 | AverageRoomNumber | Nombre moyen de chambres dans la maison |
7 | HomeAge | Âge de la maison |
8 | Distance du centre d'affaires | Distance jusqu’au quartier d’affaires le plus proche |
9 | HighwayAccess | Proximité des autoroutes |
10 | Taux d'imposition | Taux d’imposition des propriétés |
11 | Rapport élèves-enseignant | Ratio des étudiants aux enseignants |
12 | Pourcentage de la population en dessous du seuil de pauvreté | Pourcentage de la population vivant en dessous de la pauvreté |
13 | Prix | Prix de la maison |
Voici un exemple de jeu de données :
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Les données de cet exemple peuvent être modélisées par une classe comme HousingPriceData
et chargées dans un IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Entraîner le modèle
L’exemple de code suivant illustre le processus d’apprentissage d’un modèle de régression linéaire pour prédire les prix des maisons.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);
Expliquer le modèle avec la technique PFI
Dans ML.NET, utilisez la méthode PermutationFeatureImportance
pour votre tâche respective.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Le résultat de l’utilisation de PermutationFeatureImportance
sur le jeu de données d’entraînement est un ImmutableArray
d’objets RegressionMetricsStatistics
. RegressionMetricsStatistics
fournit des statistiques récapitulatives telles que la moyenne et l’écart type pour plusieurs observations de RegressionMetrics
égales au nombre de permutations spécifiées par le paramètre permutationCount
.
La métrique utilisée pour mesurer l’importance des fonctionnalités dépend de la tâche Machine Learning utilisée pour résoudre votre problème. Par exemple, les tâches de régression peuvent utiliser une métrique d’évaluation courante telle que R-squared pour mesurer l’importance. Pour plus d’informations sur les métriques d’évaluation de modèle, consultez évaluer votre modèle ML.NET avec des métriques.
L’importance, ou dans ce cas, la diminution moyenne absolue de la métrique R-squared, calculée par PermutationFeatureImportance
, peut ensuite être ordonnée de la plus importante à la moins importante.
// Order features by importance.
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
L’impression des valeurs pour chacune des fonctionnalités de featureImportanceMetrics
génère une sortie similaire à la sortie qui suit. Vous devez vous attendre à voir des résultats différents, car ces valeurs varient en fonction des données qu’elles reçoivent.
Caractéristique | Passer à R-Squared |
---|---|
HighwayAccess | -0.042731 |
Ratio Élève-Enseignant | -0.012730 |
BusinessCenterDistance | -0.010491 |
Taux d'imposition | -0.008545 |
Nombre moyen de chambres | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
Pourcentage de la population vivant sous le seuil de pauvreté | 0.000031 |
Niveaux de Déchets Toxiques | -0.000019 |
Si vous examinez les cinq caractéristiques les plus importantes pour ce jeu de données, le prix d’une maison prédite par ce modèle est influencé par sa proximité avec les autoroutes, le ratio des enseignants des écoles dans la région, la proximité des grands centres d’emploi, le taux de taxe sur les propriétés et le nombre moyen de chambres dans la maison.