directml.h) (DML_GATHER_OPERATOR_DESC 结构
使用 IndicesTensor 重新映射索引,沿轴从输入张量中收集元素。 此运算符执行以下伪代码,其中“...”表示一系列坐标,其确切行为由轴和索引维度计数确定:
output[...] = input[..., indices[...], ...]
语法
struct DML_GATHER_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *IndicesTensor;
const DML_TENSOR_DESC *OutputTensor;
UINT Axis;
UINT IndexDimensions;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要从中读取的张量。
IndicesTensor
类型: const DML_TENSOR_DESC*
包含索引的张量。 此张量维度的 DimensionCount 必须与 InputTensor.DimensionCount 匹配。
从 DML_FEATURE_LEVEL_3_0
开始,将此张量一起使用带符号整数类型时,此运算符支持负索引值。 负索引被解释为相对于轴维度的末尾。 例如,索引为 -1 表示沿该维度的最后一个元素。
无效索引将产生不正确的输出,但不会失败,并且所有读取都将安全地固定在输入张量内存中。
OutputTensor
类型: const DML_TENSOR_DESC*
将结果写入到的张量。 此张量维度的 DimensionCount 和 DataType 必须与 InputTensor.DimensionCount 匹配。 预期的 OutputTensor.Sizes 是在当前轴上拆分的 InputTensor.Sizes 前导段和尾部段的串联,其中插入了 IndicesTensor.Sizes。
OutputTensor.Sizes = {
InputTensor.Sizes[0..Axis],
IndicesTensor.Sizes[(IndicesTensor.DimensionCount - IndexDimensions) .. IndicesTensor.DimensionCount],
InputTensor.Sizes[(Axis+1) .. InputTensor.DimensionCount]
}
维度是右对齐的,以便裁剪输入大小中的任何前导 1 值,否则会溢出输出 DimensionCount。
此张量中相关维度的数目取决于 IndexDimensions 和 InputTensor的原始秩。 原始秩是使用前导维度填充之前维度的数目。 输出中相关维度的数目可以通过 InputTensor + IndexDimensions - 1 的原始排名来计算。 此值必须小于或等于 OutputTensor 的 DimensionCount。
Axis
类型: UINT
要收集的 InputTensor 的轴维度,范围为 [0, *InputTensor.DimensionCount*)
。
IndexDimensions
类型: UINT
忽略任何不相关的前导维度后的实际索引维度 IndicesTensor
数,范围 [0, IndicesTensor.DimensionCount
) 。 例如,给定 IndicesTensor.Sizes
= { 1, 1, 4, 6 }
和 IndexDimensions
= 3,实际有意义的索引为 。{ 1, 4, 6 }
示例
示例 1。 1D 重新映射
Axis: 0
IndexDimensions: 1
InputTensor: (Sizes:{4}, DataType:FLOAT32)
[11,12,13,14]
IndicesTensor: (Sizes:{5}, DataType:UINT32)
[3,1,3,0,2]
// output[x] = input[indices[x]]
OutputTensor: (Sizes:{5}, DataType:FLOAT32)
[14,12,14,11,13]
示例 2。 2D 输出、1D 索引、轴 0、串联行
Axis: 0
IndexDimensions: 1
InputTensor: (Sizes:{3,2}, DataType:FLOAT32)
[[1,2], // row 0
[3,4], // row 1
[5,6]] // row 2
IndicesTensor: (Sizes:{1, 4}, DataType:UINT32)
[[0,
1,
1,
2]]
// output[y, x] = input[indices[y], x]
OutputTensor: (Sizes:{4,2}, DataType:FLOAT32)
[[1,2], // input row 0
[3,4], // input row 1
[3,4], // input row 1
[5,6]] // input row 2
示例 3。 2D,轴 1,交换列
Axis: 1
IndexDimensions: 2
InputTensor: (Sizes:{3,2}, DataType:FLOAT32)
[[1,2],
[3,4],
[5,6]]
IndicesTensor: (Sizes:{1, 2}, DataType:UINT32)
[[1,0]]
// output[y, x] = input[y, indices[x]]
OutputTensor: (Sizes:{3,2}, DataType:FLOAT32)
[[2,1],
[4,3],
[6,5]]
示例 4. 2D、轴 1、嵌套索引
Axis: 2
IndexDimensions: 2
InputTensor: (Sizes:{1, 3,3}, DataType:FLOAT32)
[ [[1,2,3],
[4,5,6],
[7,8,9]] ]
IndicesTensor: (Sizes:{1, 1,2}, DataType:UINT32)
[ [[0,2]] ]
// output[z, y, x] = input[z, indices[y, x]]
OutputTensor: (Sizes:{3,1,2}, DataType:FLOAT32)
[[[1,3]],
[[4,6]],
[[7,9]]]
示例 5. 2D、轴 0、嵌套索引
Axis: 1
IndexDimensions: 2
InputTensor: (Sizes:{1, 3,2}, DataType:FLOAT32)
[ [[1,2],
[3,4],
[5,6]] ]
IndicesTensor: (Sizes:{1, 2,2}, DataType:UINT32)
[ [[0,1],
[1,2]] ]
// output[z, y, x] = input[indices[z, y], x]
OutputTensor: (Sizes:{2,2,2}, DataType:FLOAT32)
[[[1,2], [3,4]],
[[3,4], [5,6]]]
可用性
此运算符是在 中 DML_FEATURE_LEVEL_1_0
引入的。
张量约束
IndicesTensor
、 InputTensor 和 OutputTensor 必须具有相同的 DimensionCount。- InputTensor 和 OutputTensor 必须具有相同的 数据类型。
张量支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 4 | UINT32 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_1_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16 |
IndicesTensor | 输入 | 4 | UINT32 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16 |
要求
标头 | directml.h |