Freigeben über


Konvertieren Ihres PyTorch-Trainingmodells in ONNX

Hinweis

Für eine größere Funktionalität kann PyTorch auch mit DirectML unter Windows verwendet werden.

In der vorherigen Phase dieses Tutorials haben wir PyTorch verwendet, um unser Machine Learning-Modell zu erstellen. Dieses Modell ist jedoch eine .pth-Datei. Um es in die Windows ML-App integrieren zu können, müssen Sie das Modell in das ONNX-Format konvertieren.

Exportieren des Modells

Zum Exportieren eines Modells verwenden Sie die torch.onnx.export()-Funktion. Diese Funktion führt das Modell aus und zeichnet eine Ablaufverfolgung auf, welche Operatoren zum Berechnen der Ausgaben verwendet werden.

  1. Kopieren Sie den folgenden Code in die PyTorchTraining.py-Datei in Visual Studio oberhalb Ihrer Hauptfunktion (Main).
import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "ImageClassifier.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

Es ist wichtig, vor dem Exportieren des Modells model.eval() oder model.train(False) aufzurufen, da dies das Modell in den Rückschlussmodus versetzt. Dies ist erforderlich, da sich Operatoren wie dropout oder batchnorm im Rückschluss- und Trainingsmodus unterschiedlich verhalten.

  1. Um die Konvertierung in ONNX auszuführen, fügen Sie der Hauptfunktion einen Aufruf der Konvertierungsfunktion hinzu. Sie müssen das Modell nicht erneut trainieren, daher kommentieren wir einige Funktionen aus, die nicht mehr ausgeführt werden müssen. Die Hauptfunktion sieht wie folgt aus.
if __name__ == "__main__": 

    # Let's build our model 
    #train(5) 
    #print('Finished Training') 

    # Test which classes performed well 
    #testAccuracy() 

    # Let's load the model we just created and test the accuracy per label 
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 

    # Test with batch of images 
    #testBatch() 
    # Test how the classes performed 
    #testClassess() 
 
    # Conversion to ONNX 
    Convert_ONNX() 
  1. Führen Sie das Projekt erneut aus, indem Sie auf der Symbolleiste die Schaltfläche Start Debugging auswählen oder F5 drücken. Es ist nicht erforderlich, das Modell erneut zu trainieren. Laden Sie einfach das vorhandene Modell aus dem Projektordner.

Die Ausgabe sieht dann wie folgt aus.

ONNX-Konvertierungsprozess

Navigieren Sie zu Ihrem Projektspeicherort, und suchen Sie das ONNX-Modell neben dem .pth-Modell.

Hinweis

Möchten Sie mehr erfahren? Sehen Sie sich das PyTorch-Tutorial zum Exportieren eines Modells an.

Untersuchen Sie Ihr Modell.

  1. Öffnen Sie die Modelldatei ImageClassifier.onnx mit Neutron.

  2. Wählen Sie den Datenknoten aus, um die Modelleigenschaften zu öffnen.

ONNX-Modelleigenschaften

Wie Sie sehen können, erfordert das Modell ein 32-Bit-Tensor-Float-Objekt (mehrdimensionales Array) als Eingabe und gibt ein Tensor-Float-Objekt als Ausgabe zurück. Das Ausgabearray enthält die Wahrscheinlichkeit für jede Bezeichnung. So wie Sie das Modell erstellt haben, werden die Bezeichnungen durch zehn Zahlen repräsentiert, und jede Zahl steht für die zehn Klassen von Objekten.

Bezeichnung 0 Bezeichnung 1 Bezeichnung 2 Bezeichnung 3 Bezeichnung 4 Bezeichnung 5 Bezeichnung 6 Bezeichnung 7 Bezeichnung 8 Bezeichnung 9
0 1 2 3 4 5 6 7 8 9
ebene Auto bird cat Reh dog Frosch horse Schiff LKW

Sie müssen diese Werte extrahieren, um die richtige Vorhersage mit der Windows ML-App anzuzeigen.

Nächste Schritte

Unser Modell ist bereit für die Bereitstellung. Als nächstes erstellen wir für das Hauptereignis eine Windows-Anwendung und führen sie lokal auf Ihrem Windows-Gerät aus.