編輯

共用方式為


DirectMLX

DirectMLX is a C++ header-only helper library for DirectML, intended to make it easier to compose individual operators into graphs.

DirectMLX provides convenient wrappers for all DirectML (DML) operator types, as well as intuitive operator overloads, which makes it simpler to instantiate DML operators, and chain them into complex graphs.

Where to find DirectMLX.h

DirectMLX.h is distributed as open-source software under the MIT license. The latest version can be found on the DirectML GitHub.

Version requirements

DirectMLX requires DirectML version 1.4.0, or newer (see DirectML version history). Older versions of DirectML are not supported.

DirectMLX.h requires a C++11-capable compiler, including (but not limited to):

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10

Note that a C++17 (or newer) compiler is the option that we recommend. Compiling for C++11 is possible, but it requires the use of third-party libraries (such as GSL and Abseil) to replace missing standard library functionality.

If you have a configuration that fails to compile DirectMLX.h, then please file an issue on our GitHub.

Basic usage

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

Here's another example, which creates a DirectML graph capable of computing the quadratic formula.

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

More examples

Complete samples using DirectMLX can be found on the DirectML GitHub repo.

Compile-time options

DirectMLX supports compile-time #define's to customize various parts of the header.

Option Description
DMLX_NO_EXCEPTIONS If #define'd, causes errors to result in a call to std::abort rather than throwing an exception. This is defined by default if exceptions are unavailable (for example, if exceptions have been disabled in the compiler options).
DMLX_USE_WIL If #define'd, exceptions are thrown using Windows Implementation Library exception types. Otherwise, standard exception types (such as std::runtime_error) are used instead. This option has no effect if DMLX_NO_EXCEPTIONS is defined.
DMLX_USE_ABSEIL If #define'd, uses Abseil as drop-in replacements for standard library types unavailable in C++11. These types include absl::optional (in place of std::optional), absl::Span (in place of std::span), and absl::InlinedVector.
DMLX_USE_GSL Controls whether to use GSL as the replacement for std::span. If #define'd, uses of std::span are replaced by gsl::span on compilers without native std::span implementations. Otherwise, an inline drop-in implementation is provided instead. Note that this option is only used when compiling on a pre-C++20 compiler without support for std::span, and when no other drop-in standard library replacement (like Abseil) is in use.

Controlling tensor layout

For most operators, DirectMLX computes the properties of the operator's output tensors on your behalf. For example when performing a dml::Reduce across axes { 0, 2, 3 } with an input tensor of sizes { 3, 4, 5, 6 }, DirectMLX will automatically compute the properties of the output tensor including the correct shape of { 1, 4, 1, 1 }.

However, the other properties of an output tensor include the Strides, TotalTensorSizeInBytes, and GuaranteedBaseOffsetAlignment. By default, DirectMLX sets these properties such that the tensor has no striding, no guaranteed base offset alignment, and a total tensor size in bytes as computed by DMLCalcBufferTensorSize.

DirectMLX supports the ability to customize these output tensor properties, using objects known as tensor policies. A TensorPolicy is a customizable callback that is invoked by DirectMLX, and returns output tensor properties given a tensor's computed data type, flags, and sizes.

Tensor policies can be set on the dml::Graph object, and will be used for all subsequent operators on that graph. Tensor policies can also be set directly when constructing a TensorDesc.

The layout of tensors produced by DirectMLX can therefore be controlled by setting a TensorPolicy that sets the appropriate strides on its tensors.

Example 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

Example 2

DirectMLX also provides some alternative tensor policies built-in. The InterleavedChannel policy, for example, is provided as a convenience, and it can be used to produce tensors with strides such that they are written in NHWC order.

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

See also