Convertire il modello di training PyTorch in ONNX
Nota
Per una maggiore funzionalità, PyTorch può essere usato anche con DirectML in Windows.
Nella fase precedente di questa esercitazione è stato usato PyTorch per creare il modello di Machine Learning. Tuttavia, tale modello è un .pth
file. Per essere in grado di integrarlo con l'app Windows ML, è necessario convertire il modello in formato ONNX.
Esportare il modello
Per esportare un modello, si userà la torch.onnx.export()
funzione . Questa funzione esegue il modello e registra una traccia degli operatori usati per calcolare gli output.
- Copiare il codice seguente nel
PyTorchTraining.py
file in Visual Studio, sopra la funzione principale.
import torch.onnx
#Function to Convert to ONNX
def Convert_ONNX():
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, input_size, requires_grad=True)
# Export the model
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"ImageClassifier.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['modelInput'], # the model's input names
output_names = ['modelOutput'], # the model's output names
dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes
'modelOutput' : {0 : 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
È importante chiamare model.eval()
o model.train(False)
prima di esportare il modello, perché imposta il modello sulla modalità di inferenza. Questa operazione è necessaria perché gli operatori come dropout
o batchnorm
si comportano in modo diverso in modalità di inferenza e training.
- Per eseguire la conversione in ONNX, aggiungere una chiamata alla funzione di conversione alla funzione main. Non è necessario eseguire di nuovo il training del modello, quindi verranno commentato alcune funzioni che non è più necessario eseguire. La funzione principale sarà la seguente.
if __name__ == "__main__":
# Let's build our model
#train(5)
#print('Finished Training')
# Test which classes performed well
#testAccuracy()
# Let's load the model we just created and test the accuracy per label
model = Network()
path = "myFirstModel.pth"
model.load_state_dict(torch.load(path))
# Test with batch of images
#testBatch()
# Test how the classes performed
#testClassess()
# Conversion to ONNX
Convert_ONNX()
- Eseguire di nuovo il progetto selezionando il
Start Debugging
pulsante sulla barra degli strumenti o premendoF5
. Non è necessario eseguire di nuovo il training del modello, è sufficiente caricare il modello esistente dalla cartella del progetto.
L'output sarà il seguente.
Passare al percorso del progetto e trovare il modello ONNX accanto al .pth
modello.
Nota
Sei interessato a scoprire di più? Esaminare l'esercitazione su PyTorch sull'esportazione di un modello.
Esplorare il modello.
Aprire il file del
ImageClassifier.onnx
modello con Netron.Selezionare il nodo dati per aprire le proprietà del modello.
Come si può notare, il modello richiede un oggetto float a 32 bit (matrice multidimensionale) come input e restituisce un valore float Tensor come output. La matrice di output includerà la probabilità per ogni etichetta. La modalità di compilazione del modello, le etichette sono rappresentate da 10 numeri e ogni numero rappresenta le dieci classi di oggetti.
Etichetta 0 | Etichetta 1 | Etichetta 2 | Etichetta 3 | Etichetta 4 | Etichetta 5 | Etichetta 6 | Etichetta 7 | Etichetta 8 | Etichetta 9 |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
piano | car | bird | cat | Cervo | dog | rana | cavallo | nave | truck |
Dovrai estrarre questi valori per visualizzare la stima corretta con l'app Windows ML.
Passaggi successivi
Il modello è pronto per la distribuzione. Successivamente, per l'evento principale, creare un'applicazione Windows ed eseguirla in locale nel dispositivo Windows.