Tolka modellprediktioner med hjälp av permutationsbetydelse för attribut
Med hjälp av PFI (Permutation Feature Importance) får du lära dig hur du tolkar ML.NET förutsägelser för maskininlärningsmodellen. PFI ger det relativa bidrag som varje funktion ger till en förutsägelse.
Maskininlärningsmodeller betraktas ofta som ogenomskinliga rutor som tar indata och genererar utdata. De mellanliggande stegen eller interaktionerna mellan de funktioner som påverkar utdata förstås sällan. När maskininlärning introduceras i fler aspekter av vardagen, till exempel sjukvård, är det av yttersta vikt att förstå varför en maskininlärningsmodell fattar de beslut den gör. Om diagnoser till exempel görs av en maskininlärningsmodell behöver vårdpersonalen ett sätt att undersöka de faktorer som gick till att ställa den diagnosen. Att tillhandahålla rätt diagnos kan göra stor skillnad på om en patient har en snabb återhämtning eller inte. Ju högre förklaringsnivå en modell har, desto större förtroende har vårdpersonalen att acceptera eller avvisa de beslut som fattas av modellen.
Olika tekniker används för att förklara modeller, varav en är PFI. PFI är en teknik som används för att förklara klassificerings- och regressionsmodeller som är inspirerade av Breimans Random Forests paper (se avsnitt 10). På en hög nivå fungerar det genom att slumpmässigt blanda data en funktion i taget för hela datamängden och beräkna hur mycket prestandamåttet av intresse minskar. Ju större förändring, desto viktigare är funktionen.
Genom att markera de viktigaste funktionerna kan modellbyggare dessutom fokusera på att använda en delmängd av mer meningsfulla funktioner, vilket potentiellt kan minska brus och träningstid.
Läs in data
Funktionerna i datamängden som används för det här exemplet finns i kolumnerna 1–12. Målet är att förutsäga Price
.
Kolumn | Funktion | Beskrivning |
---|---|---|
1 | CrimeRate | Brottsfrekvens per capita |
2 | ResidentialZones | Bostadsområden i stan |
3 | CommercialZones | Icke-bostadsområden i staden |
4 | NearWater | Närhet till vattenmassa |
5 | Nivåer av giftigt avfall | Toxicitetsnivåer (PPM) |
6 | GenomsnittligtAntalRum | Genomsnittligt antal rum i huset |
7 | HomeAge | Hemmets ålder |
8 | BusinessCenterDistance | Avstånd till närmaste affärsdistrikt |
9 | Motorvägsanslutning | Närhet till motorvägar |
10 | Skattesats | Fastighetsskattesats |
11 | Elever-till-lärare-kvot | Förhållandet mellan elever och lärare |
12 | AndelAvBefolkningenUnderFattigdom | Procent av befolkningen som lever under fattigdom |
13 | Pris | Priset på hemmet |
Ett exempel på datamängden visas här:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Data i det här exemplet kan modelleras av en klass som HousingPriceData
och läsas in i en IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Träna modellen
Följande kodexempel illustrerar processen att träna en linjär regressionsmodell för att förutsäga huspriser.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);
Förklara modellen med permutationsfunktionens betydelse (PFI - Permutation Feature Importance)
I ML.NET använder du metoden PermutationFeatureImportance
för din respektive uppgift.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Resultatet av att använda PermutationFeatureImportance
på träningsdatamängden är ett ImmutableArray
av RegressionMetricsStatistics
objekt.
RegressionMetricsStatistics
innehåller sammanfattningsstatistik som medelvärde och standardavvikelse för flera observationer av RegressionMetrics
lika med antalet permutationer som anges av parametern permutationCount
.
Måttet som används för att mäta funktionsvikt beror på vilken maskininlärningsuppgift som används för att lösa problemet. Regressionsaktiviteter kan till exempel använda ett vanligt utvärderingsmått, till exempel R-kvadrat för att mäta prioritet. Mer information om modellutvärderingsmått finns i utvärdera din ML.NET modell med mått.
Vikten, eller i det här fallet den absoluta genomsnittliga minskningen i R-kvadratmått, som beräknas av PermutationFeatureImportance
, kan sedan sorteras från viktigast till minst viktigt.
// Order features by importance.
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Om du skriver ut värdena för var och en av funktionerna i featureImportanceMetrics
genereras utdata som liknar de utdata som följer. Du bör förvänta dig att se olika resultat eftersom dessa värden varierar beroende på de data som de ges.
Funktion | Ändra till R-Kvadrat |
---|---|
HighwayAccess | -0.042731 |
Elev-lärarrelation | -0.012730 |
BusinessCenterDistance | -0.010491 |
Skattesats | -0.008545 |
GenomsnittligtRumsantal | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
ProcentAvBefolkningenSomLeverUnderFattigdom | 0.000031 |
Giftavfallsnivåer | -0.000019 |
Om du tittar på de fem viktigaste funktionerna för den här datamängden påverkas priset på ett hus som förutspås av den här modellen av dess närhet till motorvägar, elevlärarens förhållande mellan skolor i området, närhet till stora arbetsförmedlingar, fastighetsskattesats och genomsnittligt antal rum i hemmet.