DML_DIAGONAL_MATRIX_OPERATOR_DESC结构 (directml.h)

生成一个类似标识的矩阵,该矩阵在主对角线上 (或其他显式值) ,其他位置为零。 对角线值可以通过 偏移) (移位,其中 OutputTensor[i, i + Offset] = Value,这意味着,大于零的 Offset 参数将所有值都向右移动,小于零会将这些值向左移动。 此生成器运算符适用于模型,以避免存储较大的常量张量。 最后两个之前的任何前导维度都被视为批计数,这意味着张量被视为 2D 矩阵的堆栈。

此运算符执行以下伪代码。

for each coordinate in OutputTensor
    OutputTensor[coordinate] = if (coordinate.y + Offset == coordinate.x) then Value else 0
endfor

语法

struct DML_DIAGONAL_MATRIX_OPERATOR_DESC {
  const DML_TENSOR_DESC *OutputTensor;
  INT                   Offset;
  FLOAT                 Value;
};

成员

OutputTensor

类型: const DML_TENSOR_DESC*

将结果写入到的张量。 维度为 { Batch1, Batch2, OutputHeight, OutputWidth }。 高度和宽度不需要为平方。

Offset

类型: INT

一个偏移量,用于移动 Value 的对角线,正偏移将写入值向右/向上移动 (将输出视为矩阵,左上角为 0,0) ,向左/下偏移量为负。

Value

类型: FLOAT

要沿 2D 对角线填充的值。 标准值为 1.0。 请注意,如果张量 DataTypeDML_TENSOR_DATA_TYPE_FLOAT16DML_TENSOR_DATA_TYPE_FLOAT32,则值可能会被截断 (例如,10.6 将变为 10) 。

示例

默认标识矩阵:

Offset: 0
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
    [[[[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1]]]]

向右/上移:

Offset: 1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
    [[[[ 0, 1, 0],
       [ 0, 0, 1],
       [ 0, 0, 0]]]]

左移/下移:

Offset: -1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
    [[[[0, 0],
       [1, 0],
       [0, 1]]]]

将对角线的对角线移至全部变为零:

Offset: -3
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
    [[[[0, 0],
       [0, 0],
       [0, 0]]]]

注解

可用性

此运算符是在 中 DML_FEATURE_LEVEL_2_0引入的。

张量支持

DML_FEATURE_LEVEL_5_1 及更高版本

种类 支持的维度计数 支持的数据类型
OutputTensor 输出 2 到 4 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_4_0 及更高版本

种类 支持的维度计数 支持的数据类型
OutputTensor 输出 2 到 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1 及更高版本

种类 支持的维度计数 支持的数据类型
OutputTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_0及更高版本

种类 支持的维度计数 支持的数据类型
OutputTensor 输出 4 FLOAT32、FLOAT16

要求

要求
最低受支持的客户端 Windows 10,版本 2004 (10.0;内部版本 19041)
最低受支持的服务器 Windows Server 版本 2004 (10.0;内部版本 19041)
标头 directml.h