Interpretowanie przewidywań modelu przy użyciu ważności funkcji permutacji
Korzystając z znaczenia funkcji permutacji (PFI), dowiedz się, jak interpretować przewidywania modelu uczenia maszynowego ML.NET. PFI przedstawia względny wkład każdej cechy w przewidywanie.
Modele uczenia maszynowego są często uważane za nieprzezroczyste pola, które pobierają dane wejściowe i generują dane wyjściowe. Pośrednie kroki lub interakcje między funkcjami, które mają wpływ na dane wyjściowe, są rzadko zrozumiałe. Wraz z wprowadzeniem uczenia maszynowego do większej liczby aspektów codziennego życia, takich jak opieka zdrowotna, niezwykle ważne jest zrozumienie, dlaczego model uczenia maszynowego podejmuje decyzje, które wykonuje. Jeśli na przykład diagnozy są wykonywane przez model uczenia maszynowego, pracownicy służby zdrowia potrzebują sposobu, aby przyjrzeć się czynnikom, które przeszły do tej diagnozy. Zapewnienie właściwej diagnozy może mieć duży wpływ na to, czy pacjent ma szybkie wyzdrowienie, czy nie. W związku z tym, im wyższy jest poziom wyjaśnialności modelu, tym większe zaufanie mają pracownicy służby zdrowia do akceptowania lub odrzucania decyzji podejmowanych przez model.
Różne techniki służą do wyjaśnienia modeli, z których jeden to PFI. PfI to technika służąca do wyjaśnienia modeli klasyfikacji i regresji inspirowanych lasami losowymi Breimana papieru (patrz sekcja 10). Ogólnie rzecz biorąc, jego działanie polega na losowym przetasowaniu danych jednej cechy naraz dla całego zestawu danych i obliczeniu, o ile zmniejsza się interesująca nas miara wydajności. Im większa zmiana, tym ważniejsza jest ta cecha.
Ponadto, podkreślając najważniejsze cechy, konstruktorzy modeli mogą skupić się na używaniu podzestawu bardziej znaczących cech, co może zmniejszyć ilość szumów i czas nauki.
Ładowanie danych
Funkcje w zestawie danych używanym dla tego przykładu znajdują się w kolumnach 1–12. Celem jest przewidywanie Price
.
Kolumna | Cecha | Opis |
---|---|---|
1 | CrimeRate | Wskaźnik przestępczości na mieszkańca |
2 | Strefy mieszkalne | Strefy mieszkalne w mieście |
3 | Komercyjne strefy | Strefy nieresidentialne w mieście |
4 | NearWater | Bliskość ciała wody |
5 | PoziomyToksycznychOdpadów | Poziomy toksyczności (PPM) |
6 | ŚredniaLiczbaPokoi | Średnia liczba pomieszczeń w domu |
7 | HomeAge | Wiek domu |
8 | BusinessCenterDistance | Odległość do najbliższej dzielnicy biznesowej |
9 | Dostęp do autostrady | Bliskość autostrad |
10 | Stawka podatku | Stawka podatku od nieruchomości |
11 | Stosunek studentów do nauczycieli | Stosunek uczniów do nauczycieli |
12 | ProcentPopulacjiPoniżejGraniczyUbóstwa | Procent ludności żyjących poniżej ubóstwa |
13 | Cena | Cena domu |
Poniżej przedstawiono przykładowy zestaw danych:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Dane w tym przykładzie mogą być modelowane przez klasę, taką jak HousingPriceData
i ładowane do IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Trenowanie modelu
Poniższy przykładowy kod ilustruje proces trenowania modelu regresji liniowej w celu przewidywania cen domów.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);
Wyjaśnij model przy użyciu ważności cechy permutacji (PFI)
W ML.NET użyj metody PermutationFeatureImportance
dla odpowiedniego zadania.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Wynikiem używania PermutationFeatureImportance
w zestawie danych szkoleniowych jest ImmutableArray
obiektów RegressionMetricsStatistics
.
RegressionMetricsStatistics
zawiera podsumowanie statystyk, takich jak średnia i odchylenie standardowe dla wielu obserwacji RegressionMetrics
równych liczbie permutacji określonych przez parametr permutationCount
.
Metryka używana do mierzenia ważności funkcji zależy od zadania uczenia maszynowego używanego do rozwiązania problemu. Na przykład zadania regresji mogą używać typowej metryki oceny, takiej jak R-squared, aby mierzyć znaczenie. Aby uzyskać więcej informacji na temat metryk oceny modelu, zobacz ocena modelu ML.NET za pomocą metryk.
Znaczenie, lub w tym przypadku, bezwzględny średni spadek metryki R kwadrat, obliczony przez PermutationFeatureImportance
, może zostać uporządkowany od najważniejszych do najmniej ważnych.
// Order features by importance.
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Drukowanie wartości dla każdej z funkcji w featureImportanceMetrics
generuje dane wyjściowe podobne do poniższych danych wyjściowych. Powinny pojawić się różne wyniki, ponieważ te wartości różnią się w zależności od podanych danych.
Cecha | Zmień na R-Squared |
---|---|
Dostęp do autostrady | -0.042731 |
Wskaźnik Uczniowie-Nauczyciele | -0.012730 |
BusinessCenterDistance | -0.010491 |
Stawka Podatkowa | -0.008545 |
ŚredniaLiczbaPokoi | -0.003949 |
CrimeRate | -0.003665 |
Komercyjne strefy | 0.002749 |
HomeAge | -0.002426 |
Strefy mieszkalne | -0.002319 |
NearWater | 0.000203 |
Procent populacji żyjącej poniżej granicy ubóstwa | 0.000031 |
PoziomyToksycznychOdpadów | -0.000019 |
Jeśli spojrzysz na pięć najważniejszych funkcji tego zestawu danych, cena domu przewidywanego przez ten model ma wpływ na jego bliskość do autostrad, stosunek nauczycieli do szkół w okolicy, bliskość głównych ośrodków zatrudnienia, stawka podatku od nieruchomości i średnia liczba pomieszczeń w domu.