Condividi tramite


Convertire il modello di training PyTorch in ONNX

Nota

Per una maggiore funzionalità, PyTorch può essere usato anche con DirectML in Windows.

Nella fase precedente di questa esercitazione è stato usato PyTorch per creare il modello di Machine Learning. Tuttavia, tale modello è un .pth file. Per essere in grado di integrarlo con l'app Windows ML, è necessario convertire il modello in formato ONNX.

Esportare il modello

Per esportare un modello, si userà la torch.onnx.export() funzione . Questa funzione esegue il modello e registra una traccia degli operatori usati per calcolare gli output.

  1. Copiare il codice seguente nel PyTorchTraining.py file in Visual Studio, sopra la funzione principale.
import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "ImageClassifier.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

È importante chiamare model.eval() o model.train(False) prima di esportare il modello, perché imposta il modello sulla modalità di inferenza. Questa operazione è necessaria perché gli operatori come dropout o batchnorm si comportano in modo diverso in modalità di inferenza e training.

  1. Per eseguire la conversione in ONNX, aggiungere una chiamata alla funzione di conversione alla funzione main. Non è necessario eseguire di nuovo il training del modello, quindi verranno commentato alcune funzioni che non è più necessario eseguire. La funzione principale sarà la seguente.
if __name__ == "__main__": 

    # Let's build our model 
    #train(5) 
    #print('Finished Training') 

    # Test which classes performed well 
    #testAccuracy() 

    # Let's load the model we just created and test the accuracy per label 
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 

    # Test with batch of images 
    #testBatch() 
    # Test how the classes performed 
    #testClassess() 
 
    # Conversion to ONNX 
    Convert_ONNX() 
  1. Eseguire di nuovo il progetto selezionando il Start Debugging pulsante sulla barra degli strumenti o premendo F5. Non è necessario eseguire di nuovo il training del modello, è sufficiente caricare il modello esistente dalla cartella del progetto.

L'output sarà il seguente.

Processo di conversione ONNX

Passare al percorso del progetto e trovare il modello ONNX accanto al .pth modello.

Nota

Sei interessato a scoprire di più? Esaminare l'esercitazione su PyTorch sull'esportazione di un modello.

Esplorare il modello.

  1. Aprire il file del ImageClassifier.onnx modello con Netron.

  2. Selezionare il nodo dati per aprire le proprietà del modello.

Proprietà del modello ONNX

Come si può notare, il modello richiede un oggetto float a 32 bit (matrice multidimensionale) come input e restituisce un valore float Tensor come output. La matrice di output includerà la probabilità per ogni etichetta. La modalità di compilazione del modello, le etichette sono rappresentate da 10 numeri e ogni numero rappresenta le dieci classi di oggetti.

Etichetta 0 Etichetta 1 Etichetta 2 Etichetta 3 Etichetta 4 Etichetta 5 Etichetta 6 Etichetta 7 Etichetta 8 Etichetta 9
0 1 2 3 4 5 6 7 8 9
piano car bird cat Cervo dog rana cavallo nave truck

Dovrai estrarre questi valori per visualizzare la stima corretta con l'app Windows ML.

Passaggi successivi

Il modello è pronto per la distribuzione. Successivamente, per l'evento principale, creare un'applicazione Windows ed eseguirla in locale nel dispositivo Windows.