DML_MULTIHEAD_ATTENTION_OPERATOR_DESC 结构 (directml.h)
执行多头注意力运算(有关详细信息,请参阅《Attention is all you need》(注意力是你所需要的一切))。 必须存在一个 Query、Key 和 Value 张量,无论它们是否堆叠在一起。 例如,如果提供了 StackedQueryKey,则 Query 和 Key 张量都必须设置为 null,因为它们已在堆积布局中提供。 StackedKeyValue 和 StackedQueryKeyValue 也遵循相同的原则。 堆叠张量始终有五个维度,并且始终堆叠在第四个维度上。
从逻辑上讲,可以将算法分解为以下运算(括号中的运算是可选的):
[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);
重要
此 API 作为 DirectML 独立可再发行组件包的一部分提供(请参阅 Microsoft.AI.DirectML 版本 1.12 及更高版本。 另请参阅 DirectML 版本历史记录。
语法
struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
_Maybenull_ const DML_TENSOR_DESC* QueryTensor;
_Maybenull_ const DML_TENSOR_DESC* KeyTensor;
_Maybenull_ const DML_TENSOR_DESC* ValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
_Maybenull_ const DML_TENSOR_DESC* MaskTensor;
_Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
_Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
const DML_TENSOR_DESC* OutputTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
FLOAT Scale;
FLOAT MaskFilterValue;
UINT HeadCount;
DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};
成员
QueryTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, sequenceLength, hiddenSize]
的查询,其中 hiddenSize = headCount * headSize
。 此张量与 StackedQueryKeyTensor 和 StackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。
KeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, keyValueSequenceLength, hiddenSize]
的键,其中 hiddenSize = headCount * headSize
。 此张量与 StackedQueryKeyTensor、StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。
ValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, keyValueSequenceLength, valueHiddenSize]
的值,其中 valueHiddenSize = headCount * valueHeadSize
。 此张量与 StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互相排斥。 该张量也可以有 4 或 5 个维度,只要前导维度是 1 即可。
StackedQueryKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, sequenceLength, headCount, 2, headSize]
的堆叠查询和键。 此张量与 QueryTensor, KeyTensor、StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互相排斥。
StackedKeyValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, keyValueSequenceLength, headCount, 2, headSize]
的堆叠键和值。 此张量与 KeyTensor、ValueTensor、StackedQueryKeyTensor 和 StackedQueryKeyValueTensor 互相排斥。
StackedQueryKeyValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
形状为 [batchSize, sequenceLength, headCount, 3, headSize]
的堆叠查询、键和值。 此张量与 QueryTensor、KeyTensor、ValueTensor、StackedQueryKeyTensor 和 StackedKeyValueTensor 互相排斥。
BiasTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是形状 [hiddenSize + hiddenSize + valueHiddenSize]
的偏差,在执行第一个 GEMM 运算之前会添加到查询/键/值。 此张量也可以有 2、3、4 或 5 个维度,只要前导维度是 1 即可。
MaskTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是一个掩码,用于确定哪些元素在执行 QxK GEMM 运算后将其值设置为 MaskFilterValue。 此掩码的行为取决于 MaskType 的值,并在 RelativePositionBiasTensor 之后应用;或者,如果 RelativePositionBiasTensor 设置为 null,则在第一个 GEMM 运算之后应用。 有关详细信息,请参阅 MaskType 的定义。
RelativePositionBiasTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是添加到第一个 GEMM 运算结果的偏差。
PastKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
上一次迭代中形状为 [batchSize, headCount, pastSequenceLength, headSize]
的键张量。 当此张量不为 null 值时,它将与键张量连接,从而产生形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的张量。
PastValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
上一次迭代中形状为 [batchSize, headCount, pastSequenceLength, headSize]
的值张量。 当此张量不为 null 值时,它将与 ValueDesc 连接,从而产生形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的张量。
OutputTensor
类型:const DML_TENSOR_DESC*
形状为 [batchSize, sequenceLength, valueHiddenSize]
的输出。
OutputPresentKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
显示交叉注意力键的状态,形状为 [batchSize, headCount, keyValueSequenceLength, headSize]
,或显示形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的自注意力的状态。 它包含键张量的内容或连接的 PastKey + Key 张量的内容,以传递给下一次迭代。
OutputPresentValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
显示交叉注意力值的状态,形状为 [batchSize, headCount, keyValueSequenceLength, headSize]
,或显示形状为 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的自注意力的状态。 它包含值张量的内容或连接的 PastValue + Value 张量的内容,以传递给下一次迭代。
Scale
类型:FLOAT
用于乘以 QxK GEMM 运算结果的刻度,但在 Softmax 运算之前应用。 此值通常为 1/sqrt(headSize)
。
MaskFilterValue
类型:FLOAT
将 QxK GEMM 运算结果添加到定义为填充元素的掩码的位置的值。 此值应为非常大的负数(通常为 -10000.0f)。
HeadCount
类型:UINT
注意力头数。
MaskType
类型:DML_MULTIHEAD_ATTENTION_MASK_TYPE
描述 MaskTensor 的行为。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN。 当掩码包含值 0 时,将添加 MaskFilterValue;但是当它包含值 1 时,不会添加任何内容。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH。 形状为 [1, batchSize]
的掩码包含每个批次的未填充区域的序列长度,序列长度后的所有元素都将其值设置为 MaskFilterValue。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START。 形状为 [2, batchSize]
的掩码包含未填充区域的结束(独占)和起始(非独占)索引,该区域外部的所有元素都将其值设置为 MaskFilterValue。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END。 形状为 [batchSize * 3 + 2]
的掩码具有以下值之一:[keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]
。
可用性
此运算符是在 DML_FEATURE_LEVEL_6_1 中引入的。
张量约束
BiasTensor、KeyTensor、OutputPresentKeyTensor、OutputPresentValueTensor、OutputTensor、PastKeyTensor、PastValueTensor、QueryTensor、RelativePositionBiasTensor、StackedKeyValueTensor、StackedQueryKeyTensor、StackedQueryKeyValueTensor 和 ValueTensor 必须具有相同的 DataType。
张量支持
张量 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
QueryTensor | 可选输入 | 3 到 5 | FLOAT32、FLOAT16 |
KeyTensor | 可选输入 | 3 到 5 | FLOAT32、FLOAT16 |
ValueTensor | 可选输入 | 3 到 5 | FLOAT32、FLOAT16 |
StackedQueryKeyTensor | 可选输入 | 5 | FLOAT32、FLOAT16 |
StackedKeyValueTensor | 可选输入 | 5 | FLOAT32、FLOAT16 |
StackedQueryKeyValueTensor | 可选输入 | 5 | FLOAT32、FLOAT16 |
BiasTensor | 可选输入 | 1 到 5 | FLOAT32、FLOAT16 |
MaskTensor | 可选输入 | 1 到 5 | INT32 |
RelativePositionBiasTensor | 可选输入 | 4 到 5 | FLOAT32、FLOAT16 |
PastKeyTensor | 可选输入 | 4 到 5 | FLOAT32、FLOAT16 |
PastValueTensor | 可选输入 | 4 到 5 | FLOAT32、FLOAT16 |
OutputTensor | 输出 | 3 到 5 | FLOAT32、FLOAT16 |
OutputPresentKeyTensor | 可选输出 | 4 到 5 | FLOAT32、FLOAT16 |
OutputPresentValueTensor | 可选输出 | 4 到 5 | FLOAT32、FLOAT16 |
要求
页眉 | directml.h |