Upravit

Sdílet prostřednictvím


DML_SLICE_GRAD_OPERATOR_DESC structure (directml.h)

Computes backpropagation gradients for Slice (see DML_SLICE1_OPERATOR_DESC).

Recall that DML_SLICE1_OPERATOR_DESC extracts a subregion of an input tensor. Given an InputGradientTensor with the same sizes as the output of an equivalent DML_SLICE1_OPERATOR_DESC, this operator produces an OutputGradientTensor with the same sizes as the input of DML_SLICE1_OPERATOR_DESC. The sliced elements are propagated to the output, and all other elements are set to 0.

As an example, consider a DML_SLICE1_OPERATOR_DESC that extracts the following elements from a tensor:

InputTensor            OutputTensor
[[a, b, c, d],
 [e, f, g, h],   Slice   [[a, c],
 [i, j, k, l],    -->     [i, k]]
 [m, n, o, p]]

If provided the same InputWindowOffsets/Sizes/Strides as in the above example, this operator would then perform the following transform.

InputGradientTensor       OutputGradientTensor
                             [[a, 0, c, 0],
      [[a, c],   SliceGrad    [0, 0, 0, 0],
       [i, k]]      -->       [i, 0, k, 0],
                              [0, 0, 0, 0]]

Syntax

struct DML_SLICE_GRAD_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputGradientTensor;
  const DML_TENSOR_DESC *OutputGradientTensor;
  UINT                  DimensionCount;
  const UINT            *InputWindowOffsets;
  const UINT            *InputWindowSizes;
  const INT             *InputWindowStrides;
};

Members

InputGradientTensor

Type: const DML_TENSOR_DESC*

The incoming gradient tensor. This is typically obtained from the output of backpropagation of a preceding layer. Typically, this tensor would have the same sizes as the output of the corresponding DML_SLICE1_OPERATOR_DESC in the forward pass.

OutputGradientTensor

Type: const DML_TENSOR_DESC*

An output tensor containing the backpropagated gradients. Typically, this tensor would have the same sizes as the input of the corresponding DML_SLICE1_OPERATOR_DESC in the forward pass.

DimensionCount

Type: UINT

The number of elements in the InputWindowOffsets, InputWindowSizes, and InputWindowStrides arrays. This value must equal the DimensionCount provided in the InputGradientTensor and OutputGradientTensor.

InputWindowOffsets

Type: _Field_size_(DimensionCount) const UINT*

See InputWindowOffsets in DML_SLICE1_OPERATOR_DESC.

InputWindowSizes

Type: _Field_size_(DimensionCount) const UINT*

See InputWindowSizes in DML_SLICE1_OPERATOR_DESC.

InputWindowStrides

Type: _Field_size_(DimensionCount) const UINT*

See InputWindowStrides in DML_SLICE1_OPERATOR_DESC.

Note that unlike DML_SLICE1_OPERATOR_DESC, this operator requires non-zero strides. That's because with a zero stride, it's ambiguous as to which input element should map to each output element, and therefore backpropagation can't be performed. Like DML_SLICE1_OPERATOR_DESC, negative strides flip the input window direction along that axis.

Availability

This operator was introduced in DML_FEATURE_LEVEL_3_0.

Tensor constraints

InputGradientTensor and OutputGradientTensor must have the same DataType and DimensionCount.

Tensor support

DML_FEATURE_LEVEL_4_1 and above

Tensor Kind Supported dimension counts Supported data types
InputGradientTensor Input 1 to 8 FLOAT64, FLOAT32, FLOAT16, INT64, INT32, INT16, INT8, UINT64, UINT32, UINT16, UINT8
OutputGradientTensor Output 1 to 8 FLOAT64, FLOAT32, FLOAT16, INT64, INT32, INT16, INT8, UINT64, UINT32, UINT16, UINT8

DML_FEATURE_LEVEL_3_1 and above

Tensor Kind Supported dimension counts Supported data types
InputGradientTensor Input 1 to 8 FLOAT32, FLOAT16, INT32, INT16, INT8, UINT32, UINT16, UINT8
OutputGradientTensor Output 1 to 8 FLOAT32, FLOAT16, INT32, INT16, INT8, UINT32, UINT16, UINT8

DML_FEATURE_LEVEL_3_0 and above

Tensor Kind Supported dimension counts Supported data types
InputGradientTensor Input 4 to 5 FLOAT32, FLOAT16, INT32, INT16, INT8, UINT32, UINT16, UINT8
OutputGradientTensor Output 4 to 5 FLOAT32, FLOAT16, INT32, INT16, INT8, UINT32, UINT16, UINT8

Requirements

Requirement Value
Minimum supported client Windows 10 Build 20348
Minimum supported server Windows 10 Build 20348
Header directml.h