Partager via


Convertir votre modèle d’entraînement PyTorch au format ONNX

Remarque

Pour une plus grande fonctionnalité, PyTorch peut également être utilisé avec DirectML sur Windows.

Dans l’étape précédente de ce tutoriel, nous avons utilisé PyTorch pour créer notre modèle de Machine Learning. Toutefois, ce modèle est un fichier .pth. Pour être en mesure de l’intégrer à l’application Windows ML, vous devez convertir le modèle au format ONNX.

Exporter le modèle

Pour exporter un modèle, vous allez utiliser la fonction torch.onnx.export(). Cette fonction exécute le modèle, et enregistre une trace des opérateurs utilisés pour calculer les sorties.

  1. Copiez le code suivant dans le fichier PyTorchTraining.py dans Visual Studio, au-dessus de votre fonction principale.
import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "ImageClassifier.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

Il est important d’appeler model.eval() ou model.train(False) avant d’exporter le modèle, car cela définit le modèle en mode d’inférence. Cela est nécessaire, car les opérateurs comme dropout ou batchnorm se comportent différemment en mode d’inférence et en mode d’entraînement.

  1. Pour exécuter la conversion au format ONNX, ajoutez un appel à la fonction de conversion à la fonction principale. Vous n’avez pas besoin d’entraîner à nouveau le modèle. Nous allons donc convertir en commentaires certaines fonctions que nous n’avons plus besoin d’exécuter. Votre fonction principale sera la suivante.
if __name__ == "__main__": 

    # Let's build our model 
    #train(5) 
    #print('Finished Training') 

    # Test which classes performed well 
    #testAccuracy() 

    # Let's load the model we just created and test the accuracy per label 
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 

    # Test with batch of images 
    #testBatch() 
    # Test how the classes performed 
    #testClassess() 
 
    # Conversion to ONNX 
    Convert_ONNX() 
  1. Réexécutez le projet en sélectionnant le bouton Start Debugging dans la barre d’outils, ou en appuyant sur F5. Il n’est pas nécessaire d’effectuer une nouvelle tentative d’entraînement du modèle. Chargez simplement le modèle existant à partir du dossier du projet.

La sortie se présentera comme suit.

Processus de conversion ONNX

Accédez à l’emplacement de votre projet et recherchez le modèle ONNX en regard du modèle .pth.

Remarque

Vous voulez en savoir plus ? Consultez le tutoriel PyTorch sur l’exportation d’un modèle.

Explorez votre modèle.

  1. Ouvrez le fichier de modèle ImageClassifier.onnx avec Neutron.

  2. Sélectionnez le nœud data pour ouvrir les propriétés du modèle.

Propriétés du modèle ONNX

Comme vous pouvez le voir, le modèle requiert un objet float tenseur de 32 bits (tableau multidimensionnel) comme entrée et retourne un float Tensor en sortie. Le tableau de sortie inclut la probabilité pour chaque étiquette. La façon dont vous avez créé le modèle, les étiquettes sont représentées par 10 nombres, et chaque nombre représente les dix classes d’objets.

Étiquette 0 Étiquette 1 Étiquette 2 Étiquette 3 Étiquette 4 Étiquette 5 Étiquette 6 Étiquette 7 Étiquette 8 Étiquette 9
0 1 2 3 4 5 6 7 8 9
plane (avion) voiture bird cat deer (cerf) dog frog (grenouille) horse (cheval) bateau (ship) camion

Vous devrez extraire ces valeurs pour afficher la prédiction correcte avec l’application Windows ML.

Étapes suivantes

Notre modèle est prêt à être déployé. Ensuite, pour l’événement principal, nous allons créer une application Windows et l’exécuter localement sur votre appareil Windows.