Conversión del modelo de entrenamiento de PyTorch a ONNX
Nota:
Para obtener una mayor funcionalidad, PyTorch también se puede usar con DirectML en Windows.
En la fase anterior de este tutorial usó PyTorch para crear el modelo de aprendizaje automático. Sin embargo, ese modelo es un archivo .pth
. Para poder integrarlo con la aplicación Windows ML, deberá convertir el modelo al formato ONNX.
Exportación del modelo
Para exportar un modelo, debe usar la función torch.onnx.export()
. Esta función ejecuta el modelo y registra un seguimiento de los operadores que se usan para calcular las salidas.
- Copie el código siguiente en el archivo
PyTorchTraining.py
de Visual Studio, encima de la función principal.
import torch.onnx
#Function to Convert to ONNX
def Convert_ONNX():
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, input_size, requires_grad=True)
# Export the model
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"ImageClassifier.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['modelInput'], # the model's input names
output_names = ['modelOutput'], # the model's output names
dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes
'modelOutput' : {0 : 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
Es importante llamar a model.eval()
o a model.train(False)
antes de exportar el modelo, ya que esto establece el modelo en el modo de inferencia. Esto es necesario porque los operadores como dropout
o batchnorm
se comportan de forma diferente en el modo de inferencia y entrenamiento.
- Para ejecutar la conversión a ONNX, agregue una llamada a la función de conversión a la función principal. No es necesario volver a entrenar el modelo, por lo que comentaremos algunas funciones que ya no necesitamos ejecutar. La función principal será la siguiente.
if __name__ == "__main__":
# Let's build our model
#train(5)
#print('Finished Training')
# Test which classes performed well
#testAccuracy()
# Let's load the model we just created and test the accuracy per label
model = Network()
path = "myFirstModel.pth"
model.load_state_dict(torch.load(path))
# Test with batch of images
#testBatch()
# Test how the classes performed
#testClassess()
# Conversion to ONNX
Convert_ONNX()
- Vuelva a ejecutar el proyecto seleccionando el botón
Start Debugging
de la barra de herramientas o presionandoF5
. No es necesario volver a entrenar el modelo; simplemente, cargue el modelo existente desde la carpeta del proyecto.
La salida debería ser similar a la siguiente.
Vaya a la ubicación del proyecto y busque el modelo ONNX junto al modelo .pth
.
Nota:
¿Quiere saber más sobre el tema? Revise el tutorial de PyTorch sobre la exportación de un modelo.
Explore el modelo.
Abra el archivo de modelo
ImageClassifier.onnx
con Netron.Seleccione el nodo de datos para abrir las propiedades del modelo.
Como puede ver, el modelo requiere un objeto de tensor flotante de 32 bits (esto es, una matriz multidimensional) como entrada y devuelve un valor "float" de Tensor como salida. La matriz de salida incluirá la probabilidad de cada etiqueta. Según la forma en que creó el modelo, las etiquetas se representarán mediante 10 números y cada número representará las diez clases de objetos.
Etiqueta 0 | Etiqueta 1 | Etiqueta 2 | Etiqueta 3 | Etiqueta 4 | Etiqueta 5 | Etiqueta 6 | Etiqueta 7 | Etiqueta 8 | Etiqueta 9 |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
avión | automóvil | bird | cat | ciervo | perro | rana | caballo | barco | camión |
Deberá extraer estos valores para mostrar la predicción correcta con aplicación la Windows ML.
Pasos siguientes
El modelo está listo para implementarse. A continuación, para el evento principal, tendrá que compilar una aplicación de Windows y ejecutarla localmente en el dispositivo Windows.