Поделиться через


Преобразование модели PyTorch в формат ONNX

На предыдущем этапе работы с этим учебником мы использовали PyTorch для создания модели машинного обучения. Однако эта модель является файлом .pth. Чтобы иметь возможность интегрировать этот файл с приложением Windows ML, вам понадобится преобразовать модель в формат ONNX.

Экспорт модели.

Для экспорта модели нужно использовать функцию torch.onnx.export(). Эта функция выполняет модель и записывает трассировку того, какие операторы используются для расчета выходных данных.

  1. Скопируйте следующий код в файл DataClassifier.py в Visual Studio и вставьте его над функцией main.
#Function to Convert to ONNX 
def convert(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "Network.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=11,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['input'],   # the model's input names 
         output_names = ['output'], # the model's output names 
         dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes 
                                'output' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

Прежде чем экспортировать модель, нужно вызвать model.eval() или model.train(False), поскольку эти переключатели позволяют задать для модели режим вывода. Такое действие необходимо, поскольку операторы dropout или batchnorm работают по-разному в режиме вывода и обучения.

  1. Чтобы выполнить преобразование в ONNX, добавьте вызов функции преобразования в функцию main. Заново обучать модель не нужно, поэтому мы закомментируем некоторые функции, которые нам больше не понадобится выполнять. Функция main будет выглядеть следующим образом.
if __name__ == "__main__": 
    num_epochs = 10 
    train(num_epochs) 
    print('Finished Training\n') 
    test() 
    test_species() 
    convert() 
  1. Запустите проект еще раз, нажав кнопку Start Debugging на панели инструментов или клавишу F5. Обучать модель снова не понадобится. Все, что нужно сделать — просто загрузить существующую модель из папки проекта.

Перейдите к расположению проекта и найдите модель ONNX рядом с моделью .pth.

Примечание.

Хотите узнать больше? Ознакомьтесь с учебником PyTorch по экспорту модели.

Обзор модели.

  1. Откройте файл модели Network.onnx с помощью Neutron.

  2. Выберите узел data, чтобы открыть свойства модели.

ONNX model properties

Как мы видим, в качестве входных данных для модели нужно использовать 32-разрядный свободно перемещаемый объект (многомерный массив) тензор, а в качестве выходных данных возвращается число с плавающей точкой тензора. Массив выходных данных будет содержать вероятность для каждой метки. При построении модели метки обозначаются 3 числами, каждое из которых связано с конкретным типом цветка Ириса.

Метка 1 label 2 Метка 3
0 1 2
Ирис щетинистый Ирис разноцветный Ирис виргинский

Вам нужно будет извлечь эти значения, чтобы отобразить правильный прогноз в приложении Windows ML.

Дальнейшие действия

Наша модель готова к использованию. Затем, что касается главного события, давайте создадим приложение Windows и запустим его локально на устройстве Windows.