(directml.h) DML_SPLIT_OPERATOR_DESC 结构

沿轴将输入张量拆分为多个输出张量。

除拆分轴外,所有输入和输出张量的大小都必须相同。 拆分轴中输入张量的大小决定了可能的拆分。 例如,如果输入张量拆分轴的大小为 3,则存在以下可能的拆分:1+1+1 (3 输出) ,1+2 (2 输出) ,2+1 (2 输出) ,或 3 (1 输出,这只是输入张量) 的副本。 输出张量拆分轴大小必须与输入张量的拆分轴大小相加。 下面的伪代码说明了这些约束。

splitSize = 0;

for (i = 0; i < OutputCount; i++) {
    assert(outputTensors[i]->DimensionCount == inputTensor->DimensionCount);
    for (dim = 0; dim < inputTensor->DimensionCount; dim++) {
        if (dim == Axis) { splitSize += outputTensors[i]->Sizes[dim]; }
        else { assert(outputTensors[i]->Sizes[dim] == inputTensor->Sizes[dim]); }
    }
}

assert(splitSize == inputTensor->Sizes[Axis]);

拆分为单个输出张量只会生成输入张量的副本。

此运算符是 DML_JOIN_OPERATOR_DESC的反数。

语法

struct DML_SPLIT_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  UINT                  OutputCount;
  const DML_TENSOR_DESC *OutputTensors;
  UINT                  Axis;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

要拆分为多个输出张量的张量。

OutputCount

类型: UINT

此字段确定 OutputTensors 数组的大小。 此值必须大于 0。

OutputTensors

类型: const DML_TENSOR_DESC*

包含从输入张量拆分的张量的说明的数组。 输出大小必须与输入张量的大小相同,拆分轴除外。

Axis

类型: UINT

要拆分的输入张量维度的索引。 除此轴之外,所有输入和输出张量在所有维度中必须具有相同的大小。 此值必须位于 范围 [0, InputTensor.DimensionCount - 1]中。

示例

以下示例使用相同的输入张量。

InputTensor: (Sizes:{1, 1, 6, 2}, DataType:FLOAT32)
[[[[1, 2],
   [3, 4],
   [5, 6],
   [7, 8],
   [9, 10],
   [11, 12]]]]

示例 1。 拆分轴 2

OutputCount: 3
Axis: 2

OutputTensors[0]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
   [3, 4]]]]

OutputTensors[1]: (Sizes:{1, 1, 1, 2}, DataType:FLOAT32)
[[[[5, 6]]]]

OutputTensors[2]: (Sizes:{1, 1, 3, 2}, DataType:FLOAT32)
[[[[7, 8],
   [9, 10],
   [11, 12]]]]

示例 2。 拆分轴 3

OutputCount: 2
Axis: 3

OutputTensors[0]: (Sizes:{1, 1, 6, 1}, DataType:FLOAT32)
[[[[1],
   [3],
   [5],
   [7],
   [9],
   [11]]]]

OutputTensors[1]: (Sizes:{1, 1, 6, 1}, DataType:FLOAT32)
[[[[2],
   [4],
   [6],
   [8],
   [10],
   [12]]]]

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_1_0

张量约束

InputTensorOutputTensors 必须具有相同的 DataTypeDimensionCount

Tensor 支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputTensors 输出数组 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_3_0及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputTensors 输出数组 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputTensors 输出数组 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_1_0 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16
OutputTensors 输出数组 4 FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16

要求

要求
Header directml.h