PyTorch モデルを ONNX 形式に変換する
このチュートリアルの前の段階では、PyTorch を使用して機械学習モデルを作成しました。 ただし、そのモデルは、.pth
ファイルです。 Windows ML アプリと統合するには、モデルを ONNX 形式に変換する必要があります。
モデルのエクスポート
モデルをエクスポートするには、torch.onnx.export()
関数を使用します。 この関数によってモデルを実行し、出力を計算するために使用される演算子のトレースを記録します。
- Visual Studio で、次のコードを
DataClassifier.py
ファイルの main 関数の上にコピーします。
#Function to Convert to ONNX
def convert():
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
# Export the model
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"Network.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
モデルをエクスポートする前に model.eval()
または model.train(False)
を呼び出すことが重要です。これにより、モデルが推論モードに設定されます。 dropout
や batchnorm
などの演算子は、推論モードとトレーニング モードでは動作が異なります。
- ONNX への変換を実行するには、変換関数の呼び出しを main 関数に追加します。 モデルを再度トレーニングする必要はないので、実行する必要がなくなったいくつかの関数をコメント アウトします。 main 関数は次のようになります。
if __name__ == "__main__":
num_epochs = 10
train(num_epochs)
print('Finished Training\n')
test()
test_species()
convert()
- ツール バーの [
Start Debugging
] ボタンを選択するかF5
キーを押して、プロジェクトを再度実行します。 モデルを再度トレーニングする必要はありません。プロジェクト フォルダーから既存のモデルを読み込むだけです。
プロジェクトの場所に移動し、.pth
モデルの横にある ONNX モデルを探します。
Note
もっと詳しく知りたいですか? モデルのエクスポートに関する PyTorch チュートリアルを参照してください。
モデルを探索する
Neutron を使用して
Network.onnx
モデル ファイルを開きます。"データ" ノードを選択して、モデルのプロパティを開きます。
ご覧のように、このモデルでは、入力として 32 ビットのテンソル (多次元配列) 浮動小数点数オブジェクトが必要であり、出力として Tensor float が返されます。 出力配列には、すべてのラベルの確率が含まれます。 モデルを構築する方法では、ラベルは 3 つの数値で表され、それぞれが特定の種類のアヤメに関連付けされます。
ラベル 1 | ラベル 2 | ラベル 3 |
---|---|---|
0 | 1 | 2 |
Iris-setosa | Iris-versicolor | Iris-virginica |
Windows ML アプリで正しい予測を表示するには、これらの値を抽出する必要があります。
次のステップ
モデルをデプロイする準備ができました。 次に、メイン イベントである Windows アプリケーションのビルドと Windows デバイス上でのローカルの実行を行いましょう。