Modelvoorspellingen interpreteren met behulp van het belang van permutatiefuncties
Met PFI (Permutation Feature Importance) leert u hoe u voorspellingen van ML.NET machine learning-modellen interpreteert. PFI geeft de relatieve bijdrage die elke functie levert aan een voorspelling.
Machine learning-modellen worden vaak beschouwd als ondoorzichtige vakken die invoer nemen en een uitvoer genereren. De tussenliggende stappen of interacties tussen de functies die van invloed zijn op de uitvoer, worden zelden begrepen. Omdat machine learning wordt geïntroduceerd in meer aspecten van het dagelijkse leven, zoals gezondheidszorg, is het van het grootste belang om te begrijpen waarom een machine learning-model de beslissingen neemt die het doet. Als er bijvoorbeeld diagnoses worden uitgevoerd door een machine learning-model, hebben professionals in de gezondheidszorg een manier nodig om te kijken naar de factoren die zijn gegaan bij het maken van die diagnoses. Het verstrekken van de juiste diagnose kan een groot verschil maken over of een patiënt snel herstelt of niet. Hoe hoger het niveau van uitleg in een model, hoe groter het vertrouwen dat professionals in de gezondheidszorg de beslissingen moeten accepteren of afwijzen die door het model zijn genomen.
Verschillende technieken worden gebruikt om modellen uit te leggen, een daarvan is PFI. PFI is een techniek die wordt gebruikt om classificatie- en regressiemodellen uit te leggen die zijn geïnspireerd op het artikel Random Forests van Breiman (zie sectie 10). Op een hoog niveau is de manier waarop het werkt door willekeurig gegevens één functie tegelijk te shuffling voor de hele gegevensset en door te berekenen hoeveel de metrische prestatiegegevens van belang afnemen. Hoe groter de wijziging, hoe belangrijker die functie is.
Door de belangrijkste functies te benadrukken, kunnen modelbouwers zich richten op het gebruik van een subset van zinvollere functies die mogelijk ruis en trainingstijd kunnen verminderen.
De gegevens laden
De functies in de gegevensset die voor dit voorbeeld worden gebruikt, bevinden zich in kolommen 1-12. Het doel is om te voorspellen Price
.
Kolom | Functie | Beschrijving |
---|---|---|
1 | CrimeRate | Criminaliteit per hoofd van de bevolking |
2 | ResidentialZones | Woonwijken in de stad |
3 | CommercialZones | Niet-woonzones in de stad |
4 | NearWater | Nabijheid van waterlichaam |
5 | ToxicWasteLevels | Toxiciteitsniveaus (PPM) |
6 | AverageRoomNumber | Gemiddeld aantal kamers in huis |
7 | HomeAge | Leeftijd van thuis |
8 | BusinessCenterDistance | Afstand tot dichtstbijzijnde zakenwijk |
9 | HighwayAccess | Nabijheid van snelwegen |
10 | TaxRate | Belastingtarief voor onroerend goed |
11 | StudentTeacherRatio | Verhouding van leerlingen/studenten tot docenten |
12 | PercentPopulationBelowPoverty | Percentage van de bevolking dat onder armoede leeft |
13 | Prijs | Prijs van het huis |
Hieronder ziet u een voorbeeld van de gegevensset:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
De gegevens in dit voorbeeld kunnen worden gemodelleerd door een klasse zoals HousingPriceData
en geladen in een IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Het model trainen
Het onderstaande codevoorbeeld illustreert het proces van het trainen van een lineair regressiemodel om de huizenprijzen te voorspellen.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Het model uitleggen met PFI (Permutation Feature Importance)
Gebruik in ML.NET de PermutationFeatureImportance
methode voor uw respectieve taak.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Het resultaat van het gebruik PermutationFeatureImportance
van de trainingsgegevensset is een ImmutableArray
van RegressionMetricsStatistics
de objecten. RegressionMetricsStatistics
biedt samenvattingsstatistieken zoals gemiddelde en standaarddeviatie voor meerdere waarnemingen die RegressionMetrics
gelijk zijn aan het aantal permutaties dat is opgegeven door de permutationCount
parameter.
De metrische waarde die wordt gebruikt om het belang van functies te meten, is afhankelijk van de machine learning-taak die wordt gebruikt om uw probleem op te lossen. Regressietaken kunnen bijvoorbeeld gebruikmaken van een algemene metrische evaluatiewaarde, zoals R-kwadraat om het belang te meten. Zie uw ML.NET model evalueren met metrische gegevens voor meer informatie over metrische gegevens voor modelevaluatie.
Het belang, of in dit geval, de absolute gemiddelde afname van de R-kwadratische metrische waarde die wordt PermutationFeatureImportance
berekend, kan vervolgens worden geordend van het belangrijkste naar het minst belangrijk.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Als u de waarden afdrukt voor elk van de functies in, featureImportanceMetrics
wordt uitvoer gegenereerd die vergelijkbaar is met die hieronder. Houd er rekening mee dat u verschillende resultaten kunt verwachten, omdat deze waarden variëren op basis van de gegevens die ze krijgen.
Functie | Wijzigen in R-Kwadraat |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
Bekijk de vijf belangrijkste functies voor deze gegevensset, de prijs van een huis dat door dit model wordt voorspeld, wordt beïnvloed door de nabijheid van snelwegen, de docentverhouding van studenten van scholen in het gebied, de nabijheid van belangrijke werkgelegenheidscentra, het belastingtarief voor onroerend goed en het gemiddelde aantal kamers in het huis.