Преобразование модели PyTorch в формат ONNX
На предыдущем этапе работы с этим учебником мы использовали PyTorch для создания модели машинного обучения. Однако эта модель является файлом .pth
. Чтобы иметь возможность интегрировать этот файл с приложением Windows ML, вам понадобится преобразовать модель в формат ONNX.
Экспорт модели.
Для экспорта модели нужно использовать функцию torch.onnx.export()
. Эта функция выполняет модель и записывает трассировку того, какие операторы используются для расчета выходных данных.
- Скопируйте следующий код в файл
DataClassifier.py
в Visual Studio и вставьте его над функцией main.
#Function to Convert to ONNX
def convert():
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
# Export the model
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"Network.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
Прежде чем экспортировать модель, нужно вызвать model.eval()
или model.train(False)
, поскольку эти переключатели позволяют задать для модели режим вывода. Такое действие необходимо, поскольку операторы dropout
или batchnorm
работают по-разному в режиме вывода и обучения.
- Чтобы выполнить преобразование в ONNX, добавьте вызов функции преобразования в функцию main. Заново обучать модель не нужно, поэтому мы закомментируем некоторые функции, которые нам больше не понадобится выполнять. Функция main будет выглядеть следующим образом.
if __name__ == "__main__":
num_epochs = 10
train(num_epochs)
print('Finished Training\n')
test()
test_species()
convert()
- Запустите проект еще раз, нажав кнопку
Start Debugging
на панели инструментов или клавишуF5
. Обучать модель снова не понадобится. Все, что нужно сделать — просто загрузить существующую модель из папки проекта.
Перейдите к расположению проекта и найдите модель ONNX рядом с моделью .pth
.
Примечание.
Хотите узнать больше? Ознакомьтесь с учебником PyTorch по экспорту модели.
Обзор модели.
Откройте файл модели
Network.onnx
с помощью Neutron.Выберите узел data, чтобы открыть свойства модели.
Как мы видим, в качестве входных данных для модели нужно использовать 32-разрядный свободно перемещаемый объект (многомерный массив) тензор, а в качестве выходных данных возвращается число с плавающей точкой тензора. Массив выходных данных будет содержать вероятность для каждой метки. При построении модели метки обозначаются 3 числами, каждое из которых связано с конкретным типом цветка Ириса.
Метка 1 | label 2 | Метка 3 |
---|---|---|
0 | 1 | 2 |
Ирис щетинистый | Ирис разноцветный | Ирис виргинский |
Вам нужно будет извлечь эти значения, чтобы отобразить правильный прогноз в приложении Windows ML.
Дальнейшие действия
Наша модель готова к использованию. Затем, что касается главного события, давайте создадим приложение Windows и запустим его локально на устройстве Windows.