Compartir a través de


TextClassificationTrainer Clase

Definición

IEstimator<TTransformer> para entrenar una red neuronal profunda (DNN) para clasificar texto.

public class TextClassificationTrainer : Microsoft.ML.TorchSharp.NasBert.NasBertTrainer<uint,long>
type TextClassificationTrainer = class
    inherit NasBertTrainer<uint32, int64>
Public Class TextClassificationTrainer
Inherits NasBertTrainer(Of UInteger, Long)
Herencia

Comentarios

Para crear este instructor, use TextClassification.

Columnas de entrada y salida

Los datos de columna de la etiqueta de entrada deben ser de tipo clave y las columnas de oración deben ser de tipoTextDataViewType .

Este instructor genera las siguientes columnas:

Nombre de columna de salida Tipo de columna Descripción
PredictedLabel Tipo de clave Índice de la etiqueta de predicción. Si su valor es i, la etiqueta real sería la categoría de i-th en el tipo de etiqueta de entrada con valores de clave.
Score Vector deSingle Puntuaciones de todas las clases. Un valor más alto indica mayor probabilidad de que caigan en la clase asociada. Si el elemento i-th tiene el valor más grande, el índice de la etiqueta de predicción sería i. Tenga en cuenta que i es el índice de base cero.

Características del entrenador

Tarea de Machine Learning Clasificación multiclase
¿Se requiere normalización? No
¿Se requiere el almacenamiento en caché? No
NuGet necesario además de Microsoft.ML Microsoft.ML.TorchSharp y libtorch-cpu o libtorch-cuda-11.3 o cualquiera de las variantes específicas del sistema operativo.
Exportable a ONNX No

Detalles del algoritmo de entrenamiento

Entrena una red neuronal profunda (DNN) aprovechando un modelo NAS-BERT roBERTa previamente entrenado para clasificar texto.

Métodos

Fit(IDataView)

IEstimator<TTransformer> para entrenar una red neuronal profunda (DNN) para clasificar texto.

(Heredado de TorchSharpBaseTrainer<TLabelCol,TTargetsCol>)
GetOutputSchema(SchemaShape)

IEstimator<TTransformer> para entrenar una red neuronal profunda (DNN) para clasificar texto.

(Heredado de NasBertTrainer<TLabelCol,TTargetsCol>)

Se aplica a