Dela via


Tolka modellprediktioner med hjälp av permutationsbetydelse för attribut

Med hjälp av PFI (Permutation Feature Importance) får du lära dig hur du tolkar ML.NET förutsägelser för maskininlärningsmodellen. PFI ger det relativa bidrag som varje funktion ger till en förutsägelse.

Maskininlärningsmodeller betraktas ofta som ogenomskinliga rutor som tar indata och genererar utdata. De mellanliggande stegen eller interaktionerna mellan de funktioner som påverkar utdata förstås sällan. När maskininlärning introduceras i fler aspekter av vardagen, till exempel sjukvård, är det av yttersta vikt att förstå varför en maskininlärningsmodell fattar de beslut den gör. Om diagnoser till exempel görs av en maskininlärningsmodell behöver vårdpersonalen ett sätt att undersöka de faktorer som gick till att ställa den diagnosen. Att tillhandahålla rätt diagnos kan göra stor skillnad på om en patient har en snabb återhämtning eller inte. Ju högre förklaringsnivå en modell har, desto större förtroende har vårdpersonalen att acceptera eller avvisa de beslut som fattas av modellen.

Olika tekniker används för att förklara modeller, varav en är PFI. PFI är en teknik som används för att förklara klassificerings- och regressionsmodeller som är inspirerade av Breimans Random Forests paper (se avsnitt 10). På en hög nivå fungerar det genom att slumpmässigt blanda data en funktion i taget för hela datamängden och beräkna hur mycket prestandamåttet av intresse minskar. Ju större förändring, desto viktigare är funktionen.

Genom att markera de viktigaste funktionerna kan modellbyggare dessutom fokusera på att använda en delmängd av mer meningsfulla funktioner, vilket potentiellt kan minska brus och träningstid.

Läs in data

Funktionerna i datamängden som används för det här exemplet finns i kolumnerna 1–12. Målet är att förutsäga Price.

Kolumn Funktion Beskrivning
1 CrimeRate Brottsfrekvens per capita
2 ResidentialZones Bostadsområden i stan
3 CommercialZones Icke-bostadsområden i staden
4 NearWater Närhet till vattenmassa
5 Nivåer av giftigt avfall Toxicitetsnivåer (PPM)
6 GenomsnittligtAntalRum Genomsnittligt antal rum i huset
7 HomeAge Hemmets ålder
8 BusinessCenterDistance Avstånd till närmaste affärsdistrikt
9 Motorvägsanslutning Närhet till motorvägar
10 Skattesats Fastighetsskattesats
11 Elever-till-lärare-kvot Förhållandet mellan elever och lärare
12 AndelAvBefolkningenUnderFattigdom Procent av befolkningen som lever under fattigdom
13 Pris Priset på hemmet

Ett exempel på datamängden visas här:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Data i det här exemplet kan modelleras av en klass som HousingPriceData och läsas in i en IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Träna modellen

Följande kodexempel illustrerar processen att träna en linjär regressionsmodell för att förutsäga huspriser.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);

Förklara modellen med permutationsfunktionens betydelse (PFI - Permutation Feature Importance)

I ML.NET använder du metoden PermutationFeatureImportance för din respektive uppgift.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Resultatet av att använda PermutationFeatureImportance på träningsdatamängden är ett ImmutableArray av RegressionMetricsStatistics objekt. RegressionMetricsStatistics innehåller sammanfattningsstatistik som medelvärde och standardavvikelse för flera observationer av RegressionMetrics lika med antalet permutationer som anges av parametern permutationCount.

Måttet som används för att mäta funktionsvikt beror på vilken maskininlärningsuppgift som används för att lösa problemet. Regressionsaktiviteter kan till exempel använda ett vanligt utvärderingsmått, till exempel R-kvadrat för att mäta prioritet. Mer information om modellutvärderingsmått finns i utvärdera din ML.NET modell med mått.

Vikten, eller i det här fallet den absoluta genomsnittliga minskningen i R-kvadratmått, som beräknas av PermutationFeatureImportance, kan sedan sorteras från viktigast till minst viktigt.

// Order features by importance.
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Om du skriver ut värdena för var och en av funktionerna i featureImportanceMetrics genereras utdata som liknar de utdata som följer. Du bör förvänta dig att se olika resultat eftersom dessa värden varierar beroende på de data som de ges.

Funktion Ändra till R-Kvadrat
HighwayAccess -0.042731
Elev-lärarrelation -0.012730
BusinessCenterDistance -0.010491
Skattesats -0.008545
GenomsnittligtRumsantal -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
ProcentAvBefolkningenSomLeverUnderFattigdom 0.000031
Giftavfallsnivåer -0.000019

Om du tittar på de fem viktigaste funktionerna för den här datamängden påverkas priset på ett hus som förutspås av den här modellen av dess närhet till motorvägar, elevlärarens förhållande mellan skolor i området, närhet till stora arbetsförmedlingar, fastighetsskattesats och genomsnittligt antal rum i hemmet.

Nästa steg