DML_ELEMENT_WISE_IF_OPERATOR_DESC 结构 (directml.h)
根据 ConditionTensor 中对应元素的值,选择 ATensor 或 BTensor 中的元素。 ConditionTensor 的非零元素从 ATensor 中选择,而零值元素则从 BTensor 中选择。
f(cond, a, b) = a, if cond != 0
b, otherwise
Example:
[[1, 0], [1, 1]] // ConditionTensor
[[1, 2], [3, 4]] // ATensor
[[9, 8], [7, 6]] // BTensor
[[1, 8], [3, 4]] // Output
语法
struct DML_ELEMENT_WISE_IF_OPERATOR_DESC {
const DML_TENSOR_DESC *ConditionTensor;
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *OutputTensor;
};
成员
ConditionTensor
类型: const DML_TENSOR_DESC*
要从中读取的条件张量。
ATensor
类型: const DML_TENSOR_DESC*
包含左侧输入的张量。
BTensor
类型: const DML_TENSOR_DESC*
包含右侧输入的张量。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入结果的输出张量。
备注
可用于在功能上构建其他聚合运算符,例如 LeakyRelu。 下面是伪代码中的插图, (不是最有效的方法,但可能) : LeakyRelu(x) = If(Less(x, 0), Mul(x, alpha), x)
。
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_2_0
。
张量约束
- ATensor、 BTensor、 ConditionTensor 和 OutputTensor 必须具有相同的 DimensionCount 和 Size。
- ATensor、 BTensor 和 OutputTensor 必须具有相同的 数据类型。
Tensor 支持
DML_FEATURE_LEVEL_5_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
ConditionTensor | 输入 | 1 到 8 | UINT8 |
ATensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
BTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
ConditionTensor | 输入 | 1 到 8 | UINT8 |
ATensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
BTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
OutputTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
ConditionTensor | 输入 | 4 | UINT8 |
ATensor | 输入 | 4 | FLOAT16 |
BTensor | 输入 | 4 | FLOAT16 |
OutputTensor | 输出 | 4 | FLOAT16 |
要求
最低受支持的客户端 | Windows 10版本 2004 (10.0;内部版本 19041) |
最低受支持的服务器 | Windows Server 版本 2004 (10.0;内部版本 19041) |
标头 | directml.h |