DML_ARGMAX_OPERATOR_DESC 结构 (directml.h)
输出输入张量一个或多个维度内最大值元素的索引。
每个输出元素都是对输入张量子集应用 argmax 缩减的结果。 argmax 函数输出一组输入元素中最大值元素的索引。 每个缩减所涉及的输入元素由提供的输入轴确定。 同样,每个输出索引都与提供的输入轴相关。 如果指定了所有输入轴,运算符将应用单个 argmax 缩减,并生成单个输出元素。
语法
struct DML_ARGMAX_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *OutputTensor;
UINT AxisCount;
const UINT *Axes;
DML_AXIS_DIRECTION AxisDirection;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要从中读取的张量。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入结果的张量。 每个输出元素都是 InputTensor 中元素子集的 argmax 缩减的结果。
- DimensionCount 必须与 InputTensor.DimensionCount 匹配, (输入张量排名保留) 。
- 大小 必须与 InputTensor.Size 匹配,但缩小 的轴中包含的维度除外,其大小必须为 1。
AxisCount
类型: UINT
要减少的轴数。 此字段确定 Axes 数组的大小。
Axes
类型:_Field_size_ (AxisCount) const UINT*
要沿其减小的轴。 值必须位于 范围 [0, InputTensor.DimensionCount - 1]
中。
AxisDirection
确定当多个输入元素具有相同值时要选择的索引。
- DML_AXIS_DIRECTION_INCREASING 返回第一个最大值元素 (的索引,
argmax({3,2,1,2,3}) = 0
例如,) - DML_AXIS_DIRECTION_DECREASING 返回最后一个最大值元素 (的索引,
argmax({3,2,1,2,3}) = 4
例如,)
示例
本部分中的示例都使用相同的二维输入张量。
InputTensor: (Sizes:{3, 3}, DataType:FLOAT32)
[[1, 2, 3],
[3, 0, 4],
[2, 5, 2]]
示例 1。 将 argmax 应用于列
AxisCount: 1
Axes: {0}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 3}, DataType:UINT32)
[[1, // argmax({1, 3, 2})
2, // argmax({2, 0, 5})
1]] // argmax({3, 4, 2})
示例 2。 将 argmax 应用于行
AxisCount: 1
Axes: {1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{3, 1}, DataType:UINT32)
[[2], // argmax({1, 2, 3})
[2], // argmax({3, 0, 4})
[1]] // argmax({2, 5, 2})
示例 3。 将 argmax 应用于整个张量) (所有轴
AxisCount: 2
Axes: {0, 1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 1}, DataType:UINT32)
[[7]] // argmax({1, 2, 3, 3, 0, 4, 2, 5, 2})
备注
输出张量大小必须与输入张量大小相同,但缩小的轴必须为 1。
DML_AXIS_DIRECTION_INCREASINGAxisDirection 时,此 API 等效于使用 DML_REDUCE_FUNCTION_ARGMAXDML_REDUCE_OPERATOR_DESC。
此功能的子集通过 DML_REDUCE_OPERATOR_DESC 运算符公开,在早期的 DirectML 功能级别上受支持。
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_3_0
。
张量约束
InputTensor 和 OutputTensor 必须具有相同的 DimensionCount。
Tensor 支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
DML_FEATURE_LEVEL_3_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
要求
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |