DML_GEMM_OPERATOR_DESC 结构 (directml.h)
执行形式的 Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
常规矩阵乘法函数,其中 x
表示矩阵乘法,使用 *
标量表示乘法。
此运算符需要具有布局 { BatchCount, ChannelCount, Height, Width }
的 4D 张量,它将执行 BatchCount * ChannelCount 数量的独立矩阵乘法。
例如,如果 ATensor的大小为{ BatchCount, ChannelCount, M, K }
, 而 BTensor的大小为{ BatchCount, ChannelCount, K, N }
, OutputTensor 的 Size 为 { BatchCount, ChannelCount, M, N }
,则此运算符执行维度 {M,K} x {K,N} = {M,N} 的 BatchCount * ChannelCount 独立矩阵乘法。
语法
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
成员
ATensor
类型: const DML_TENSOR_DESC*
包含 A 矩阵的张量。 如果 transA 是DML_MATRIX_TRANSFORM_NONE,则此张量的大小应为{ BatchCount, ChannelCount, M, K }
,或者{ BatchCount, ChannelCount, K, M }
如果 TransA 是DML_MATRIX_TRANSFORM_TRANSPOSE。
BTensor
类型: const DML_TENSOR_DESC*
包含 B 矩阵的张量。 如果 TransB 是DML_MATRIX_TRANSFORM_NONE,则此张量的大小应为{ BatchCount, ChannelCount, K, N }
,或者{ BatchCount, ChannelCount, N, K }
如果 TransB 是DML_MATRIX_TRANSFORM_TRANSPOSE。
CTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
包含 C 矩阵的张量,或 nullptr
。 如果未提供值,则默认值为 0。 如果提供,则此张量 的大小 应为 { BatchCount, ChannelCount, M, N }
。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入结果的张量。 此张 量的大小为{ BatchCount, ChannelCount, M, N }
。
TransA
要应用于 ATensor 的转换;转置或无转换。
TransB
要应用于 BTensor 的转换;转置或无转换。
Alpha
类型: FLOAT
输入 ATensor 和 BTensor 的乘积的标量乘数的值。
Beta
类型: FLOAT
可选输入 CTensor 的标量乘数值。 如果未提供 CTensor ,则忽略此值。
FusedActivation
类型:_Maybenull_ const DML_OPERATOR_DESC*
在 GEMM 之后应用的可选融合激活层。 有关详细信息,请参阅 使用融合运算符提高性能。
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_1_0
。
张量约束
- ATensor、 BTensor、 CTensor 和 OutputTensor 必须具有相同的 DataType 和 DimensionCount。
- CTensor 和 OutputTensor 必须具有相同 的大小。
Tensor 支持
DML_FEATURE_LEVEL_4_0及更高版本
张 | 种类 | 维度 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|---|
ATensor | 输入 | { [BatchCount], [ChannelCount], M, K } | 2 到 4 | FLOAT32、FLOAT16 |
BTensor | 输入 | { [BatchCount], [ChannelCount], K, N } | 2 到 4 | FLOAT32、FLOAT16 |
CTensor | 可选输入 | { [BatchCount], [ChannelCount], M, N } | 2 到 4 | FLOAT32、FLOAT16 |
OutputTensor | 输出 | { [BatchCount], [ChannelCount], M, N } | 2 到 4 | FLOAT32、FLOAT16 |
DML_FEATURE_LEVEL_1_0 及更高版本
张 | 种类 | 维度 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|---|
ATensor | 输入 | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32、FLOAT16 |
BTensor | 输入 | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32、FLOAT16 |
CTensor | 可选输入 | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32、FLOAT16 |
OutputTensor | 输出 | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32、FLOAT16 |
要求
要求 | 值 |
---|---|
Header | directml.h |