DML_JOIN_OPERATOR_DESC 结构 (directml.h)
沿指定轴连接输入张量数组。
仅当输入张量在所有维度中的大小相同时,输入张量才可联接,但联接轴可能包含任何非零大小。 输出大小等于输入大小(联接轴除外),联接轴是所有输入联接轴大小的总和。 下面的伪代码演示了这些约束。
joinSize = 0;
for (i = 0; i < InputCount; i++) {
assert(inputTensors[i]->DimensionCount == outputTensor->DimensionCount);
for (dim = 0; dim < outputTensor->DimensionCount; dim++) {
if (dim == Axis) { joinSize += inputTensors[i]->Sizes[dim]; }
else { assert(inputTensors[i]->Sizes[dim] == outputTensor->Sizes[dim]); }
}
}
assert(joinSize == outputTensor->Sizes[Axis]);
联接单个输入张量只会生成输入张量的副本。
此运算符是 DML_SPLIT_OPERATOR_DESC的反函数。
语法
struct DML_JOIN_OPERATOR_DESC {
UINT InputCount;
const DML_TENSOR_DESC *InputTensors;
const DML_TENSOR_DESC *OutputTensor;
UINT Axis;
};
成员
InputCount
类型: UINT
此字段确定 InputTensors 数组的大小。 此值必须大于 0。
InputTensors
类型:_Field_size_ (InputCount) const DML_TENSOR_DESC*
一个数组,其中包含要联接到单个输出张量中的张量的说明。 此数组中的所有输入张量必须具有相同的大小,但联接轴可能具有任何非零值。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入联接输入张量的张量。 输出大小必须与所有输入张量具有相同的大小,联接轴除外,联接轴必须等于所有输入的联接轴大小之和。
Axis
类型: UINT
要联接的输入张量维度的索引。 除此轴之外,所有输入和输出张量在所有维度中都必须具有相同的大小。 此值必须位于范围 [0, OutputTensor.DimensionCount - 1]
中。
示例
示例 1。 联接只有一个可能的轴的张量
在此示例中,张量只能沿着第四个维度 (轴 3) 联接。 无法联接任何其他轴,因为第四维中的张量大小不匹配。
InputCount: 2
Axis: 3
InputTensors[0]: (Sizes:{1, 1, 2, 3}, DataType:FLOAT32)
[[[[ 1, 2, 3],
[ 4, 5, 6]]]]
InputTensors[1]: (Sizes:{1, 1, 2, 4}, DataType:FLOAT32)
[[[[ 7, 8, 9, 10],
[11, 12, 13, 14]]]]
OutputTensor: (Sizes:{1, 1, 2, 7}, DataType:FLOAT32)
[[[[ 1, 2, 3, 7, 8, 9, 10],
[ 4, 5, 6, 11, 12, 13, 14]]]]
示例 2。 联接具有多个可能轴的张量:
以下示例使用相同的输入张量。 由于所有输入在所有维度中具有相同的大小,因此可以沿任何维度联接它们。
InputCount: 3
InputTensors[0]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
[3, 4]]]]
InputTensors[1]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[5, 6],
[7, 8]]]]
InputTensors[2]: (Sizes:{1, 1, 2, 2}, DataType:FLOAT32)
[[[[9, 10],
[11, 12]]]]
联接轴 1:
Axis: 1
OutputTensor: (Sizes:{1, 3, 2, 2}, DataType:FLOAT32)
[[[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]]]
联接轴 2:
Axis: 2
OutputTensor: (Sizes:{1, 1, 6, 2}, DataType:FLOAT32)
[[[[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12]]]]
联接轴 3:
Axis: 3
OutputTensor: (Sizes:{1, 1, 2, 6}, DataType:FLOAT32)
[[[[1, 2, 5, 6, 9, 10],
[3, 4, 7, 8, 11, 12]]]]
可用性
此运算符是在 中 DML_FEATURE_LEVEL_1_0
引入的。
张量约束
InputTensors 和 OutputTensor 必须具有相同的 DataType 和 DimensionCount。
张量支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensors | 输入数组 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensors | 输入数组 | 4 到 5 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 4 到 5 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensors | 输入数组 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_1_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensors | 输入数组 | 4 | FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、UINT32、UINT16 |
要求
标头 | directml.h |