DML_TOP_K1_OPERATOR_DESC结构 (directml.h)

InputTensor轴上选择每个序列中的最大或最小 K 元素,并分别返回 OutputValueTensorOutputIndexTensor中这些元素的值和索引。 序列 是指 InputTensor 维度上存在的元素集之一。

可以使用 AxisDirection来控制选择最大 K 元素还是最小 K 元素的选择。

语法

struct DML_TOP_K1_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *OutputValueTensor;
  const DML_TENSOR_DESC *OutputIndexTensor;
  UINT                  Axis;
  UINT                  K;
  DML_AXIS_DIRECTION    AxisDirection;
};

成员

InputTensor

类型:const DML_TENSOR_DESC*

包含要选择的元素的输入张量。

OutputValueTensor

类型:const DML_TENSOR_DESC*

要向其写入顶部 K 元素的值的输出张量。 根据 AxisDirection的值,选择顶部 K 元素。 此张量的大小必须等于 InputTensor 参数指定的维度 外,其大小必须等于 K

如果 AxisDirectionDML_AXIS_DIRECTION_DECREASING,则保证从每个输入序列中选择的 K 值进行降序(最大到最小)。 否则,相反的是 true,并保证所选值按升序排序(最小到最大)。

OutputIndexTensor

类型:const DML_TENSOR_DESC*

要向其写入顶部 K 元素的索引的输出张量。 此张量的大小必须等于 InputTensor 参数指定的维度 外,其大小必须等于 K

在此张量中返回的索引相对于其序列的开头(而不是张量开头)进行测量。 例如,索引 0 始终引用轴中所有序列的第一个元素。

如果 top-K 中的两个或更多个元素具有相同的值(即当有平局时),则包含这两个元素的索引,并保证按升序元素索引进行排序。 请注意,无论 AxisDirection的值如何,都是如此。

Axis

类型:UINT

要跨的元素选择的维度的索引。 此值必须小于 InputTensorDimensionCount

K

类型:UINT

要选择的元素数。 K 必须大于 0,但小于 InputTensor 中由 指定的维度中的元素数。

AxisDirection

类型:DML_AXIS_DIRECTION

来自 DML_AXIS_DIRECTION 枚举的值。 如果设置为 DML_AXIS_DIRECTION_INCREASING,则此运算符返回 最小K 元素,以递增值。 否则,它将按递减顺序返回 最大K 元素。

例子

示例 1

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0,  1, 10, 11],
   [ 3,  2,  9,  8],
   [ 4,  5,  6,  7]]]]

Axis: 3
K:    2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[11, 10],
   [ 9,  8],
   [ 7,  6]]]]

OutputIndexTensor: (Sizes:{1,1,3,2}, DataType:UINT32)
[[[[3, 2],
   [2, 3],
   [3, 2]]]]

示例 2. 使用不同的轴

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0,  1, 10, 11],
   [ 3,  2,  9,  8],
   [ 4,  5,  6,  7]]]]

Axis: 2
K:    2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,2,4}, DataType:FLOAT32)
[[[[ 4,  5, 10, 11],
   [ 3,  2,  9,  8]]]]

OutputIndexTensor: (Sizes:{1,1,2,4}, DataType:UINT32)
[[[[2, 2, 0, 0],
   [1, 1, 1, 1]]]]

示例 3. 绑定值

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
   [3, 4, 5, 5],
   [6, 6, 6, 6]]]]

Axis: 3
K:    3
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[3, 2, 2],
   [5, 5, 4],
   [6, 6, 6]]]]

OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[3, 1, 2],
   [2, 3, 1],
   [0, 1, 2]]]]

示例 4. 增加轴方向

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
   [3, 4, 5, 5],
   [6, 6, 6, 6]]]]

Axis: 3
K:    3
AxisDirection: DML_AXIS_DIRECTION_INCREASING
   
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[1, 2, 2],
   [3, 4, 5],
   [6, 6, 6]]]]

OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[0, 1, 2],
   [0, 1, 2],
   [0, 1, 2]]]]

言论

AxisDirection 设置为 DML_AXIS_DIRECTION_DECREASING时,此运算符等效于 DML_TOP_K_OPERATOR_DESC

可用性

此运算符是在 DML_FEATURE_LEVEL_2_1中引入的。

Tensor 约束

  • InputTensorOutputIndexTensorOutputValueTensor 必须具有相同 DimensionCount
  • InputTensorOutputValueTensor 必须具有相同 DataType

Tensor 支持

DML_FEATURE_LEVEL_5_0及更高版本

张肌 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputValueTensor 输出 1 到 8 FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputIndexTensor 输出 1 到 8 UINT64、UINT32

DML_FEATURE_LEVEL_3_1及更高版本

张肌 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputValueTensor 输出 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputIndexTensor 输出 1 到 8 UINT32

DML_FEATURE_LEVEL_2_1及更高版本

张肌 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputValueTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputIndexTensor 输出 4 UINT32

要求

要求 价值
最低支持的客户端 Windows 10 内部版本 20348
支持的最低服务器 Windows 10 内部版本 20348
标头 directml.h