DML_DIAGONAL_MATRIX_OPERATOR_DESC结构 (directml.h)
生成一个类似标识的矩阵,该矩阵在主对角线上 (或其他显式值) ,其他位置为零。 对角线值可以通过 偏移) (移位,其中 OutputTensor[i, i + Offset]
= Value,这意味着,大于零的 Offset 参数将所有值都向右移动,小于零会将这些值向左移动。 此生成器运算符适用于模型,以避免存储较大的常量张量。 最后两个之前的任何前导维度都被视为批计数,这意味着张量被视为 2D 矩阵的堆栈。
此运算符执行以下伪代码。
for each coordinate in OutputTensor
OutputTensor[coordinate] = if (coordinate.y + Offset == coordinate.x) then Value else 0
endfor
语法
struct DML_DIAGONAL_MATRIX_OPERATOR_DESC {
const DML_TENSOR_DESC *OutputTensor;
INT Offset;
FLOAT Value;
};
成员
OutputTensor
类型: const DML_TENSOR_DESC*
将结果写入到的张量。 维度为 { Batch1, Batch2, OutputHeight, OutputWidth }
。 高度和宽度不需要为平方。
Offset
类型: INT
一个偏移量,用于移动 Value 的对角线,正偏移将写入值向右/向上移动 (将输出视为矩阵,左上角为 0,0) ,向左/下偏移量为负。
Value
类型: FLOAT
要沿 2D 对角线填充的值。 标准值为 1.0。 请注意,如果张量 DataType 未 DML_TENSOR_DATA_TYPE_FLOAT16 或 DML_TENSOR_DATA_TYPE_FLOAT32,则值可能会被截断 (例如,10.6 将变为 10) 。
示例
默认标识矩阵:
Offset: 0
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]]]
向右/上移:
Offset: 1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[ 0, 1, 0],
[ 0, 0, 1],
[ 0, 0, 0]]]]
左移/下移:
Offset: -1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[0, 0],
[1, 0],
[0, 1]]]]
将对角线的对角线移至全部变为零:
Offset: -3
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[0, 0],
[0, 0],
[0, 0]]]]
注解
可用性
此运算符是在 中 DML_FEATURE_LEVEL_2_0
引入的。
张量支持
DML_FEATURE_LEVEL_5_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
OutputTensor | 输出 | 2 到 4 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_4_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
OutputTensor | 输出 | 2 到 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16 |
要求
要求 | 值 |
---|---|
最低受支持的客户端 | Windows 10,版本 2004 (10.0;内部版本 19041) |
最低受支持的服务器 | Windows Server 版本 2004 (10.0;内部版本 19041) |
标头 | directml.h |