структура DML_GEMM_OPERATOR_DESC (directml.h)
Выполняет общую функцию матричного умножения формы Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
, где x
обозначает матричное умножение, а *
умножение — скалярным.
Этому оператору требуются 4D-тензоры с макетом { BatchCount, ChannelCount, Height, Width }
, и он будет выполнять BatchCount * ChannelCount число независимых матричных умножений.
Например, если параметр ATensor имеет значение Sizes{ BatchCount, ChannelCount, M, K }
, а BTensor имеет значение Sizes{ BatchCount, ChannelCount, K, N }
, а OutputTensor имеет значение Sizes{ BatchCount, ChannelCount, M, N }
, то этот оператор выполняет независимое матричное умножение размеров {M,K} x {K,N} = {M,N}.
Синтаксис
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
Члены
ATensor
Тип: const DML_TENSOR_DESC*
Тензор, содержащий матрицу A. Размеры этого тензора должны иметь значение { BatchCount, ChannelCount, M, K }
, если TransAявляется DML_MATRIX_TRANSFORM_NONE, или { BatchCount, ChannelCount, K, M }
Если TransA является DML_MATRIX_TRANSFORM_TRANSPOSE.
BTensor
Тип: const DML_TENSOR_DESC*
Тензор, содержащий матрицу B. Размер тензора должен иметь значение { BatchCount, ChannelCount, K, N }
, если TransBявляется DML_MATRIX_TRANSFORM_NONE, или { BatchCount, ChannelCount, N, K }
если TransB является DML_MATRIX_TRANSFORM_TRANSPOSE.
CTensor
Тип: _Maybenull_ const DML_TENSOR_DESC*
Тензор, содержащий матрицу C, или nullptr
. Если значение не указано, значение по умолчанию — 0. Если указан, размер этого тензора должен иметь значение { BatchCount, ChannelCount, M, N }
.
OutputTensor
Тип: const DML_TENSOR_DESC*
Тензор, в который записываются результаты. Размеры этого тензора : { BatchCount, ChannelCount, M, N }
.
TransA
Тип: DML_MATRIX_TRANSFORM
Преобразование, применяемое к ATensor; либо транспонировать, либо нет преобразования.
TransB
Тип: DML_MATRIX_TRANSFORM
Преобразование, применяемое к BTensor; либо транспонировать, либо нет преобразования.
Alpha
Тип: FLOAT
Значение скалярного множителя для произведения входных значений ATensor и BTensor.
Beta
Тип: FLOAT
Значение скалярного множителя для необязательного входного элемента CTensor. Если параметр CTensor не указан, это значение игнорируется.
FusedActivation
Тип: _Maybenull_ const DML_OPERATOR_DESC*
Необязательный слой активации с слиянием, применяемый после GEMM. Дополнительные сведения см. в разделе Использование слитых операторов для повышения производительности.
Доступность
Этот оператор появился в DML_FEATURE_LEVEL_1_0
.
Ограничения тензоров
- ATensor, BTensor, CTensor и OutputTensor должны иметь одинаковые значения DataType и DimensionCount.
- CTensor и OutputTensor должны иметь одинаковые размеры.
Поддержка тензоров
DML_FEATURE_LEVEL_4_0 и выше
Тензор | Kind | Измерения | Поддерживаемые счетчики измерений | Поддерживаемые типы данных |
---|---|---|---|---|
ATensor | Входные данные | { [BatchCount], [ChannelCount], M, K } | От 2 до 4 | FLOAT32, FLOAT16 |
BTensor | Входные данные | { [BatchCount], [ChannelCount], K, N } | От 2 до 4 | FLOAT32, FLOAT16 |
CTensor | Необязательные входные данные | { [BatchCount], [ChannelCount], M, N } | От 2 до 4 | FLOAT32, FLOAT16 |
OutputTensor | Выходные данные | { [BatchCount], [ChannelCount], M, N } | От 2 до 4 | FLOAT32, FLOAT16 |
DML_FEATURE_LEVEL_1_0 и выше
Тензор | Kind | Измерения | Поддерживаемые счетчики измерений | Поддерживаемые типы данных |
---|---|---|---|---|
ATensor | Входные данные | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32, FLOAT16 |
BTensor | Входные данные | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32, FLOAT16 |
CTensor | Необязательные входные данные | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
OutputTensor | Выходные данные | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
Требования
Требование | Значение |
---|---|
Заголовок | directml.h |