Interpretieren von Modellvorhersagen mit Permutation Feature Importance
In diesem Artikel erfahren Sie, wie Sie mithilfe von Permutation Feature Importance (PFI) in ML.NET Vorhersagen von Machine Learning-Modellen interpretieren können. PFI stellt Informationen dazu bereit, welchen relativen Anteil jedes Feature an einer Vorhersage hat.
Machine Learning-Modelle werden oft als Blackboxes betrachtet, die aus Eingaben eine Ausgabe generieren. Die Zwischenschritte oder Interaktionen zwischen den Features, die die Ausgabe beeinflussen, werden nur selten verstanden. Da das maschinelle Lernen in immer mehr Bereichen des täglichen Lebens, wie beispielsweise im Gesundheitswesen, zum Einsatz kommt, ist es von größter Bedeutung zu verstehen, warum ein Machine Learning-Modell die Entscheidungen trifft, die es trifft. Wenn die Diagnosen zum Beispiel durch ein Machine Learning-Modell gestellt werden, brauchen die Mediziner eine Möglichkeit, die Faktoren zu untersuchen, die bei der Erstellung dieser Diagnosen berücksichtigt wurden. Die richtige Diagnose könnte einen großen Unterschied machen, ob die Genesung eines Patienten schnell verläuft oder nicht. Je höher also der Grad der Erklärbarkeit in einem Modell ist, desto größer ist das Vertrauen der Mediziner, die die vom Modell getroffenen Entscheidungen akzeptieren oder ablehnen müssen.
Zur Erklärung von Modellen werden verschiedene Techniken verwendet, darunter PFI. PFI ist eine Technik zur Erklärung von Klassifizierungs- und Regressionsmodellen, die von Leo Breimans Schrift Random Forests inspiriert ist (siehe Abschnitt 10). Allgemein funktioniert die Technik so, das Daten für das gesamte Dataset einzeln zufällig gemischt werden und anschließend berechnet wird, wie stark die Leistungsmetrik von Interesse abnimmt. Je größer die Änderung, desto wichtiger ist dieses Feature.
Darüber hinaus können sich Modellersteller durch die Hervorhebung der wichtigsten Features auf die Verwendung einer Teilmenge sinnvoller Features konzentrieren, die Rauschen und Trainingszeiten reduzieren können.
Laden der Daten
Die Features im Dataset, das für dieses Beispiel verwendet wird, befinden sich in den Spalten 1-12. Das Ziel ist die Vorhersage von Price
.
Spalte | Feature | Beschreibung |
---|---|---|
1 | CrimeRate | Pro-Kopf-Kriminalitätsrate |
2 | ResidentialZones | Wohngebiete in der Stadt |
3 | CommercialZones | Nicht-Wohngebiete in der Stadt |
4 | NearWater | Nähe zu einem Gewässer |
5 | ToxicWasteLevels | Toxizitätswerte (PPM) |
6 | AverageRoomNumber | Durchschnittliche Anzahl von Räumen in einem Haus |
7 | HomeAge | Alter des Hauses |
8 | BusinessCenterDistance | Entfernung zum nächsten Geschäftsviertel |
9 | HighwayAccess | Geografischer Nähe zu Autobahnen |
10 | TaxRate | Grundsteuersatz |
11 | StudentTeacherRatio | Verhältnis zwischen Schülern/Studenten und Lehrkräften |
12 | PercentPopulationBelowPoverty | Prozentsatz der Bevölkerung, der unter der Armutsgrenze lebt |
13 | Preis | Preis des Hauses |
Ein Beispiel für das Dataset wird unten dargestellt:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Die Daten in diesem Beispiel können durch eine Klasse wie HousingPriceData
modelliert und in eine IDataView
geladen werden.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Trainieren des Modells
Das folgende Codebeispiel veranschaulicht den Prozess des Trainings eines linearen Regressionsmodells zur Vorhersage von Hauspreisen.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Erläutern des Modells mit Permutation Feature Importance (PFI)
ML.NET verwendet die PermutationFeatureImportance
-Methode für die jeweilige Aufgabe.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Das Ergebnis der Verwendung von PermutationFeatureImportance
auf das Trainingsdataset ist ein ImmutableArray
von RegressionMetricsStatistics
-Objekten. RegressionMetricsStatistics
liefert zusammenfassende Statistiken wie Mittelwert und Standardabweichung für mehrere Beobachtungen der RegressionMetrics
, die der Anzahl der durch den permutationCount
-Parameter angegebenen Permutationen entsprechen.
Die Metrik, die zum Messen der Wichtigkeit von Features verwendet wird, hängt von der Machine Learning-Aufgabe ab, die zum Lösen Ihres Problems verwendet wird. Beispielsweise können Regressionsaufgaben eine allgemeine Auswertungsmetrik wie R-squared verwenden, um die Wichtigkeit zu messen. Weitere Informationen zu Modellauswertungsmetriken finden Sie unter Auswerten des ML.NET-Modells mit Metriken.
Die Bedeutung, oder in diesem Fall die aus PermutationFeatureImportance
berechnete absolute durchschnittliche Abnahme der Metrik für das Bestimmtheitsmaß, kann dann vom wichtigsten zum unwichtigsten Feature geordnet werden.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Wenn Sie die Werte für jedes Feature in featureImportanceMetrics
ausdrucken, würde das in etwa wie folgt aussehen. Beachten Sie, dass mit unterschiedlichen Ergebnissen zu rechnen ist, da diese Werte je nach den Daten, die sie erhalten, variieren.
Feature | Änderung im Bestimmtheitsmaß |
---|---|
HighwayAccess | -0,042731 |
StudentTeacherRatio | -0,012730 |
BusinessCenterDistance | -0,010491 |
TaxRate | -0,008545 |
AverageRoomNumber | -0,003949 |
CrimeRate | -0,003665 |
CommercialZones | 0,002749 |
HomeAge | -0,002426 |
ResidentialZones | -0,002319 |
NearWater | 0,000203 |
PercentPopulationLivingBelowPoverty | 0,000031 |
ToxicWasteLevels | -0,000019 |
Wenn Sie sich die fünf wichtigsten Features dieses Datasets ansehen, wird der Preis eines Hauses, der von diesem Modell vorhergesagt wird, durch die Nähe zu Autobahnen, das Verhältnis zwischen Schülern/Studenten und Lehrkräften in der Region, die Nähe zu den wichtigsten Beschäftigungszentren, den Grundsteuersatz und die durchschnittliche Anzahl der Räume im Haus beeinflusst.