使用融合运算符提高性能
某些 DirectML 运算符支持称为融合的概念。 运算符融合是提高性能的一种方法,通过将一个运算符(通常是激活函数)合并到另一个运算符以便其一起执行,而无需往返内存。
何时融合激活函数
融合激活函数是实现性能优化的一种手段。 在许多机器学习 (ML) 模型中,极其常见的一种方案是将非线性(激活函数)应用于模型中每一层的输出。
通常,这需要往返图形内存。 例如,如果卷积后跟非融合 Relu 激活函数,则 GPU 必须等待卷积的结果写入 GPU 内存,然后才能开始计算 Relu 激活层。 由于大多数激活函数的计算工作负载往往较小,因此往返图形内存可能是主要的性能瓶颈。
运算符融合允许在前面的运算符(例如卷积)中执行激活函数(如上例中的 Relu)。 这样,GPU 就可以计算激活函数,而无需等待前面的运算符的结果写入内存,从而提高了性能。
由于融合激活函数后会产生相同的结果,但在很多情况下计算速度更快,因此,建议尽可能将激活层融合到其前面的运算符中,从而消除激活层。
如何融合激活函数
支持融合激活函数的运算符在其运算符结构 const DML_OPERATOR_DESC* FusedActivation
中具有其他可选参数。 以卷积为例,卷积支持融合激活函数,其运算符说明中具有相应的 FusedActivation(请参阅 DML_CONVOLUTION_OPERATOR_DESC)。
struct DML_CONVOLUTION_OPERATOR_DESC
{
const DML_TENSOR_DESC* InputTensor;
const DML_TENSOR_DESC* FilterTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
const DML_TENSOR_DESC* OutputTensor;
DML_CONVOLUTION_MODE Mode;
DML_CONVOLUTION_DIRECTION Direction;
UINT DimensionCount;
_Field_size_(DimensionCount) const UINT* Strides;
_Field_size_(DimensionCount) const UINT* Dilations;
_Field_size_(DimensionCount) const UINT* StartPadding;
_Field_size_(DimensionCount) const UINT* EndPadding;
_Field_size_(DimensionCount) const UINT* OutputPadding;
UINT GroupCount;
_Maybenull_ const DML_OPERATOR_DESC* FusedActivation;
};
要融合激活函数,请构造 DML_OPERATOR_DESC,描述要融合的激活函数类型。 例如,要融合 Relu 函数,则正确的运算符类型为 DML_OPERATOR_ACTIVATION_RELU。
注意
构造激活函数的运算符说明时,必须将激活函数的 InputTensor 和 OutputTensor 参数设置为 NULL。
示例
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyReluDesc;
leakyReluDesc.InputTensor = nullptr;
leakyReluDesc.OutputTensor = nullptr;
leakyReluDesc.Alpha = 0.01f;
DML_OPERATOR_DESC activationDesc = { DML_OPERATOR_ACTIVATION_LEAKY_RELU, &leakyReluDesc };
DML_CONVOLUTION_OPERATOR_DESC convDesc;
// ...
convDesc.FusedActivation = &activationDesc;
对于完整示例,DirectMLSuperResolution 示例利用融合后的激活函数来提高性能。
支持融合激活函数的运算符
以下列出了 DML_OPERATOR_TYPE 枚举中的常量。 该主题中的每个常量均已链接到要使用的相应说明结构。
- DML_OPERATOR_BATCH_NORMALIZATION
- DML_OPERATOR_BATCH_NORMALIZATION_TRAINING
- DML_OPERATOR_CONVOLUTION
- DML_OPERATOR_ELEMENT_WISE_ADD1
- DML_OPERATOR_GEMM
- DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION
- DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1
支持融合的激活函数
以下列出了 DML_OPERATOR_TYPE 枚举中的常量。 该主题中的每个常量均已链接到要使用的相应说明结构。
- DML_OPERATOR_ELEMENT_WISE_CLIP
- DML_OPERATOR_ACTIVATION_LINEAR
- DML_OPERATOR_ACTIVATION_SIGMOID
- DML_OPERATOR_ACTIVATION_HARD_SIGMOID
- DML_OPERATOR_ACTIVATION_TANH
- DML_OPERATOR_ACTIVATION_SCALED_TANH
- DML_OPERATOR_ACTIVATION_RELU
- DML_OPERATOR_ACTIVATION_LEAKY_RELU
- DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU
- DML_OPERATOR_ACTIVATION_ELU
- DML_OPERATOR_ACTIVATION_CELU
- DML_OPERATOR_ACTIVATION_SCALED_ELU
- DML_OPERATOR_ACTIVATION_SOFTPLUS
- DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS
- DML_OPERATOR_ACTIVATION_SOFTSIGN
- DML_OPERATOR_ACTIVATION_IDENTITY
- DML_OPERATOR_ACTIVATION_SHRINK
- DML_OPERATOR_ACTIVATION_GELU
未列出的任何运算符都不支持融合激活函数。