Delen via


Modelvoorspellingen interpreteren met behulp van het belang van permutatiefuncties

Met PFI (Permutation Feature Importance) leert u hoe u voorspellingen van ML.NET machine learning-modellen interpreteert. PFI geeft de relatieve bijdrage die elke functie levert aan een voorspelling.

Machine learning-modellen worden vaak beschouwd als ondoorzichtige vakken die invoer nemen en een uitvoer genereren. De tussenliggende stappen of interacties tussen de functies die van invloed zijn op de uitvoer, worden zelden begrepen. Omdat machine learning wordt geïntroduceerd in meer aspecten van het dagelijkse leven, zoals gezondheidszorg, is het van het grootste belang om te begrijpen waarom een machine learning-model de beslissingen neemt die het doet. Als er bijvoorbeeld diagnoses worden uitgevoerd door een machine learning-model, hebben professionals in de gezondheidszorg een manier nodig om te kijken naar de factoren die zijn gegaan bij het maken van die diagnoses. Het verstrekken van de juiste diagnose kan een groot verschil maken over of een patiënt snel herstelt of niet. Hoe hoger het niveau van uitleg in een model, hoe groter het vertrouwen dat professionals in de gezondheidszorg de beslissingen moeten accepteren of afwijzen die door het model zijn genomen.

Verschillende technieken worden gebruikt om modellen uit te leggen, een daarvan is PFI. PFI is een techniek die wordt gebruikt om classificatie- en regressiemodellen uit te leggen die zijn geïnspireerd op het artikel Random Forests van Breiman (zie sectie 10). Op een hoog niveau is de manier waarop het werkt door willekeurig gegevens één functie tegelijk te shuffling voor de hele gegevensset en door te berekenen hoeveel de metrische prestatiegegevens van belang afnemen. Hoe groter de wijziging, hoe belangrijker die functie is.

Door de belangrijkste functies te benadrukken, kunnen modelbouwers zich richten op het gebruik van een subset van zinvollere functies die mogelijk ruis en trainingstijd kunnen verminderen.

De gegevens laden

De functies in de gegevensset die voor dit voorbeeld worden gebruikt, bevinden zich in kolommen 1-12. Het doel is om te voorspellen Price.

Kolom Functie Beschrijving
1 CrimeRate Criminaliteit per hoofd van de bevolking
2 ResidentialZones Woonwijken in de stad
3 CommercialZones Niet-woonzones in de stad
4 NearWater Nabijheid van waterlichaam
5 ToxicWasteLevels Toxiciteitsniveaus (PPM)
6 AverageRoomNumber Gemiddeld aantal kamers in huis
7 HomeAge Leeftijd van thuis
8 BusinessCenterDistance Afstand tot dichtstbijzijnde zakenwijk
9 HighwayAccess Nabijheid van snelwegen
10 TaxRate Belastingtarief voor onroerend goed
11 StudentTeacherRatio Verhouding van leerlingen/studenten tot docenten
12 PercentPopulationBelowPoverty Percentage van de bevolking dat onder armoede leeft
13 Prijs Prijs van het huis

Hieronder ziet u een voorbeeld van de gegevensset:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

De gegevens in dit voorbeeld kunnen worden gemodelleerd door een klasse zoals HousingPriceData en geladen in een IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Het model trainen

Het onderstaande codevoorbeeld illustreert het proces van het trainen van een lineair regressiemodel om de huizenprijzen te voorspellen.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Het model uitleggen met PFI (Permutation Feature Importance)

Gebruik in ML.NET de PermutationFeatureImportance methode voor uw respectieve taak.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Het resultaat van het gebruik PermutationFeatureImportance van de trainingsgegevensset is een ImmutableArray van RegressionMetricsStatistics de objecten. RegressionMetricsStatistics biedt samenvattingsstatistieken zoals gemiddelde en standaarddeviatie voor meerdere waarnemingen die RegressionMetrics gelijk zijn aan het aantal permutaties dat is opgegeven door de permutationCount parameter.

De metrische waarde die wordt gebruikt om het belang van functies te meten, is afhankelijk van de machine learning-taak die wordt gebruikt om uw probleem op te lossen. Regressietaken kunnen bijvoorbeeld gebruikmaken van een algemene metrische evaluatiewaarde, zoals R-kwadraat om het belang te meten. Zie uw ML.NET model evalueren met metrische gegevens voor meer informatie over metrische gegevens voor modelevaluatie.

Het belang, of in dit geval, de absolute gemiddelde afname van de R-kwadratische metrische waarde die wordt PermutationFeatureImportance berekend, kan vervolgens worden geordend van het belangrijkste naar het minst belangrijk.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Als u de waarden afdrukt voor elk van de functies in, featureImportanceMetrics wordt uitvoer gegenereerd die vergelijkbaar is met die hieronder. Houd er rekening mee dat u verschillende resultaten kunt verwachten, omdat deze waarden variëren op basis van de gegevens die ze krijgen.

Functie Wijzigen in R-Kwadraat
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Bekijk de vijf belangrijkste functies voor deze gegevensset, de prijs van een huis dat door dit model wordt voorspeld, wordt beïnvloed door de nabijheid van snelwegen, de docentverhouding van studenten van scholen in het gebied, de nabijheid van belangrijke werkgelegenheidscentra, het belastingtarief voor onroerend goed en het gemiddelde aantal kamers in het huis.

Volgende stappen