Freigeben über


DML_GEMM_OPERATOR_DESC-Struktur (directml.h)

Führt eine allgemeine Matrixmultiplikationsfunktion des Formulars Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)aus, wobei x die Matrixmultiplikation und * die Multiplikation mit einem Skalar bezeichnet wird.

Dieser Operator erfordert 4D-Tensors mit Layout { BatchCount, ChannelCount, Height, Width }und führt BatchCount * ChannelCount-Anzahl unabhängiger Matrixmultiplikationen aus.

Wenn ATensor beispielsweise Größen von { BatchCount, ChannelCount, M, K }hat und BTensorGrößen von { BatchCount, ChannelCount, K, N }hat und OutputTensorgrößen von { BatchCount, ChannelCount, M, N }hat, führt dieser Operator BatchCount * ChannelCount-unabhängige Matrixmultiplikationen von Dimensionen {M,K} x {K,N} = {M,N} = {M,N} durch.

Syntax

struct DML_GEMM_OPERATOR_DESC {
  const DML_TENSOR_DESC   *ATensor;
  const DML_TENSOR_DESC   *BTensor;
  const DML_TENSOR_DESC   *CTensor;
  const DML_TENSOR_DESC   *OutputTensor;
  DML_MATRIX_TRANSFORM    TransA;
  DML_MATRIX_TRANSFORM    TransB;
  FLOAT                   Alpha;
  FLOAT                   Beta;
  const DML_OPERATOR_DESC *FusedActivation;
};

Member

ATensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die A-Matrix enthält. Die Größen dieses Tensors sollten sein { BatchCount, ChannelCount, M, K } , wenn TransADML_MATRIX_TRANSFORM_NONE oder { BatchCount, ChannelCount, K, M }TransADML_MATRIX_TRANSFORM_TRANSPOSE ist.

BTensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die B-Matrix enthält. Die Größen dieses Tensors sollten sein { BatchCount, ChannelCount, K, N } , wenn TransBDML_MATRIX_TRANSFORM_NONE oder { BatchCount, ChannelCount, N, K }TransBDML_MATRIX_TRANSFORM_TRANSPOSE ist.

CTensor

Typ: _Maybenull_ const DML_TENSOR_DESC*

Ein Tensor, der die C-Matrix oder nullptrenthält. Standardwert ist 0, wenn nicht angegeben. Falls angegeben, sollte die Größe dieses Tensors sein { BatchCount, ChannelCount, M, N }.

OutputTensor

Typ: const DML_TENSOR_DESC*

Der Tensor, in den die Ergebnisse geschrieben werden sollen. Die Größen dieses Tensors sind { BatchCount, ChannelCount, M, N }.

TransA

Typ: DML_MATRIX_TRANSFORM

Die Transformation, die auf ATensor angewendet werden soll; entweder eine Transponieren oder keine Transformation.

TransB

Typ: DML_MATRIX_TRANSFORM

Die Transformation, die auf BTensor angewendet werden soll; entweder eine Transponieren oder keine Transformation.

Alpha

Typ: FLOAT

Der Wert des skalaren Multiplikators für das Produkt der Eingaben ATensor und BTensor.

Beta

Typ: FLOAT

Der Wert des Skalarmultiplikators für den optionalen Eingabe-CTensor. Wenn CTensor nicht bereitgestellt wird, wird dieser Wert ignoriert.

FusedActivation

Typ: _Maybenull_ const DML_OPERATOR_DESC*

Eine optionale Fused-Aktivierungsebene, die nach dem GEMM angewendet werden soll. Weitere Informationen finden Sie unter Verwenden von fusionierten Operatoren für verbesserte Leistung.

Verfügbarkeit

Dieser Operator wurde in DML_FEATURE_LEVEL_1_0eingeführt.

Tensoreinschränkungen

  • ATensor, BTensor, CTensor und OutputTensor müssen denselben DataType und DimensionCount aufweisen.
  • CTensor und OutputTensor müssen die gleichen Größen aufweisen.

Tensorunterstützung

DML_FEATURE_LEVEL_4_0 und höher

Tensor Variante Dimensionen Unterstützte Dimensionsanzahl Unterstützte Datentypen
ATensor Eingabe { [BatchCount], [ChannelCount], M, K } 2 bis 4 FLOAT32, FLOAT16
BTensor Eingabe { [BatchCount], [ChannelCount], K, N } 2 bis 4 FLOAT32, FLOAT16
CTensor Optionale Eingabe { [BatchCount], [ChannelCount], M, N } 2 bis 4 FLOAT32, FLOAT16
OutputTensor Ausgabe { [BatchCount], [ChannelCount], M, N } 2 bis 4 FLOAT32, FLOAT16

DML_FEATURE_LEVEL_1_0 und höher

Tensor Variante Dimensionen Unterstützte Dimensionsanzahl Unterstützte Datentypen
ATensor Eingabe { BatchCount, ChannelCount, M, K } 4 FLOAT32, FLOAT16
BTensor Eingabe { BatchCount, ChannelCount, K, N } 4 FLOAT32, FLOAT16
CTensor Optionale Eingabe { BatchCount, ChannelCount, M, N } 4 FLOAT32, FLOAT16
OutputTensor Ausgabe { BatchCount, ChannelCount, M, N } 4 FLOAT32, FLOAT16

Anforderungen

Anforderung Wert
Header directml.h

Weitere Informationen