Partilhar via


TensorFlowEstimator Classe

Definição

Ele TensorFlowTransformer é usado nos dois cenários a seguir.

  1. Pontuação com o modelo TensorFlow pré-treinado: nesse modo, a transformação extrai os valores das camadas ocultas de um modelo tensorflow pré-treinado e usa saídas como recursos em ML.Net pipeline.
  2. Retreinamento do modelo TensorFlow : nesse modo, a transformação treina novamente um modelo TensorFlow usando os dados do usuário passados por ML.Net pipeline. Depois que o modelo é treinado, suas saídas podem ser usadas como recursos para pontuação.
public sealed class TensorFlowEstimator : Microsoft.ML.IEstimator<Microsoft.ML.Transforms.TensorFlowTransformer>
type TensorFlowEstimator = class
    interface IEstimator<TensorFlowTransformer>
Public NotInheritable Class TensorFlowEstimator
Implements IEstimator(Of TensorFlowTransformer)
Herança
TensorFlowEstimator
Implementações

Comentários

O TensorFlowTransform extrai saídas especificadas usando um modelo tensorflow pré-treinado. Opcionalmente, ele pode treinar ainda mais o modelo TensorFlow em dados do usuário para ajustar parâmetros de modelo nos dados do usuário ( também conhecido como "Transferir Aprendizado").

Para pontuação, a transformação usa como entradas o modelo tensorflow pré-treinado, os nomes dos nós de entrada e os nomes dos nós de saída cujos valores queremos extrair. Para treinar novamente, a transformação também requer parâmetros relacionados ao treinamento, como os nomes da operação de otimização no grafo TensorFlow, o nome da operação de taxa de aprendizado no grafo e seu valor, o nome das operações no grafo para calcular a perda e a métrica de desempenho etc.

Essa transformação requer que o nuget Microsoft.ML.TensorFlow seja instalado. O TensorFlowTransform tem as seguintes suposições sobre entrada, saída, processamento de dados e retreinamento.

  1. Para o modelo de entrada, atualmente o TensorFlowTransform dá suporte ao formato de modelo Congelado e também ao formato SavedModel . No entanto, o retreinamento do modelo só é possível para o formato SavedModel . Atualmente, o formato de ponto de verificação não tem suporte para pontuação nem para treinar novamente devido à falta de suporte à API C do TensorFlow para carregá-lo.
  2. A transformação dá suporte à pontuação de apenas um exemplo por vez. No entanto, o retreinamento pode ser realizado em lotes.
  3. Cenários avançados de aprendizado de transferência/ajuste fino (por exemplo, adicionar mais camadas à rede, alterar a forma de entradas, congelar as camadas que não precisam ser atualizadas durante o processo de retreinamento etc.) atualmente não são possíveis devido à falta de suporte para manipulação de rede/grafo dentro do modelo usando a API C do TensorFlow.
  4. O nome das colunas de entrada deve corresponder ao nome das entradas no modelo TensorFlow.
  5. O nome de cada coluna de saída deve corresponder a uma das operações no grafo TensorFlow.
  6. Atualmente, double, float, long, int, short, sbyte, ulong, uint, ushort, byte e bool são os tipos de dados aceitáveis para entrada/saída.
  7. Após o sucesso, a transformação introduzirá uma nova coluna correspondente IDataView a cada coluna de saída especificada.

As entradas e saídas de um modelo tensorFlow podem ser obtidas usando as GetModelSchema() ferramentas ou summarize_graph .

Métodos

Fit(IDataView)

Treina e retorna um TensorFlowTransformer.

GetOutputSchema(SchemaShape)

Retorna o SchemaShape esquema que será produzido pelo transformador. Usado para propagação e verificação de esquema em um pipeline.

Métodos de Extensão

AppendCacheCheckpoint<TTrans>(IEstimator<TTrans>, IHostEnvironment)

Acrescente um 'ponto de verificação de cache' à cadeia do avaliador. Isso garantirá que os estimadores downstream sejam treinados em relação aos dados armazenados em cache. É útil ter um ponto de verificação de cache antes dos treinadores que levam vários passes de dados.

WithOnFitDelegate<TTransformer>(IEstimator<TTransformer>, Action<TTransformer>)

Dado um avaliador, retorne um objeto de encapsulamento que chamará um delegado uma vez Fit(IDataView) que seja chamado. Geralmente, é importante que um avaliador retorne informações sobre o que estava em forma, e é por isso que o Fit(IDataView) método retorna um objeto especificamente tipado, em vez de apenas um geral ITransformer. No entanto, ao mesmo tempo, IEstimator<TTransformer> muitas vezes são formados em pipelines com muitos objetos, portanto, talvez seja necessário criar uma cadeia de avaliadores por meio EstimatorChain<TLastTransformer> de onde o estimador para o qual queremos obter o transformador está enterrado em algum lugar nesta cadeia. Para esse cenário, podemos por meio desse método anexar um delegado que será chamado assim que o ajuste for chamado.

Aplica-se a