DML_SCATTER_ND_OPERATOR_DESC 结构 (directml.h)
将整个输入张量复制到输出,然后使用更新张量中的相应值覆盖所选索引。 此运算符执行以下伪代码,其中“...”表示一系列坐标,其确切行为由轴和索引大小决定。
output = input
output[indices[...]] = updates[...]
如果两个输出元素索引重叠 (这) 无效,则无法保证最后一次写入中胜出。
语法
struct DML_SCATTER_ND_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *IndicesTensor;
const DML_TENSOR_DESC *UpdatesTensor;
const DML_TENSOR_DESC *OutputTensor;
UINT InputDimensionCount;
UINT IndicesDimensionCount;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要从中读取的张量。
IndicesTensor
类型: const DML_TENSOR_DESC*
包含索引的张量。 此张量维度的 DimensionCount 必须与 InputTensor.DimensionCount 匹配。 IndexesTensor 的最后一个维度实际上是每个索引元组的坐标数,并且不能超过 InputTensor.DimensionCount。 例如,索引大小 {1,4,5,2}
为 IndexesDimensionCount = 3 的张量意味着索引到 InputTensor 的 2 值坐标元组的 4x5 数组。
从 DML_FEATURE_LEVEL_3_0
开始,将此张量一起使用带符号整数类型时,此运算符支持负索引值。 负索引被解释为相对于相应维度的末尾。 例如,索引为 -1 表示沿该维度的最后一个元素。
UpdatesTensor
类型: const DML_TENSOR_DESC*
包含新值的张量,用于替换相应索引处的现有输入值。 此张量维度的 DimensionCount 必须与 InputTensor.DimensionCount 匹配。 预期的 UpdatesTensor.Sizes 是用于生成以下内容的 IndicesTensor.Sizes 前导段和 InputTensor.Sizes 尾随段的串联。
indexTupleSize = IndicesTensor.Sizes[IndicesTensor.DimensionCount - 1]
UpdatesTensor.Sizes = [
1...,
IndicesTensor.Sizes[(IndicesTensor.DimensionCount - IndicesDimensionCount) .. (IndicesTensor.DimensionCount - 1)],
InputTensor.Sizes[(InputTensor.DimensionCount - indexTupleSize) .. InputTensor.DimensionCount]
]
维度是右对齐的,如果需要满足 UpdatesTensor.DimensionCount,前面追加前导 1 个值。
下面是一个示例。
InputTensor.Sizes = [3,4,5,6,7]
InputDimensionCount = 5
IndicesTensor.Sizes = [1,1, 1,2,3]
IndicesDimensionCount = 3 // can be thought of as a [1,2] array of 3-coordinate tuples
// The [1,2] comes from the indices tensor (ignoring last dimension, which is the tuple size),
// and the [6,7] comes from input tensor, ignoring the first 3 dimensions
// since the index tuples are 3 elements (from the indices tensor last dimension).
UpdatesTensor.Sizes = [1, 1,2,6,7]
OutputTensor
类型: const DML_TENSOR_DESC*
将结果写入到的张量。 此张量 的大小 和 数据类型 必须与 InputTensor.Sizes 匹配。
InputDimensionCount
类型: UINT
忽略任何不相关的前导维度(范围 [1, InputTensor.DimensionCount) )后,InputTensor 中的实际输入维度数。 例如,假设 InputTensor.Sizes = {1,1,4,6} 和 InputDimensionCount = 3,实际有意义的索引为 {1,4,6}。
IndicesDimensionCount
类型: UINT
在忽略任何不相关的前导维度后, IndexesTensor 中的实际索引维度的数目,范围 [1, IndicesTensor.DimensionCount) 。 例如,假设 IndexesTensor.Sizes = {1,1,4,6} 和 IndexesDimensionCount = 3,实际有意义的索引为 {1,4,6}。
示例
InputTensor: (Sizes:{8}, DataType:FLOAT32)
[1, 2, 3, 4, 5, 6, 7, 8]
IndicesTensor: (Sizes:{4,1}, DataType:FLOAT32)
[[4], [3], [1], [7]]
UpdatesTensor: (Sizes:{4}, DataType:FLOAT32)
[9, 10, 11, 12]
// output = input
// output[indices[x, 0]] = updates[x]
OutputTensor: (Sizes:{8}, DataType:FLOAT32)
[1, 11, 3, 10, 9, 6, 7, 12]
可用性
此运算符是在 中 DML_FEATURE_LEVEL_2_1
引入的。
张量约束
- IndexesTensor、 InputTensor、 OutputTensor 和 UpdatesTensor 必须具有相同的 DimensionCount。
- InputTensor、 OutputTensor 和 UpdatesTensor 必须具有相同 的数据类型。
张量支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
UpdatesTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
UpdatesTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 4 | UINT32 |
UpdatesTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
要求
要求 | 值 |
---|---|
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |