Interpretace predikcí modelu pomocí důležitosti funkce Permutation
Pomocí funkce Permutation Feature Importance (PFI) se naučíte interpretovat ML.NET predikce modelu strojového učení. PFI dává relativní příspěvek každé funkci k predikci.
Modely strojového učení se často považují za neprůžné rámečky, které přebírají vstupy a generují výstup. Přechodné kroky nebo interakce mezi funkcemi, které ovlivňují výstup, jsou zřídka srozumitelné. Vzhledem k tomu, že strojové učení je zavedeno do více aspektů každodenního života, jako je zdravotní péče, je nanejvýš důležité pochopit, proč model strojového učení dělá rozhodnutí. Pokud jsou například diagnostiky vytvořené modelem strojového učení, potřebují odborníci na zdravotnictví způsob, jak se podívat na faktory, které se dostaly do provádění diagnostiky. Poskytnutí správné diagnózy může mít velký rozdíl na tom, zda má pacient rychlé obnovení, nebo ne. Proto čím vyšší je úroveň vysvětlitelnosti modelu, tím větší spolehlivost zdravotnických odborníků musí přijmout nebo odmítnout rozhodnutí, která model udělal.
K vysvětlení modelů se používají různé techniky, z nichž jedna je PFI. PFI je technika používaná k vysvětlení klasifikačních a regresních modelů inspirovaných dokumentem Random Forests společnosti Breiman (viz část 10). Na vysoké úrovni je způsob, jakým funguje, náhodným náhodném prohazování jedné funkce pro celou datovou sadu a výpočtu toho, kolik metrik výkonu zájmu klesá. Čím větší je změna, tím důležitější je tato funkce.
Kromě toho se tvůrci modelů můžou zaměřit na používání podmnožina smysluplnějších funkcí, které můžou potenciálně snížit šum a dobu trénování.
Načtení dat
Funkce v datové sadě používané pro tuto ukázku jsou ve sloupcích 1–12. Cílem je předpovědět Price
.
Column | Funkce | Popis |
---|---|---|
1 | CrimeRate | Míra trestné činnosti na obyvatele |
2 | ResidentialZones | Obytné zóny ve městě |
3 | CommercialZones | Nebytové zóny ve městě |
4 | NearWater | Blízkost vodního těla |
5 | ToxicWasteLevels | Úrovně toxicity (PPM) |
6 | AverageRoomNumber | Průměrný počet místností v domě |
7 | HomeAge | Věk domu |
8 | BusinessCenterDistance | Vzdálenost k nejbližší obchodní čtvrti |
9 | HighwayAccess | Blízkost dálnic |
10 | TaxRate | Sazba daně z nemovitostí |
11 | StudentTeacherRatio | Poměr studentů k učitelům |
12 | PercentPopulationBelowPoverty | Procento populace žijící pod chudobou |
13 | Cena | Cena domu |
Ukázka datové sady je znázorněná níže:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Data v této ukázce mohou být modelována třídou jako HousingPriceData
a načtena do .IDataView
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Trénování modelu
Následující ukázka kódu znázorňuje proces trénování lineárního regresního modelu, který předpovídá ceny domů.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Vysvětlení modelu s důležitostí funkcí permutace (PFI)
V ML.NET použijte metodu PermutationFeatureImportance
pro příslušný úkol.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Výsledkem použití PermutationFeatureImportance
na trénovací datové sadě je ImmutableArray
RegressionMetricsStatistics
objekt. RegressionMetricsStatistics
poskytuje souhrnné statistiky, jako jsou střední a směrodatná odchylka pro více pozorování RegressionMetrics
rovnajících se počtu permutací určených parametrem permutationCount
.
Metrika použitá k měření důležitosti funkcí závisí na úloze strojového učení použité k vyřešení vašeho problému. Například regresní úlohy můžou k měření důležitosti použít běžnou metriku vyhodnocení, například R-squared. Další informace o metrikách vyhodnocení modelu najdete v tématu vyhodnocení modelu ML.NET metrikami.
Důležitost nebo v tomto případě je možné absolutní průměrné snížení metriky R na druhou mocninu vypočítané podle toho, co PermutationFeatureImportance
je poté seřazeno od nejdůležitějších po nejméně důležité.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Při tisku hodnot jednotlivých funkcí by featureImportanceMetrics
se výstup podobný následujícímu vygeneroval. Mějte na paměti, že byste měli očekávat různé výsledky, protože tyto hodnoty se liší v závislosti na zadaných datech.
Funkce | Změna na R-Squared |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
Když se podíváme na pět nejdůležitějších funkcí pro tuto datovou sadu, cena domu předpovězená tímto modelem je ovlivněna jeho blízkostí k dálnicím, poměrem učitelů studentů v dané oblasti, blízkost hlavních pracovních center, sazbou daně z nemovitostí a průměrným počtem místností v domácnosti.