순열 기능 중요도를 사용하여 모델 예측 해석
PFI(순열 기능 중요도)를 사용하여 ML.NET 기계 학습 모델 예측을 해석하는 방법을 알아봅니다. PFI를 통해 각 기능이 예측에 공헌하는 상대적 기여도를 알 수 있습니다.
기계 학습 모델은 종종 입력을 받아 출력을 생성하는 불투명한 상자처럼 여겨지곤 합니다. 출력에 영향을 미치는 중간 단계 또는 기능 간 상호 작용은 거의 해석되지 않습니다. 의료 등, 다양한 일상에서 기계 학습의 도입이 증가하면서 기계 학습 모델이 그러한 의사 결정을 내리게 되는 이유를 해석하는 것이 매우 중요해졌습니다. 예를 들어, 기계 학습 모델에서 진단을 수행할 경우 의료 전문가가 해당 진단에 반영된 요소를 살펴볼 방법이 필요합니다. 올바른 진단을 제공하면 환자의 빠른 회복 여부에 큰 차이를 낼 수 있습니다. 따라서 모델의 설명 가능한 수준이 높을수록 의료 전문가는 더 자신 있게 모델의 의사 결정을 수락 또는 거부할 수 있습니다.
다양한 기법을 사용하여 모델을 설명하며, 그 방법 중 하나가 PFI입니다. PFI는 Breiman의 랜덤 포리스트 논문(섹션 10 참조)에서 착안한 분류 및 회귀 모델을 설명하는 데 사용되는 기법입니다. 높은 수준에서는, 전체 데이터 세트에 대해 한 번에 한 기능씩 임의로 데이터를 섞고 해당 성능 메트릭이 얼마나 감소하는지를 산출하는 방식으로 작동합니다. 변화가 클수록 해당 기능이 중요한 것입니다.
또한 가장 중요한 기능을 강조 표시하므로 모델을 빌드하는 사람이 노이즈 및 학습 시간을 줄일 수 있는 더 의미 있는 기능의 하위 집합에 주력할 수 있습니다.
데이터 로드
이 샘플에 사용되는 데이터 세트의 기능은 1-12열에 있습니다. 목표는 Price
예측입니다.
Column | 기능 | 설명 |
---|---|---|
1 | CrimeRate | 인당 범죄율 |
2 | ResidentialZones | 도시 내 주거지 |
3 | CommercialZones | 도시 내 비주거지 |
4 | NearWater | 수원 근접성 |
5 | ToxicWasteLevels | 독성 물질 수준(PPM) |
6 | AverageRoomNumber | 가구 내 평균 방 수 |
7 | HomeAge | 가구 연령 |
8 | BusinessCenterDistance | 가장 가까운 비즈니스 지구까지 거리 |
9 | HighwayAccess | 고속도로 근접성 |
10 | TaxRate | 재산세율 |
11 | StudentTeacherRatio | 교사 학생 비율 |
12 | PercentPopulationBelowPoverty | 빈곤 인구 비율 |
13 | 가격 | 주택 가격 |
데이터 세트의 샘플은 다음과 같습니다.
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
이 샘플의 데이터는 HousingPriceData
같은 클래스로 모델링하고 IDataView
에 로드할 수 있습니다.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
모델 학습
다음 코드 샘플은 선형 회귀 모델을 학습하여 주택 가격을 예측하는 프로세스를 보여 줍니다.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
PFI(순열 기능 중요도)를 사용하여 모델 설명
ML.NET에서는 해당 작업에 PermutationFeatureImportance
메서드를 사용합니다.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
학습 데이터 세트에 PermutationFeatureImportance
를 사용한 결과는 개체의 RegressionMetricsStatistics
ImmutableArray
입니다. RegressionMetricsStatistics
는 매개 변수에서 지정한 순열 수에 해당하는 여러 permutationCount
RegressionMetrics
관찰에 대해 평균, 표준 편차 같은 요약 통계를 제공합니다.
기능 중요도를 측정하는 데 사용되는 메트릭은 문제를 해결하는 데 사용되는 기계 학습 작업에 따라 다릅니다. 예를 들어 회귀 작업은 R 제곱과 같은 일반적인 평가 메트릭을 사용하여 중요도를 측정할 수 있습니다. 모델 평가 메트릭에 대한 자세한 내용은 메트릭을 사용하여 ML.NET 모델 평가를 참조하세요.
중요도, 이 경우 PermutationFeatureImportance
에서 계산한 R 제곱 메트릭의 절대 평균 감소를 가장 중요함에서 가장 중요하지 않음의 순서로 지정할 수 있습니다.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
featureImportanceMetrics
에서 각 기능에 대한 값을 출력하면 다음과 유사한 출력이 생성됩니다. 이 값은 제공된 데이터에 따라 달라지므로 결과가 다르게 보일 수 있습니다.
기능 | R 제곱으로 변경 |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
이 데이터 세트에 가장 중요한 기능 5개를 살펴보면 이 모델이 예측한 주택 가격은 고속도로 근접성, 지역 내 학교의 학생-교사 비율, 주요 업무 지구 근접성, 재산세율, 주택의 평균 방 수의 영향을 받습니다.
다음 단계
.NET