Windows Machine Learning API を使用して TensorFlow モデルを Windows アプリに展開する
この最後のセクションでは、Web カメラをストリーミングし、Windows ML を使用して YOLO モデルを評価してオブジェクトを検出する GUI を備えた、シンプルな UWP アプリを作成する方法について説明します。
Visual Studio で UWP アプリを作成する
- Visual Studio を開き、[
Create a new project.
] を選択します。UWP を検索し、[Blank App (Universal Windows)
] を選択します。
- 次のページで、プロジェクトの [名前] と [場所] を指定してプロジェクトの設定を構成します。 次に、アプリのターゲットと最小 OS バージョンを選択します。 Windows ML API を使用するには、X を使用する必要がありますが、NuGet パッケージを選択して X までサポートすることもできます。NuGet パッケージを使用することを選択した場合は、次の手順 [リンク] に従ってください。
Windows ML API を呼び出してモデルを評価する
手順 1: Machine Learning コード ジェネレーターを使用して、Windows ML API のラッパー クラスを生成します。
手順 2: 生成された .cs ファイル内の生成されたコードを修正します。 最終的なファイルは次のようになります。
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Windows.Media;
using Windows.Storage;
using Windows.Storage.Streams;
using Windows.AI.MachineLearning;
namespace yolodemo
{
public sealed class YoloInput
{
public TensorFloat input_100; // shape(-1,3,416,416)
}
public sealed class YoloOutput
{
public TensorFloat concat_1600; // shape(-1,-1,-1)
}
public sealed class YoloModel
{
private LearningModel model;
private LearningModelSession session;
private LearningModelBinding binding;
public static async Task<YoloModel> CreateFromStreamAsync(IRandomAccessStreamReference stream)
{
YoloModel learningModel = new YoloModel();
learningModel.model = await LearningModel.LoadFromStreamAsync(stream);
learningModel.session = new LearningModelSession(learningModel.model);
learningModel.binding = new LearningModelBinding(learningModel.session);
return learningModel;
}
public async Task<YoloOutput> EvaluateAsync(YoloInput input)
{
binding.Bind("input_1:0", input.input_100);
var result = await session.EvaluateAsync(binding, "0");
var output = new YoloOutput();
output.concat_1600 = result.Outputs["concat_16:0"] as TensorFloat;
return output;
}
}
}
各ビデオ フレームを評価してオブジェクトを検出し、境界ボックスを描画します。
- 次のライブラリを mainPage.xaml.cs に追加します。
using System.Threading.Tasks;
using Windows.Devices.Enumeration;
using Windows.Media;
using Windows.Media.Capture;
using Windows.Storage;
using Windows.UI;
using Windows.UI.Xaml.Media.Imaging;
using Windows.UI.Xaml.Shapes;
using Windows.AI.MachineLearning;
- 次の変数を
public sealed partial class MainPage : Page
に追加します。
private MediaCapture _media_capture;
private LearningModel _model;
private LearningModelSession _session;
private LearningModelBinding _binding;
private readonly SolidColorBrush _fill_brush = new SolidColorBrush(Colors.Transparent);
private readonly SolidColorBrush _line_brush = new SolidColorBrush(Colors.DarkGreen);
private readonly double _line_thickness = 2.0;
private readonly string[] _labels =
{
"<list of labels>"
};
- 検出結果の書式設定方法を示す構造を作成します。
internal struct DetectionResult
{
public string label;
public List<float> bbox;
public double prob;
}
- Box 型の 2 つのオブジェクトを比較する Comparer オブジェクトを作成します。 このクラスは、検出されたオブジェクトの周囲に境界ボックスを描画するために使用されます。
class Comparer : IComparer<DetectionResult>
{
public int Compare(DetectionResult x, DetectionResult y)
{
return y.prob.CompareTo(x.prob);
}
}
- 次のメソッドを追加してデバイスの Web カム ストリームを初期化し、各フレームの処理を開始してオブジェクトを検出します。
private async Task InitCameraAsync()
{
if (_media_capture == null || _media_capture.CameraStreamState == Windows.Media.Devices.CameraStreamState.Shutdown || _media_capture.CameraStreamState == Windows.Media.Devices.CameraStreamState.NotStreaming)
{
if (_media_capture != null)
{
_media_capture.Dispose();
}
MediaCaptureInitializationSettings settings = new MediaCaptureInitializationSettings();
var cameras = await DeviceInformation.FindAllAsync(DeviceClass.VideoCapture);
var camera = cameras.FirstOrDefault();
settings.VideoDeviceId = camera.Id;
_media_capture = new MediaCapture();
await _media_capture.InitializeAsync(settings);
WebCam.Source = _media_capture;
}
if (_media_capture.CameraStreamState == Windows.Media.Devices.CameraStreamState.NotStreaming)
{
await _media_capture.StartPreviewAsync();
WebCam.Visibility = Visibility.Visible;
}
ProcessFrame();
}
- 次のメソッドを追加して、各フレームを処理します。 このメソッドにより、EvaluateFrame と DrawBoxes が呼び出されます。これらは後の手順で実装します。
private async Task ProcessFrame()
{
var frame = new VideoFrame(Windows.Graphics.Imaging.BitmapPixelFormat.Bgra8, (int)WebCam.Width, (int)WebCam.Height);
await _media_capture.GetPreviewFrameAsync(frame);
var results = await EvaluateFrame(frame);
await DrawBoxes(results.ToArray(), frame);
ProcessFrame();
}
- 新しい Sigmoid 浮動小数点型を作成します
private float Sigmoid(float val)
{
var x = (float)Math.Exp(val);
return x / (1.0f + x);
}
- オブジェクトを正しく検出するためのしきい値を作成します。
private float ComputeIOU(DetectionResult DRa, DetectionResult DRb)
{
float ay1 = DRa.bbox[0];
float ax1 = DRa.bbox[1];
float ay2 = DRa.bbox[2];
float ax2 = DRa.bbox[3];
float by1 = DRb.bbox[0];
float bx1 = DRb.bbox[1];
float by2 = DRb.bbox[2];
float bx2 = DRb.bbox[3];
Debug.Assert(ay1 < ay2);
Debug.Assert(ax1 < ax2);
Debug.Assert(by1 < by2);
Debug.Assert(bx1 < bx2);
// determine the coordinates of the intersection rectangle
float x_left = Math.Max(ax1, bx1);
float y_top = Math.Max(ay1, by1);
float x_right = Math.Min(ax2, bx2);
float y_bottom = Math.Min(ay2, by2);
if (x_right < x_left || y_bottom < y_top)
return 0;
float intersection_area = (x_right - x_left) * (y_bottom - y_top);
float bb1_area = (ax2 - ax1) * (ay2 - ay1);
float bb2_area = (bx2 - bx1) * (by2 - by1);
float iou = intersection_area / (bb1_area + bb2_area - intersection_area);
Debug.Assert(iou >= 0 && iou <= 1);
return iou;
}
- 次のリストを実装し、フレームで検出された現在のオブジェクトを追跡します。
private List<DetectionResult> NMS(IReadOnlyList<DetectionResult> detections,
float IOU_threshold = 0.45f,
float score_threshold=0.3f)
{
List<DetectionResult> final_detections = new List<DetectionResult>();
for (int i = 0; i < detections.Count; i++)
{
int j = 0;
for (j = 0; j < final_detections.Count; j++)
{
if (ComputeIOU(final_detections[j], detections[i]) > IOU_threshold)
{
break;
}
}
if (j==final_detections.Count)
{
final_detections.Add(detections[i]);
}
}
return final_detections;
}
- 次のメソッドを実装します。
private List<DetectionResult> ParseResult(float[] results)
{
int c_values = 84;
int c_boxes = results.Length / c_values;
float confidence_threshold = 0.5f;
List<DetectionResult> detections = new List<DetectionResult>();
this.OverlayCanvas.Children.Clear();
for (int i_box = 0; i_box < c_boxes; i_box++)
{
float max_prob = 0.0f;
int label_index = -1;
for (int j_confidence = 4; j_confidence < c_values; j_confidence++)
{
int index = i_box * c_values + j_confidence;
if (results[index] > max_prob)
{
max_prob = results[index];
label_index = j_confidence - 4;
}
}
if (max_prob > confidence_threshold)
{
List<float> bbox = new List<float>();
bbox.Add(results[i_box * c_values + 0]);
bbox.Add(results[i_box * c_values + 1]);
bbox.Add(results[i_box * c_values + 2]);
bbox.Add(results[i_box * c_values + 3]);
detections.Add(new DetectionResult()
{
label = _labels[label_index],
bbox = bbox,
prob = max_prob
});
}
}
return detections;
}
- フレーム内で検出されたオブジェクトの周囲にボックスを描画するために、次のメソッドを追加します。
private async Task DrawBoxes(float[] results, VideoFrame frame)
{
List<DetectionResult> detections = ParseResult(results);
Comparer cp = new Comparer();
detections.Sort(cp);
IReadOnlyList<DetectionResult> final_detetions = NMS(detections);
for (int i=0; i < final_detetions.Count; ++i)
{
int top = (int)(final_detetions[i].bbox[0] * WebCam.Height);
int left = (int)(final_detetions[i].bbox[1] * WebCam.Width);
int bottom = (int)(final_detetions[i].bbox[2] * WebCam.Height);
int right = (int)(final_detetions[i].bbox[3] * WebCam.Width);
var brush = new ImageBrush();
var bitmap_source = new SoftwareBitmapSource();
await bitmap_source.SetBitmapAsync(frame.SoftwareBitmap);
brush.ImageSource = bitmap_source;
// brush.Stretch = Stretch.Fill;
this.OverlayCanvas.Background = brush;
var r = new Rectangle();
r.Tag = i;
r.Width = right - left;
r.Height = bottom - top;
r.Fill = this._fill_brush;
r.Stroke = this._line_brush;
r.StrokeThickness = this._line_thickness;
r.Margin = new Thickness(left, top, 0, 0);
this.OverlayCanvas.Children.Add(r);
// Default configuration for border
// Render text label
var border = new Border();
var backgroundColorBrush = new SolidColorBrush(Colors.Black);
var foregroundColorBrush = new SolidColorBrush(Colors.SpringGreen);
var textBlock = new TextBlock();
textBlock.Foreground = foregroundColorBrush;
textBlock.FontSize = 18;
textBlock.Text = final_detetions[i].label;
// Hide
textBlock.Visibility = Visibility.Collapsed;
border.Background = backgroundColorBrush;
border.Child = textBlock;
Canvas.SetLeft(border, final_detetions[i].bbox[1] * 416 + 2);
Canvas.SetTop(border, final_detetions[i].bbox[0] * 416 + 2);
textBlock.Visibility = Visibility.Visible;
// Add to canvas
this.OverlayCanvas.Children.Add(border);
}
}
- 必要なインフラストラクチャが整ったので、いよいよ評価そのものを組み込みます。 このメソッドにより、現在のフレームに対してモデルが評価され、オブジェクトが検出されます。
private async Task<List<float>> EvaluateFrame(VideoFrame frame)
{
_binding.Clear();
_binding.Bind("input_1:0", frame);
var results = await _session.EvaluateAsync(_binding, "");
Debug.Print("output done\n");
TensorFloat result = results.Outputs["Identity:0"] as TensorFloat;
var shape = result.Shape;
var data = result.GetAsVectorView();
return data.ToList<float>();
}
- このアプリは何らかの方法で開始する必要があります。 ユーザーが [
Go
] ボタンを選択したときに、Web カメラのストリームとモデルの評価を開始するメソッドを追加します。
private void button_go_Click(object sender, RoutedEventArgs e)
{
InitModelAsync();
InitCameraAsync();
}
- Windows ML API を呼び出してモデルを評価するメソッドを追加します。 まず、モデルがストレージから読み込まれ、セッションが作成され、メモリにバインドされます。
private async Task InitModelAsync()
{
var model_file = await StorageFile.GetFileFromApplicationUriAsync(new Uri("ms-appx:///Assets//Yolo.onnx"));
_model = await LearningModel.LoadFromStorageFileAsync(model_file);
var device = new LearningModelDevice(LearningModelDeviceKind.Cpu);
_session = new LearningModelSession(_model, device);
_binding = new LearningModelBinding(_session);
}
アプリケーションを起動します
これで、リアルタイムのオブジェクト検出アプリケーションを作成することができました。 Visual Studio の上部のバーにある [Run
] ボタンを選択し、アプリを起動します。 アプリは次のようになります。
その他のリソース
このチュートリアルで説明しているトピックについて詳しくは、次のリソースを参照してください。
- Windows ML ツール: Windows ML ダッシュボード、WinMLRunner、mglen Windows ML コード ジェネレーターなどのツールについて説明します。
- ONNX モデル: ONNX 形式について説明します。
- Windows ML のパフォーマンスとメモリ: Windows ML を使用してアプリのパフォーマンスを管理する方法について説明します。
- Windows Machine Learning API リファレンス: Windows ML の 3 つの領域について説明します。