DML_GEMM_OPERATOR_DESC-Struktur (directml.h)
Führt eine allgemeine Matrixmultiplikationsfunktion des Formulars Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
aus, wobei x
die Matrixmultiplikation und *
die Multiplikation mit einem Skalar bezeichnet wird.
Dieser Operator erfordert 4D-Tensors mit Layout { BatchCount, ChannelCount, Height, Width }
und führt BatchCount * ChannelCount-Anzahl unabhängiger Matrixmultiplikationen aus.
Wenn ATensor beispielsweise Größen von { BatchCount, ChannelCount, M, K }
hat und BTensorGrößen von { BatchCount, ChannelCount, K, N }
hat und OutputTensorgrößen von { BatchCount, ChannelCount, M, N }
hat, führt dieser Operator BatchCount * ChannelCount-unabhängige Matrixmultiplikationen von Dimensionen {M,K} x {K,N} = {M,N} = {M,N} durch.
Syntax
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
Member
ATensor
Typ: const DML_TENSOR_DESC*
Ein Tensor, der die A-Matrix enthält. Die Größen dieses Tensors sollten sein { BatchCount, ChannelCount, M, K }
, wenn TransADML_MATRIX_TRANSFORM_NONE oder { BatchCount, ChannelCount, K, M }
TransADML_MATRIX_TRANSFORM_TRANSPOSE ist.
BTensor
Typ: const DML_TENSOR_DESC*
Ein Tensor, der die B-Matrix enthält. Die Größen dieses Tensors sollten sein { BatchCount, ChannelCount, K, N }
, wenn TransBDML_MATRIX_TRANSFORM_NONE oder { BatchCount, ChannelCount, N, K }
TransBDML_MATRIX_TRANSFORM_TRANSPOSE ist.
CTensor
Typ: _Maybenull_ const DML_TENSOR_DESC*
Ein Tensor, der die C-Matrix oder nullptr
enthält. Standardwert ist 0, wenn nicht angegeben. Falls angegeben, sollte die Größe dieses Tensors sein { BatchCount, ChannelCount, M, N }
.
OutputTensor
Typ: const DML_TENSOR_DESC*
Der Tensor, in den die Ergebnisse geschrieben werden sollen. Die Größen dieses Tensors sind { BatchCount, ChannelCount, M, N }
.
TransA
Typ: DML_MATRIX_TRANSFORM
Die Transformation, die auf ATensor angewendet werden soll; entweder eine Transponieren oder keine Transformation.
TransB
Typ: DML_MATRIX_TRANSFORM
Die Transformation, die auf BTensor angewendet werden soll; entweder eine Transponieren oder keine Transformation.
Alpha
Typ: FLOAT
Der Wert des skalaren Multiplikators für das Produkt der Eingaben ATensor und BTensor.
Beta
Typ: FLOAT
Der Wert des Skalarmultiplikators für den optionalen Eingabe-CTensor. Wenn CTensor nicht bereitgestellt wird, wird dieser Wert ignoriert.
FusedActivation
Typ: _Maybenull_ const DML_OPERATOR_DESC*
Eine optionale Fused-Aktivierungsebene, die nach dem GEMM angewendet werden soll. Weitere Informationen finden Sie unter Verwenden von fusionierten Operatoren für verbesserte Leistung.
Verfügbarkeit
Dieser Operator wurde in DML_FEATURE_LEVEL_1_0
eingeführt.
Tensoreinschränkungen
- ATensor, BTensor, CTensor und OutputTensor müssen denselben DataType und DimensionCount aufweisen.
- CTensor und OutputTensor müssen die gleichen Größen aufweisen.
Tensorunterstützung
DML_FEATURE_LEVEL_4_0 und höher
Tensor | Variante | Dimensionen | Unterstützte Dimensionsanzahl | Unterstützte Datentypen |
---|---|---|---|---|
ATensor | Eingabe | { [BatchCount], [ChannelCount], M, K } | 2 bis 4 | FLOAT32, FLOAT16 |
BTensor | Eingabe | { [BatchCount], [ChannelCount], K, N } | 2 bis 4 | FLOAT32, FLOAT16 |
CTensor | Optionale Eingabe | { [BatchCount], [ChannelCount], M, N } | 2 bis 4 | FLOAT32, FLOAT16 |
OutputTensor | Ausgabe | { [BatchCount], [ChannelCount], M, N } | 2 bis 4 | FLOAT32, FLOAT16 |
DML_FEATURE_LEVEL_1_0 und höher
Tensor | Variante | Dimensionen | Unterstützte Dimensionsanzahl | Unterstützte Datentypen |
---|---|---|---|---|
ATensor | Eingabe | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32, FLOAT16 |
BTensor | Eingabe | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32, FLOAT16 |
CTensor | Optionale Eingabe | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
OutputTensor | Ausgabe | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
Anforderungen
Anforderung | Wert |
---|---|
Header | directml.h |