directml.h) (DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC 结构
计算平均池 (DML_AVERAGE_POOLING_OPERATOR_DESC) 的反向传播梯度。
请考虑 2x2 DML_AVERAGE_POOLING_OPERATOR_DESC,不带填充,步幅为 1,可执行以下操作。
InputTensor OutputTensor
[[[[1, 2, 3], AvgPool [[[[3, 4],
[4, 5, 6], --> [6, 7]]]]
[7, 8, 9]]]]
输入张量中的每个 2x2 窗口的平均值会生成输出的一个元素, (边缘) 以外的元素读取零。 下面是 给定类似 参数DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC的输出示例。
InputGradientTensor OutputGradientTensor
[[[[1, 2], AvgPoolGrad [[[[0.25, 0.75, 0.5],
[3, 4]]]] --> [ 1, 2.5, 1.5],
[0.75, 1.75, 1]]]]
请注意,OutputGradientTensor 中的值表示在原始 DML_AVERAGE_POOLING_OPERATOR_DESC 运算符期间该元素对 OutputTensor 的加权贡献。
语法
struct DML_AVERAGE_POOLING_GRAD_OPERATOR_DESC {
const DML_TENSOR_DESC *InputGradientTensor;
const DML_TENSOR_DESC *OutputGradientTensor;
UINT DimensionCount;
const UINT *Strides;
const UINT *WindowSize;
const UINT *StartPadding;
const UINT *EndPadding;
BOOL IncludePadding;
};
成员
InputGradientTensor
类型: const DML_TENSOR_DESC*
传入的渐变张量。 这通常是从上一层的反向传播的输出中获取的。 通常,此张量的大小与前向传递中相应DML_AVERAGE_POOLING_OPERATOR_DESC的输出大小相同。
OutputGradientTensor
类型: const DML_TENSOR_DESC*
包含反向传播渐变的输出张量。 通常,此张量的大小与前向传递中相应DML_AVERAGE_POOLING_OPERATOR_DESC的输入大小相同。
DimensionCount
类型: UINT
Strides、WindowSize、StartPadding 和 EndPadding 数组中的元素数。 此值必须等于空间维度计数。 如果提供了 4D 张量,则空间维度计数为 2;如果提供 5D 张量,则为 3。
Strides
类型:_Field_size_ (DimensionCount) const UINT*
请参阅DML_AVERAGE_POOLING_OPERATOR_DESC中的步幅。
WindowSize
类型:_Field_size_ (DimensionCount) const UINT*
请参阅DML_AVERAGE_POOLING_OPERATOR_DESC中的 WindowSize。
StartPadding
类型:_Field_size_ (DimensionCount) const UINT*
请参阅 DML_AVERAGE_POOLING_OPERATOR_DESC 中的 StartPadding。
EndPadding
类型:_Field_size_ (DimensionCount) const UINT*
请参阅 DML_AVERAGE_POOLING_OPERATOR_DESC 中的 EndPadding。
IncludePadding
类型: BOOL
请参阅 DML_AVERAGE_POOLING_OPERATOR_DESC 中的 IncludePadding。
可用性
此运算符是在 中 DML_FEATURE_LEVEL_3_0
引入的。
张量约束
InputGradientTensor 和 OutputGradientTensor 必须具有相同的 DataType 和 DimensionCount。
张量支持
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputGradientTensor | 输入 | 4 到 5 | FLOAT32、FLOAT16 |
OutputGradientTensor | 输出 | 4 到 5 | FLOAT32、FLOAT16 |
要求
要求 | 值 |
---|---|
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |