estructura DML_GEMM_OPERATOR_DESC (directml.h)
Realiza una función de multiplicación de matriz general del formulario Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
, donde x
denota la multiplicación de la matriz y *
denota la multiplicación con un escalar.
Este operador requiere tensores 4D con diseño { BatchCount, ChannelCount, Height, Width }
y realizará BatchCount * ChannelCount número de multiplicaciones de matrices independientes.
Por ejemplo, si ATensor tiene Tamaños de , y BTensor tiene Tamaños de { BatchCount, ChannelCount, K, N }
y OutputTensor tiene Tamaños de { BatchCount, ChannelCount, M, N }
, este operador realiza multiplicaciones de matrices independientes de { BatchCount, ChannelCount, M, K }
BatchCount * ChannelCount de dimensiones {M,K} x {K,N} = {M,N}.
Sintaxis
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
Miembros
ATensor
Tipo: const DML_TENSOR_DESC*
Tensor que contiene la matriz A. Los tamaños de este tensor deben ser { BatchCount, ChannelCount, M, K }
si TransA es DML_MATRIX_TRANSFORM_NONE o { BatchCount, ChannelCount, K, M }
si TransA es DML_MATRIX_TRANSFORM_TRANSPOSE.
BTensor
Tipo: const DML_TENSOR_DESC*
Tensor que contiene la matriz B. Los tamaños de este tensor deben ser { BatchCount, ChannelCount, K, N }
si TransB es DML_MATRIX_TRANSFORM_NONE o { BatchCount, ChannelCount, N, K }
si TransB es DML_MATRIX_TRANSFORM_TRANSPOSE.
CTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Tensor que contiene la matriz de C o nullptr
. Los valores predeterminados son 0 cuando no se proporcionan. Si se proporciona, los tamaños de este tensor deben ser { BatchCount, ChannelCount, M, N }
.
OutputTensor
Tipo: const DML_TENSOR_DESC*
Tensor en el que se van a escribir los resultados. Los tamaños de este tensor son { BatchCount, ChannelCount, M, N }
.
TransA
Tipo: DML_MATRIX_TRANSFORM
Transformación que se va a aplicar a ATensor; una transposición o ninguna transformación.
TransB
Tipo: DML_MATRIX_TRANSFORM
Transformación que se va a aplicar a BTensor; una transposición o ninguna transformación.
Alpha
Tipo: FLOAT
Valor del multiplicador escalar para el producto de entradas ATensor y BTensor.
Beta
Tipo: FLOAT
Valor del multiplicador escalar para el CTensor de entrada opcional. Si no se proporciona CTensor , este valor se omite.
FusedActivation
Tipo: _Maybenull_ const DML_OPERATOR_DESC*
Una capa de activación fusionada opcional que se aplicará después del GEMM. Para obtener más información, consulte Uso de operadores fusionados para mejorar el rendimiento.
Disponibilidad
Este operador se introdujo en DML_FEATURE_LEVEL_1_0
.
Restricciones tensor
- ATensor, BTensor, CTensor y OutputTensor deben tener el mismo DataType y DimensionCount.
- CTensor y OutputTensor deben tener los mismos tamaños.
Compatibilidad con Tensor
DML_FEATURE_LEVEL_4_0 y versiones posteriores
Tensor | Clase | Dimensions | Recuentos de dimensiones admitidos | Tipos de datos admitidos |
---|---|---|---|---|
ATensor | Entrada | { [BatchCount], [ChannelCount], M, K } | De 2 a 4 | FLOAT32, FLOAT16 |
BTensor | Entrada | { [BatchCount], [ChannelCount], K, N } | De 2 a 4 | FLOAT32, FLOAT16 |
Ctensor | Entrada opcional | { [BatchCount], [ChannelCount], M, N } | De 2 a 4 | FLOAT32, FLOAT16 |
OutputTensor | Resultados | { [BatchCount], [ChannelCount], M, N } | De 2 a 4 | FLOAT32, FLOAT16 |
DML_FEATURE_LEVEL_1_0 y versiones posteriores
Tensor | Clase | Dimensions | Recuentos de dimensiones admitidos | Tipos de datos admitidos |
---|---|---|---|---|
ATensor | Entrada | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32, FLOAT16 |
BTensor | Entrada | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32, FLOAT16 |
Ctensor | Entrada opcional | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
OutputTensor | Resultados | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
Requisitos
Requisito | Valor |
---|---|
Header | directml.h |