directml.h) (DML_SLICE_GRAD_OPERATOR_DESC 结构
计算切片的反向传播渐变 (请参阅 DML_SLICE1_OPERATOR_DESC) 。
回想一下, DML_SLICE1_OPERATOR_DESC 提取输入张量子区域。 如果 InputGradientTensor 的大小与等效DML_SLICE1_OPERATOR_DESC的输出大小相同,则此运算符将生成大小与DML_SLICE1_OPERATOR_DESC输入相同的 OutputGradientTensor。 切片元素传播到输出,所有其他元素设置为 0。
例如,考虑从张量中提取以下元素的 DML_SLICE1_OPERATOR_DESC :
InputTensor OutputTensor
[[a, b, c, d],
[e, f, g, h], Slice [[a, c],
[i, j, k, l], --> [i, k]]
[m, n, o, p]]
如果提供与上述示例中相同的 InputWindowOffsets/大小/步幅 ,则此运算符将执行以下转换。
InputGradientTensor OutputGradientTensor
[[a, 0, c, 0],
[[a, c], SliceGrad [0, 0, 0, 0],
[i, k]] --> [i, 0, k, 0],
[0, 0, 0, 0]]
语法
struct DML_SLICE_GRAD_OPERATOR_DESC {
const DML_TENSOR_DESC *InputGradientTensor;
const DML_TENSOR_DESC *OutputGradientTensor;
UINT DimensionCount;
const UINT *InputWindowOffsets;
const UINT *InputWindowSizes;
const INT *InputWindowStrides;
};
成员
InputGradientTensor
类型: const DML_TENSOR_DESC*
传入的渐变张量。 这通常是从上一层的反向传播的输出中获取的。 通常,此张量的大小与前向传递中相应DML_SLICE1_OPERATOR_DESC的输出大小相同。
OutputGradientTensor
类型: const DML_TENSOR_DESC*
包含反向传播渐变的输出张量。 通常,此张量的大小与前向传递中相应DML_SLICE1_OPERATOR_DESC的输入大小相同。
DimensionCount
类型: UINT
InputWindowOffsets、InputWindowSizes 和 InputWindowStrides 数组中的元素数。 此值必须等于 InputGradientTensor 和 OutputGradientTensor 中提供的 DimensionCount。
InputWindowOffsets
类型:_Field_size_ (DimensionCount) const UINT*
请参阅 DML_SLICE1_OPERATOR_DESC 中的 InputWindowOffsets。
InputWindowSizes
类型:_Field_size_ (DimensionCount) const UINT*
请参阅 DML_SLICE1_OPERATOR_DESC 中的 InputWindowSizes。
InputWindowStrides
类型:_Field_size_ (DimensionCount) const UINT*
请参阅 DML_SLICE1_OPERATOR_DESC 中的 InputWindowStrides。
请注意,与 DML_SLICE1_OPERATOR_DESC不同,此运算符需要非零步幅。 这是因为在步幅为零的情况下,对于应映射到每个输出元素的输入元素不明确,因此无法执行反向传播。 与 DML_SLICE1_OPERATOR_DESC一样,负步幅沿该轴翻转输入窗口方向。
可用性
此运算符是在 中 DML_FEATURE_LEVEL_3_0
引入的。
张量约束
InputGradientTensor 和 OutputGradientTensor 必须具有相同的 DataType 和 DimensionCount。
张量支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputGradientTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputGradientTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputGradientTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputGradientTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputGradientTensor | 输入 | 4 到 5 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputGradientTensor | 输出 | 4 到 5 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
要求
要求 | 值 |
---|---|
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |