DML_MULTIHEAD_ATTENTION_OPERATOR_DESC 结构 (directml.h)

执行多头注意力运算(有关详细信息,请参阅《Attention is all you need》(注意力是你所需要的一切))。 必须存在一个 QueryKeyValue 张量,无论它们是否堆叠在一起。 例如,如果提供了 StackedQueryKey,则 QueryKey 张量都必须设置为 null,因为它们已在堆积布局中提供。 StackedKeyValueStackedQueryKeyValue 也遵循相同的原则。 堆叠张量始终有五个维度,并且始终堆叠在第四个维度上。

从逻辑上讲,可以将算法分解为以下运算(括号中的运算是可选的):

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

重要

此 API 作为 DirectML 独立可再发行组件包的一部分提供(请参阅 Microsoft.AI.DirectML 版本 1.12 及更高版本。 另请参阅 DirectML 版本历史记录

语法

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

成员

QueryTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, sequenceLength, hiddenSize] 的查询,其中 hiddenSize = headCount * headSize。 此张量与 StackedQueryKeyTensorStackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。

KeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, keyValueSequenceLength, hiddenSize] 的键,其中 hiddenSize = headCount * headSize。 此张量与 StackedQueryKeyTensorStackedKeyValueTensorStackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。

ValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, keyValueSequenceLength, valueHiddenSize] 的值,其中 valueHiddenSize = headCount * valueHeadSize。 此张量与 StackedKeyValueTensorStackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。

StackedQueryKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, sequenceLength, headCount, 2, headSize] 的堆叠查询和键。 此张量与 QueryTensor, KeyTensorStackedKeyValueTensorStackedQueryKeyValueTensor 互相排斥。

StackedQueryKeyTensor layout

StackedKeyValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, keyValueSequenceLength, headCount, 2, headSize] 的堆叠键和值。 此张量与 KeyTensorValueTensorStackedQueryKeyTensorStackedQueryKeyValueTensor 互相排斥。

StackedKeyValueTensor layout

StackedQueryKeyValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

形状为 [batchSize, sequenceLength, headCount, 3, headSize] 的堆叠查询、键和值。 此张量与 QueryTensorKeyTensorValueTensorStackedQueryKeyTensorStackedKeyValueTensor 互相排斥。

StackedQueryKeyValueTensor layout

BiasTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是形状 [hiddenSize + hiddenSize + valueHiddenSize] 的偏差,在执行第一个 GEMM 运算之前会添加到查询/键/。 此张量也可以有 2、3、4 或 5 个维度,只要前导维度是 1 即可。

MaskTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是一个掩码,用于确定哪些元素在执行 QxK GEMM 运算后将其值设置为 MaskFilterValue。 此掩码的行为取决于 MaskType 的值,并在 RelativePositionBiasTensor 之后应用;或者,如果 RelativePositionBiasTensor 设置为 null,则在第一个 GEMM 运算之后应用。 有关详细信息,请参阅 MaskType 的定义。

RelativePositionBiasTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是添加到第一个 GEMM 运算结果的偏差。

PastKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

上一次迭代中形状为 [batchSize, headCount, pastSequenceLength, headSize] 的键张量。 当此张量不为 null 值时,它将与键张量连接,从而产生形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] 的张量。

PastValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

上一次迭代中形状为 [batchSize, headCount, pastSequenceLength, headSize] 的值张量。 当此张量不为 null 值时,它将与 ValueDesc 连接,从而产生形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] 的张量。

OutputTensor

类型:const DML_TENSOR_DESC*

形状为 [batchSize, sequenceLength, valueHiddenSize] 的输出。

OutputPresentKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

显示交叉注意力键的状态,形状为 [batchSize, headCount, keyValueSequenceLength, headSize],或显示形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] 的自注意力的状态。 它包含键张量的内容或连接的 PastKey + Key 张量的内容,以传递给下一次迭代。

OutputPresentValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

显示交叉注意力值的状态,形状为 [batchSize, headCount, keyValueSequenceLength, headSize],或显示形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] 的自注意力的状态。 它包含值张量的内容或连接的 PastValue + Value 张量的内容,以传递给下一次迭代。

Scale

类型:FLOAT

用于乘以 QxK GEMM 运算结果的刻度,但在 Softmax 运算之前应用。 此值通常为 1/sqrt(headSize)

MaskFilterValue

类型:FLOAT

将 QxK GEMM 运算结果添加到定义为填充元素的掩码的位置的值。 此值应为非常大的负数(通常为 -10000.0f)。

HeadCount

类型:UINT

注意力头数。

MaskType

类型:DML_MULTIHEAD_ATTENTION_MASK_TYPE

描述 MaskTensor 的行为。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN。 当掩码包含值 0 时,将添加 MaskFilterValue;但是当它包含值 1 时,不会添加任何内容。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH。 形状为 [1, batchSize] 的掩码包含每个批次的未填充区域的序列长度,序列长度后的所有元素都将其值设置为 MaskFilterValue

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START。 形状为 [2, batchSize] 的掩码包含未填充区域的结束(独占)和起始(非独占)索引,该区域外部的所有元素都将其值设置为 MaskFilterValue

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END。 形状为 [batchSize * 3 + 2] 的掩码具有以下值之一:[keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]

可用性

此运算符是在 DML_FEATURE_LEVEL_6_1 中引入的。

张量约束

BiasTensorKeyTensorOutputPresentKeyTensorOutputPresentValueTensorOutputTensorPastKeyTensorPastValueTensorQueryTensorRelativePositionBiasTensorStackedKeyValueTensorStackedQueryKeyTensorStackedQueryKeyValueTensorValueTensor 必须具有相同的 DataType

张量支持

张量 种类 支持的维度计数 支持的数据类型
QueryTensor 可选输入 3 到 5 FLOAT32、FLOAT16
KeyTensor 可选输入 3 到 5 FLOAT32、FLOAT16
ValueTensor 可选输入 3 到 5 FLOAT32、FLOAT16
StackedQueryKeyTensor 可选输入 5 FLOAT32、FLOAT16
StackedKeyValueTensor 可选输入 5 FLOAT32、FLOAT16
StackedQueryKeyValueTensor 可选输入 5 FLOAT32、FLOAT16
BiasTensor 可选输入 1 到 5 FLOAT32、FLOAT16
MaskTensor 可选输入 1 到 5 INT32
RelativePositionBiasTensor 可选输入 4 到 5 FLOAT32、FLOAT16
PastKeyTensor 可选输入 4 到 5 FLOAT32、FLOAT16
PastValueTensor 可选输入 4 到 5 FLOAT32、FLOAT16
OutputTensor 输出 3 到 5 FLOAT32、FLOAT16
OutputPresentKeyTensor 可选输出 4 到 5 FLOAT32、FLOAT16
OutputPresentValueTensor 可选输出 4 到 5 FLOAT32、FLOAT16

要求

   
页眉 directml.h