(directml.h) DML_SPLIT_OPERATOR_DESC 结构
沿轴将输入张量拆分为多个输出张量。
除拆分轴外,所有输入和输出张量的大小都必须相同。 拆分轴中输入张量的大小决定了可能的拆分。 例如,如果输入张量拆分轴的大小为 3,则存在以下可能的拆分:1+1+1 (3 输出) ,1+2 (2 输出) ,2+1 (2 输出) ,或 3 (1 输出,这只是输入张量) 的副本。 输出张量拆分轴大小必须与输入张量的拆分轴大小相加。 下面的伪代码说明了这些约束。
splitSize = 0;
for (i = 0; i < OutputCount; i++) {
assert(outputTensors[i]->DimensionCount == inputTensor->DimensionCount);
for (dim = 0; dim < inputTensor->DimensionCount; dim++) {
if (dim == Axis) { splitSize += outputTensors[i]->Sizes[dim]; }
else { assert(outputTensors[i]->Sizes[dim] == inputTensor->Sizes[dim]); }
}
}
assert(splitSize == inputTensor->Sizes[Axis]);
拆分为单个输出张量只会生成输入张量的副本。
此运算符是 DML_JOIN_OPERATOR_DESC的反数。
语法
struct DML_SPLIT_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
UINT OutputCount;
const DML_TENSOR_DESC *OutputTensors;
UINT Axis;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要拆分为多个输出张量的张量。
OutputCount
类型: UINT
此字段确定 OutputTensors 数组的大小。 此值必须大于 0。
OutputTensors
类型: const DML_TENSOR_DESC*
包含从输入张量拆分的张量的说明的数组。 输出大小必须与输入张量的大小相同,拆分轴除外。
Axis
类型: UINT
要拆分的输入张量维度的索引。 除此轴之外,所有输入和输出张量在所有维度中必须具有相同的大小。 此值必须位于 范围 [0, InputTensor.DimensionCount - 1]
中。
示例
以下示例使用相同的输入张量。
InputTensor: (Sizes:{1, 1, 6, 2}, DataType:FLOAT32)
[[[[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12]]]]
示例 1。 拆分轴 2
OutputCount: 3
Axis: 2
OutputTensors[0]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
[3, 4]]]]
OutputTensors[1]: (Sizes:{1, 1, 1, 2}, DataType:FLOAT32)
[[[[5, 6]]]]
OutputTensors[2]: (Sizes:{1, 1, 3, 2}, DataType:FLOAT32)
[[[[7, 8],
[9, 10],
[11, 12]]]]
示例 2。 拆分轴 3
OutputCount: 2
Axis: 3
OutputTensors[0]: (Sizes:{1, 1, 6, 1}, DataType:FLOAT32)
[[[[1],
[3],
[5],
[7],
[9],
[11]]]]
OutputTensors[1]: (Sizes:{1, 1, 6, 1}, DataType:FLOAT32)
[[[[2],
[4],
[6],
[8],
[10],
[12]]]]
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_1_0
。
张量约束
InputTensor 和 OutputTensors 必须具有相同的 DataType 和 DimensionCount。
Tensor 支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputTensors | 输出数组 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensors | 输出数组 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensors | 输出数组 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_1_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16 |
OutputTensors | 输出数组 | 4 | FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16 |
要求
要求 | 值 |
---|---|
Header | directml.h |