Interpretowanie przewidywań modelu przy użyciu ważności funkcji permutacji
Korzystając z znaczenia funkcji permutacji (PFI), dowiedz się, jak interpretować przewidywania modelu uczenia maszynowego ML.NET. Interfejs PFI zapewnia względny wkład każdej funkcji w przewidywanie.
Modele uczenia maszynowego są często uważane za nieprzezroczyste pola, które pobierają dane wejściowe i generują dane wyjściowe. Pośrednie kroki lub interakcje między funkcjami, które mają wpływ na dane wyjściowe, są rzadko zrozumiałe. Ponieważ uczenie maszynowe jest wprowadzane do większej liczby aspektów codziennego życia, takich jak opieka zdrowotna, niezwykle ważne jest zrozumienie, dlaczego model uczenia maszynowego podejmuje decyzje, które wykonuje. Jeśli na przykład diagnozy są wykonywane przez model uczenia maszynowego, pracownicy służby zdrowia potrzebują sposobu na przyjrzenie się czynnikom, które zostały wprowadzone w celu zdiagnozowania. Zapewnienie właściwej diagnozy może mieć duży wpływ na to, czy pacjent ma szybkie wyzdrowienie, czy nie. W związku z tym im wyższy poziom możliwości wyjaśnienia w modelu, tym większe zaufanie pracowników służby zdrowia musi zaakceptować lub odrzucić decyzje podjęte przez model.
Różne techniki służą do wyjaśnienia modeli, z których jeden to PFI. PFI to technika służąca do wyjaśnienia modeli klasyfikacji i regresji, które są inspirowane dokumentem Lasy losowe Breimana (patrz sekcja 10). Na wysokim poziomie sposób, w jaki działa, polega na losowym przetasowaniu danych jednej funkcji naraz dla całego zestawu danych i obliczeniu, ile metryki wydajności odsetek spada. Większa zmiana, tym ważniejsze jest to, że funkcja jest.
Ponadto, podkreślając najważniejsze funkcje, konstruktorzy modeli mogą skupić się na używaniu podzestawu bardziej znaczących funkcji, które mogą potencjalnie zmniejszyć czas uczenia i szumu.
Ładowanie danych
Funkcje zestawu danych używanego dla tego przykładu znajdują się w kolumnach 1–12. Celem jest przewidywanie Price
wartości .
Kolumna | Funkcja | Opis |
---|---|---|
1 | CrimeRate | Wskaźnik przestępczości na mieszkańca |
2 | Strefy mieszkalne | Strefy mieszkalne w mieście |
3 | Komercyjna strefa | Strefy niemiejętne w mieście |
100 | NearWater | Bliskość ciała wody |
5 | ToxicWasteLevels | Poziomy toksyczności (PPM) |
6 | AverageRoomNumber | Średnia liczba pomieszczeń w domu |
7 | HomeAge | Wiek domu |
8 | BusinessCenterDistance | Odległość do najbliższej dzielnicy biznesowej |
9 | HighwayAccess | Bliskość autostrad |
10 | TaxRate | Stawka podatku od nieruchomości |
11 | StudentTeacherRatio | Stosunek uczniów do nauczycieli |
12 | PercentPopulationBelowPoverty | Procent ludności żyjących poniżej ubóstwa |
13 | Cena | Cena domu |
Poniżej przedstawiono przykładowy zestaw danych:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Dane w tym przykładzie mogą być modelowane przez klasę, taką jak HousingPriceData
i załadowaną do klasy IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Szkolenie modelu
Poniższy przykład kodu ilustruje proces trenowania modelu regresji liniowej w celu przewidywania cen domów.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Wyjaśnienie modelu przy użyciu ważności funkcji permutacji (PFI)
W ML.NET użyj metody dla odpowiedniego PermutationFeatureImportance
zadania.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Wynikiem użycia PermutationFeatureImportance
zestawu danych trenowania jest ImmutableArray
RegressionMetricsStatistics
obiekt. RegressionMetricsStatistics
Zawiera podsumowanie statystyk, takich jak średnia i odchylenie standardowe dla wielu obserwacji RegressionMetrics
równych liczbie permutacji określonych przez permutationCount
parametr .
Metryka używana do mierzenia ważności funkcji zależy od zadania uczenia maszynowego używanego do rozwiązania problemu. Na przykład zadania regresji mogą używać typowej metryki oceny, takiej jak R-squared, aby mierzyć znaczenie. Aby uzyskać więcej informacji na temat metryk oceny modelu, zobacz ocena modelu ML.NET za pomocą metryk.
Znaczenie, lub w tym przypadku, bezwzględny spadek metryki R kwadrat obliczony przez PermutationFeatureImportance
może następnie być uporządkowany od najważniejszego do najmniej ważnego.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Drukowanie wartości dla każdej z funkcji w pliku featureImportanceMetrics
spowoduje wygenerowanie danych wyjściowych podobnych do poniższych. Należy pamiętać, że powinny być widoczne różne wyniki, ponieważ te wartości różnią się w zależności od podanych danych.
Funkcja | Zmień na R-Squared |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
Komercyjna strefa | 0.002749 |
HomeAge | -0.002426 |
Strefy mieszkalne | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
Patrząc na pięć najważniejszych funkcji tego zestawu danych, cena domu przewidywanego przez ten model jest pod wpływem jego sąsiedztwa do autostrad, współczynnik nauczycieli nauczycieli szkół w okolicy, bliskość głównych ośrodków zatrudnienia, stawka podatku od nieruchomości i średnia liczba pomieszczeń w domu.