Udostępnij za pośrednictwem


Interpretowanie przewidywań modelu przy użyciu ważności funkcji permutacji

Korzystając z znaczenia funkcji permutacji (PFI), dowiedz się, jak interpretować przewidywania modelu uczenia maszynowego ML.NET. PFI przedstawia względny wkład każdej cechy w przewidywanie.

Modele uczenia maszynowego są często uważane za nieprzezroczyste pola, które pobierają dane wejściowe i generują dane wyjściowe. Pośrednie kroki lub interakcje między funkcjami, które mają wpływ na dane wyjściowe, są rzadko zrozumiałe. Wraz z wprowadzeniem uczenia maszynowego do większej liczby aspektów codziennego życia, takich jak opieka zdrowotna, niezwykle ważne jest zrozumienie, dlaczego model uczenia maszynowego podejmuje decyzje, które wykonuje. Jeśli na przykład diagnozy są wykonywane przez model uczenia maszynowego, pracownicy służby zdrowia potrzebują sposobu, aby przyjrzeć się czynnikom, które przeszły do tej diagnozy. Zapewnienie właściwej diagnozy może mieć duży wpływ na to, czy pacjent ma szybkie wyzdrowienie, czy nie. W związku z tym, im wyższy jest poziom wyjaśnialności modelu, tym większe zaufanie mają pracownicy służby zdrowia do akceptowania lub odrzucania decyzji podejmowanych przez model.

Różne techniki służą do wyjaśnienia modeli, z których jeden to PFI. PfI to technika służąca do wyjaśnienia modeli klasyfikacji i regresji inspirowanych lasami losowymi Breimana papieru (patrz sekcja 10). Ogólnie rzecz biorąc, jego działanie polega na losowym przetasowaniu danych jednej cechy naraz dla całego zestawu danych i obliczeniu, o ile zmniejsza się interesująca nas miara wydajności. Im większa zmiana, tym ważniejsza jest ta cecha.

Ponadto, podkreślając najważniejsze cechy, konstruktorzy modeli mogą skupić się na używaniu podzestawu bardziej znaczących cech, co może zmniejszyć ilość szumów i czas nauki.

Ładowanie danych

Funkcje w zestawie danych używanym dla tego przykładu znajdują się w kolumnach 1–12. Celem jest przewidywanie Price.

Kolumna Cecha Opis
1 CrimeRate Wskaźnik przestępczości na mieszkańca
2 Strefy mieszkalne Strefy mieszkalne w mieście
3 Komercyjne strefy Strefy nieresidentialne w mieście
4 NearWater Bliskość ciała wody
5 PoziomyToksycznychOdpadów Poziomy toksyczności (PPM)
6 ŚredniaLiczbaPokoi Średnia liczba pomieszczeń w domu
7 HomeAge Wiek domu
8 BusinessCenterDistance Odległość do najbliższej dzielnicy biznesowej
9 Dostęp do autostrady Bliskość autostrad
10 Stawka podatku Stawka podatku od nieruchomości
11 Stosunek studentów do nauczycieli Stosunek uczniów do nauczycieli
12 ProcentPopulacjiPoniżejGraniczyUbóstwa Procent ludności żyjących poniżej ubóstwa
13 Cena Cena domu

Poniżej przedstawiono przykładowy zestaw danych:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Dane w tym przykładzie mogą być modelowane przez klasę, taką jak HousingPriceData i ładowane do IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Trenowanie modelu

Poniższy przykładowy kod ilustruje proces trenowania modelu regresji liniowej w celu przewidywania cen domów.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);

Wyjaśnij model przy użyciu ważności cechy permutacji (PFI)

W ML.NET użyj metody PermutationFeatureImportance dla odpowiedniego zadania.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Wynikiem używania PermutationFeatureImportance w zestawie danych szkoleniowych jest ImmutableArray obiektów RegressionMetricsStatistics. RegressionMetricsStatistics zawiera podsumowanie statystyk, takich jak średnia i odchylenie standardowe dla wielu obserwacji RegressionMetrics równych liczbie permutacji określonych przez parametr permutationCount.

Metryka używana do mierzenia ważności funkcji zależy od zadania uczenia maszynowego używanego do rozwiązania problemu. Na przykład zadania regresji mogą używać typowej metryki oceny, takiej jak R-squared, aby mierzyć znaczenie. Aby uzyskać więcej informacji na temat metryk oceny modelu, zobacz ocena modelu ML.NET za pomocą metryk.

Znaczenie, lub w tym przypadku, bezwzględny średni spadek metryki R kwadrat, obliczony przez PermutationFeatureImportance, może zostać uporządkowany od najważniejszych do najmniej ważnych.

// Order features by importance.
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Drukowanie wartości dla każdej z funkcji w featureImportanceMetrics generuje dane wyjściowe podobne do poniższych danych wyjściowych. Powinny pojawić się różne wyniki, ponieważ te wartości różnią się w zależności od podanych danych.

Cecha Zmień na R-Squared
Dostęp do autostrady -0.042731
Wskaźnik Uczniowie-Nauczyciele -0.012730
BusinessCenterDistance -0.010491
Stawka Podatkowa -0.008545
ŚredniaLiczbaPokoi -0.003949
CrimeRate -0.003665
Komercyjne strefy 0.002749
HomeAge -0.002426
Strefy mieszkalne -0.002319
NearWater 0.000203
Procent populacji żyjącej poniżej granicy ubóstwa 0.000031
PoziomyToksycznychOdpadów -0.000019

Jeśli spojrzysz na pięć najważniejszych funkcji tego zestawu danych, cena domu przewidywanego przez ten model ma wpływ na jego bliskość do autostrad, stosunek nauczycieli do szkół w okolicy, bliskość głównych ośrodków zatrudnienia, stawka podatku od nieruchomości i średnia liczba pomieszczeń w domu.

Następne kroki