Classificação multiclasse

Concluído

A classificação multiclasse é usada para prever a qual das várias classes possíveis uma observação pertence. Como uma técnica de aprendizado de máquina supervisionada, ela segue o mesmo processo de treinamento iterativo , validação e avaliação como regressão e classificação binária na qual um subconjunto dos dados de treinamento é retido para validar o modelo treinado.

Exemplo - classificação multiclasse

Algoritmos de classificação multiclasse são usados para calcular valores de probabilidade para rótulos de várias classes, permitindo que um modelo preveja a classe mais provável para uma determinada observação.

Vamos explorar um exemplo em que temos algumas observações de pinguins, em que o comprimento da barbatana (x) de cada pinguim é registrado. Para cada observação, os dados incluem a espécie de pinguim (y), que é codificada da seguinte forma:

  • 0: Adélia
  • 1: Gentoo
  • 2: Chinstrap

Nota

Como nos exemplos anteriores neste módulo, um cenário real incluiria vários valores de recurso (x). Usaremos um único recurso para manter as coisas simples.

Diagram of a measuring ruler. Diagram of three penguins.
Comprimento da barbatana (x) Espécie (y)
167 0
172 0
225 2
197 1
189 1
232 2
158 0

Treinamento de um modelo de classificação multiclasse

Para treinar um modelo de classificação multiclasse, precisamos usar um algoritmo para ajustar os dados de treinamento a uma função que calcula um valor de probabilidade para cada classe possível. Há dois tipos de algoritmo que você pode usar para fazer isso:

  • Algoritmos One-vs-Rest (OvR)
  • Algoritmos multinomiais

Algoritmos One-vs-Rest (OvR)

Os algoritmos One-vs-Rest treinam uma função de classificação binária para cada classe, cada um calculando a probabilidade de que a observação seja um exemplo da classe alvo. Cada função calcula a probabilidade da observação ser uma classe específica em comparação com qualquer outra classe. Para o nosso modelo de classificação de espécies de pinguins, o algoritmo criaria essencialmente três funções binárias de classificação:

  • f0(x) = P(y=0 | x)
  • f1(x) = P(y=1 | x)
  • f2(x) = P(y=2 | x)

Cada algoritmo produz uma função sigmoide que calcula um valor de probabilidade entre 0,0 e 1,0. Um modelo treinado usando este tipo de algoritmo prevê a classe para a função que produz a saída de maior probabilidade.

Algoritmos multinomiais

Como uma abordagem alternativa é usar um algoritmo multinomial, que cria uma única função que retorna uma saída de vários valores. A saída é um vetor (uma matriz de valores) que contém a distribuição de probabilidade para todas as classes possíveis - com uma pontuação de probabilidade para cada classe que, quando totalizada, soma 1,0:

f(x) =[P(y=0|x), P(y=1|x), P(y=2|x)]

Um exemplo deste tipo de função é uma função softmax , que poderia produzir uma saída como o exemplo a seguir:

[0.2, 0.3, 0.5]

Os elementos no vetor representam as probabilidades para as classes 0, 1 e 2, respectivamente; Assim, neste caso, a classe com maior probabilidade é 2.

Independentemente do tipo de algoritmo usado, o modelo usa a função resultante para determinar a classe mais provável para um determinado conjunto de recursos (x) e prevê o rótulo de classe correspondente (y).

Avaliação de um modelo de classificação multiclasse

Você pode avaliar um classificador multiclasse calculando métricas de classificação binária para cada classe individual. Como alternativa, você pode calcular métricas agregadas que levam todas as classes em consideração.

Vamos supor que validamos nosso classificador multiclasse e obtivemos os seguintes resultados:

Comprimento da barbatana (x) Espécies reais (y) Espécies previstas (ŷ)
165 0 0
171 0 0
205 2 1
195 1 1
183 1 1
221 2 2
214 2 2

A matriz de confusão para um classificador multiclasse é semelhante à de um classificador binário, exceto que mostra o número de previsões para cada combinação de rótulos de classe previstos (ŷ) e reais (y):

Diagram of a multiclass confusion matrix.

A partir desta matriz de confusão, podemos determinar as métricas para cada classe individual da seguinte forma:

Classe TP TN FP FN Precisão Recuperar Precisão Pontuação F1
0 2 5 0 0 1.0 1.0 1.0 1.0
1 2 4 1 0 0.86 1.0 0.67 0.8
2 2 4 0 1 0.86 0.67 1.0 0.8

Para calcular as métricas gerais de precisão, recuperação e precisão, use o total das métricas TP, TN, FP e FN:

  • Precisão total = (13+6)÷(13+6+1+1) = 0,90
  • Recordação global = 6÷(6+1) = 0,86
  • Precisão total = 6÷(6+1) = 0,86

A pontuação geral de F1 é calculada usando as métricas gerais de recall e precisão:

  • Pontuação F1 geral = (2x0,86x0,86)÷(0,86+0,86) = 0,86