Interpretar previsões do modelo usando Importância do Recurso de Permutação
Usando a PFI (Importância do Recurso de Permutação), saiba como interpretar ML.NET previsões de modelo de machine learning. A PFI informa a contribuição relativa que cada recurso faz a uma previsão.
Modelos de machine learning geralmente são considerados caixas opacas que pegam entradas e geram uma saída. As etapas intermediárias ou as interações entre os recursos que influenciam a saída raramente são compreendidas. Conforme o aprendizado de máquina é introduzido em mais aspectos da vida diária, como serviços de saúde, é de extrema importância entender por que um modelo de machine learning toma as decisões que ele toma. Por exemplo, se os diagnósticos forem feitos por um modelo de machine learning, os profissionais de saúde precisarão de uma maneira de examinar os fatores que contribuíram para esse diagnóstico. Fornecer o diagnóstico certo pode fazer uma grande diferença em se um paciente tem uma recuperação rápida ou não. Portanto, quanto maior o nível de capacidade de explicação de um modelo, mais confiança os profissionais de saúde terão em aceitar ou rejeitar as decisões tomadas pelo modelo.
Várias técnicas são usadas para explicar os modelos, uma delas é a PFI. PFI é uma técnica usada para explicar os modelos de classificação e regressão inspirados pelo artigo de Breiman chamado Random Forests (Florestas aleatórias) (confira a seção 10). Em um alto nível, a maneira como eles funcionam é embaralhando aleatoriamente um recurso de dados por vez para todo o conjunto de dados e calculando o quanto a métrica de desempenho de interesse diminui. Quanto maior a alteração, mais importante é esse recurso.
Além disso, ao realçar os recursos mais importantes, construtores de modelo podem se concentrar no uso de um subconjunto de recursos mais significativos que pode reduzir o ruído e tempo de treinamento.
Carregar os dados
Os recursos no conjunto de dados que está sendo usado para este exemplo estão nas colunas 1 a 12. A meta é prever Price
.
Coluna | Recurso | Descrição |
---|---|---|
1 | CrimeRate | Taxa de criminalidade per capita |
2 | ResidentialZones | Zonas residenciais da cidade |
3 | CommercialZones | Zonas não residenciais da cidade |
4 | NearWater | Proximidade a recursos hídricos |
5 | ToxicWasteLevels | Níveis de toxicidade (PPM) |
6 | AverageRoomNumber | Número médio de ambientes na casa |
7 | HomeAge | Idade da casa |
8 | BusinessCenterDistance | Distância até o bairro comercial mais próximo |
9 | HighwayAccess | Proximidade de rodovias |
10 | TaxRate | Taxa de imposto sobre propriedade |
11 | StudentTeacherRatio | Taxa de alunos para professores |
12 | PercentPopulationBelowPoverty | Percentual da população vivendo abaixo da linha de pobreza |
13 | Preço | Preço da casa |
Um exemplo do conjunto de dados é mostrado abaixo:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Os dados desta amostra podem ser modelados por uma classe como HousingPriceData
e carregados em uma IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Treinar o modelo
O exemplo de código a seguir ilustra o processo de treinamento de um modelo de regressão linear para prever preços de casa.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Explicar o modelo com PFI (Importância de Recurso de Permutação)
No ML.NET, use o método PermutationFeatureImportance
para suas respectivas tarefas.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
O resultado de usar PermutationFeatureImportance
no conjunto de dados de treinamento é um ImmutableArray
de objetos RegressionMetricsStatistics
. RegressionMetricsStatistics
fornece estatísticas resumidas, como média e desvio padrão para várias observações de RegressionMetrics
igual ao número de permutações especificado pelo parâmetro permutationCount
.
A métrica usada para medir a importância do recurso depende da tarefa de machine learning usada para resolver o problema. Por exemplo, as tarefas de regressão podem usar uma métrica de avaliação comum, como R ao quadrado, para medir a importância. Para mais informações sobre métricas de avaliação de modelo, confira Avaliar seu modelo de ML.NET com métricas.
A importância ou, neste caso, a redução média absoluta de métrica R ao quadrado calculada por PermutationFeatureImportance
pode ser ordenada da mais importante para a menos importante.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Imprimir os valores para cada um dos recursos em featureImportanceMetrics
geraria saída semelhante à abaixo. Lembre-se de que você deve esperar ver resultados diferentes, pois esses valores variam conforme os dados apresentados.
Recurso | Alterar para R ao quadrado |
---|---|
HighwayAccess | -0,042731 |
StudentTeacherRatio | -0,012730 |
BusinessCenterDistance | -0,010491 |
TaxRate | -0,008545 |
AverageRoomNumber | -0,003949 |
CrimeRate | -0,003665 |
CommercialZones | 0,002749 |
HomeAge | -0,002426 |
ResidentialZones | -0,002319 |
NearWater | 0,000203 |
PercentPopulationLivingBelowPoverty | 0,000031 |
ToxicWasteLevels | -0,000019 |
Vamos analisar os cinco recursos mais importantes para este conjunto de dados, o preço de uma casa previsto por esse modelo é influenciado pela sua proximidade a rodovias, pela proporção de alunos para professor das escolas na área, pela proximidade com centros de emprego importantes, pela taxa de impostos sobre propriedade e pelo número médio de ambientes na casa.