Partage via


Interpréter les prédictions de modèle à l’aide de l’importance des fonctionnalités de permutation

À l’aide de l’importance des caractéristiques par permutation (PFI), découvrez comment interpréter les prédictions de modèles d'apprentissage automatique de ML.NET. PFI fournit la contribution relative que chaque fonctionnalité apporte à une prédiction.

Les modèles Machine Learning sont souvent considérés comme des zones opaques qui prennent des entrées et génèrent une sortie. Les étapes intermédiaires ou les interactions entre les caractéristiques qui influencent la sortie sont rarement comprises. Étant donné que le Machine Learning est introduit dans d’autres aspects de la vie quotidienne, tels que la santé, il est de la plus grande importance de comprendre pourquoi un modèle Machine Learning prend les décisions qu’il prend. Par exemple, si des diagnostics sont effectués par un modèle Machine Learning, les professionnels de la santé ont besoin d’un moyen de déterminer les facteurs qui ont été pris en compte dans la réalisation de ce diagnostic. Fournir le bon diagnostic pourrait faire une grande différence sur le fait qu’un patient ait une récupération rapide ou non. Par conséquent, plus le modèle offre de transparence, plus les professionnels de la santé ont confiance dans leur capacité d'accepter ou de rejeter les décisions prises par le modèle.

Différentes techniques sont utilisées pour expliquer les modèles, dont l’une est PFI. PFI est une technique utilisée pour expliquer les modèles de classification et de régression, qui s’inspire du document Random Forests de Leo Breiman (voir la section 10). De manière générale, la manière dont cela fonctionne consiste à mélanger de façon aléatoire les données une fonction à la fois pour l'ensemble du jeu de données et à calculer à quel point la métrique de performance d'intérêt diminue. Plus la modification est importante, plus cette fonctionnalité est importante.

En outre, en mettant en évidence les fonctionnalités les plus importantes, les générateurs de modèles peuvent se concentrer sur l’utilisation d’un sous-ensemble de fonctionnalités plus significatives, ce qui peut potentiellement réduire le bruit et le temps d’entraînement.

Charger les données

Les fonctionnalités du jeu de données utilisé pour cet exemple se trouvent dans les colonnes 1 à 12. L’objectif est de prédire Price.

Colonne Caractéristique Description
1 CrimeRate Taux de criminalité par habitant
2 ResidentialZones Zones résidentielles en ville
3 CommercialZones Zones non résidentielles en ville
4 NearWater Proximité d'un plan d'eau
5 Niveaux de déchets toxiques Niveaux de toxicité (PPM)
6 AverageRoomNumber Nombre moyen de chambres dans la maison
7 HomeAge Âge de la maison
8 Distance du centre d'affaires Distance jusqu’au quartier d’affaires le plus proche
9 HighwayAccess Proximité des autoroutes
10 Taux d'imposition Taux d’imposition des propriétés
11 Rapport élèves-enseignant Ratio des étudiants aux enseignants
12 Pourcentage de la population en dessous du seuil de pauvreté Pourcentage de la population vivant en dessous de la pauvreté
13 Prix Prix de la maison

Voici un exemple de jeu de données :

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Les données de cet exemple peuvent être modélisées par une classe comme HousingPriceData et chargées dans un IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Entraîner le modèle

L’exemple de code suivant illustre le processus d’apprentissage d’un modèle de régression linéaire pour prédire les prix des maisons.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);

Expliquer le modèle avec la technique PFI

Dans ML.NET, utilisez la méthode PermutationFeatureImportance pour votre tâche respective.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Le résultat de l’utilisation de PermutationFeatureImportance sur le jeu de données d’entraînement est un ImmutableArray d’objets RegressionMetricsStatistics. RegressionMetricsStatistics fournit des statistiques récapitulatives telles que la moyenne et l’écart type pour plusieurs observations de RegressionMetrics égales au nombre de permutations spécifiées par le paramètre permutationCount.

La métrique utilisée pour mesurer l’importance des fonctionnalités dépend de la tâche Machine Learning utilisée pour résoudre votre problème. Par exemple, les tâches de régression peuvent utiliser une métrique d’évaluation courante telle que R-squared pour mesurer l’importance. Pour plus d’informations sur les métriques d’évaluation de modèle, consultez évaluer votre modèle ML.NET avec des métriques.

L’importance, ou dans ce cas, la diminution moyenne absolue de la métrique R-squared, calculée par PermutationFeatureImportance, peut ensuite être ordonnée de la plus importante à la moins importante.

// Order features by importance.
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

L’impression des valeurs pour chacune des fonctionnalités de featureImportanceMetrics génère une sortie similaire à la sortie qui suit. Vous devez vous attendre à voir des résultats différents, car ces valeurs varient en fonction des données qu’elles reçoivent.

Caractéristique Passer à R-Squared
HighwayAccess -0.042731
Ratio Élève-Enseignant -0.012730
BusinessCenterDistance -0.010491
Taux d'imposition -0.008545
Nombre moyen de chambres -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
Pourcentage de la population vivant sous le seuil de pauvreté 0.000031
Niveaux de Déchets Toxiques -0.000019

Si vous examinez les cinq caractéristiques les plus importantes pour ce jeu de données, le prix d’une maison prédite par ce modèle est influencé par sa proximité avec les autoroutes, le ratio des enseignants des écoles dans la région, la proximité des grands centres d’emploi, le taux de taxe sur les propriétés et le nombre moyen de chambres dans la maison.

Étapes suivantes