DML_TOP_K1_OPERATOR_DESC结构 (directml.h)
从 InputTensor轴上选择每个序列中的最大或最小 K 元素,并分别返回 OutputValueTensor 和 OutputIndexTensor中这些元素的值和索引。 序列 是指 InputTensor轴 维度上存在的元素集之一。
可以使用 AxisDirection来控制选择最大 K 元素还是最小 K 元素的选择。
语法
struct DML_TOP_K1_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *OutputValueTensor;
const DML_TENSOR_DESC *OutputIndexTensor;
UINT Axis;
UINT K;
DML_AXIS_DIRECTION AxisDirection;
};
成员
InputTensor
类型:const DML_TENSOR_DESC*
包含要选择的元素的输入张量。
OutputValueTensor
类型:const DML_TENSOR_DESC*
要向其写入顶部 K 元素的值的输出张量。 根据 AxisDirection的值,选择顶部 K 元素。 此张量的大小必须等于 InputTensor,除 轴 参数指定的维度 外,其大小必须等于 K。
如果 AxisDirectionDML_AXIS_DIRECTION_DECREASING,则保证从每个输入序列中选择的 K 值进行降序(最大到最小)。 否则,相反的是 true,并保证所选值按升序排序(最小到最大)。
OutputIndexTensor
类型:const DML_TENSOR_DESC*
要向其写入顶部 K 元素的索引的输出张量。 此张量的大小必须等于 InputTensor,除 轴 参数指定的维度 外,其大小必须等于 K。
在此张量中返回的索引相对于其序列的开头(而不是张量开头)进行测量。 例如,索引 0 始终引用轴中所有序列的第一个元素。
如果 top-K 中的两个或更多个元素具有相同的值(即当有平局时),则包含这两个元素的索引,并保证按升序元素索引进行排序。 请注意,无论 AxisDirection的值如何,都是如此。
Axis
类型:UINT
要跨的元素选择的维度的索引。 此值必须小于 InputTensor的 DimensionCount。
K
类型:UINT
要选择的元素数。 K 必须大于 0,但小于 InputTensor 中由 轴指定的维度中的元素数。
AxisDirection
来自 DML_AXIS_DIRECTION 枚举的值。 如果设置为 DML_AXIS_DIRECTION_INCREASING,则此运算符返回 最小K 元素,以递增值。 否则,它将按递减顺序返回 最大K 元素。
例子
示例 1
InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0, 1, 10, 11],
[ 3, 2, 9, 8],
[ 4, 5, 6, 7]]]]
Axis: 3
K: 2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
OutputValueTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[11, 10],
[ 9, 8],
[ 7, 6]]]]
OutputIndexTensor: (Sizes:{1,1,3,2}, DataType:UINT32)
[[[[3, 2],
[2, 3],
[3, 2]]]]
示例 2. 使用不同的轴
InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0, 1, 10, 11],
[ 3, 2, 9, 8],
[ 4, 5, 6, 7]]]]
Axis: 2
K: 2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
OutputValueTensor: (Sizes:{1,1,2,4}, DataType:FLOAT32)
[[[[ 4, 5, 10, 11],
[ 3, 2, 9, 8]]]]
OutputIndexTensor: (Sizes:{1,1,2,4}, DataType:UINT32)
[[[[2, 2, 0, 0],
[1, 1, 1, 1]]]]
示例 3. 绑定值
InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
[3, 4, 5, 5],
[6, 6, 6, 6]]]]
Axis: 3
K: 3
AxisDirection: DML_AXIS_DIRECTION_DECREASING
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[3, 2, 2],
[5, 5, 4],
[6, 6, 6]]]]
OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[3, 1, 2],
[2, 3, 1],
[0, 1, 2]]]]
示例 4. 增加轴方向
InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
[3, 4, 5, 5],
[6, 6, 6, 6]]]]
Axis: 3
K: 3
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[1, 2, 2],
[3, 4, 5],
[6, 6, 6]]]]
OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[0, 1, 2],
[0, 1, 2],
[0, 1, 2]]]]
言论
当 AxisDirection 设置为 DML_AXIS_DIRECTION_DECREASING时,此运算符等效于 DML_TOP_K_OPERATOR_DESC。
可用性
此运算符是在 DML_FEATURE_LEVEL_2_1
中引入的。
Tensor 约束
- InputTensor、OutputIndexTensor,OutputValueTensor 必须具有相同 DimensionCount。
- InputTensor 和 OutputValueTensor 必须具有相同 DataType。
Tensor 支持
DML_FEATURE_LEVEL_5_0及更高版本
张肌 | 类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputValueTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputIndexTensor | 输出 | 1 到 8 | UINT64、UINT32 |
DML_FEATURE_LEVEL_3_1及更高版本
张肌 | 类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputValueTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputIndexTensor | 输出 | 1 到 8 | UINT32 |
DML_FEATURE_LEVEL_2_1及更高版本
张肌 | 类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputValueTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputIndexTensor | 输出 | 4 | UINT32 |
要求
要求 | 价值 |
---|---|
最低支持的客户端 | Windows 10 内部版本 20348 |
支持的最低服务器 | Windows 10 内部版本 20348 |
标头 | directml.h |