文档首页 > > 训练场景> TBE自定义算子开发 > 接口参考> 原型定义接口> 算子原型InferShape接口>

BROADCAST_INFER

BROADCAST_INFER

分享
更新时间:2021/02/05 GMT+08:00

函数功能

提供公共函数宏封装,供算子开发者开发InferShape函数。该函数基于2个输入的shape,设置输出的shape。该宏只是设置shape,未设置dtype。

  • 如果2个输入的shape一致,会按输入的shape设置输出shape。
  • 如果2个输入的shape不一致,会按照broadcast的策略,取2个输入shape的并集。

    比如输入shape分别为(1,2,3,4)和(3,1,3,4),则该宏会设置算子的输出shape为(3,2,3,4)。

函数原型

BROADCAST_INFER(in1_name, in2_name, out_name)

该函数会自动调用如下函数:

graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, const function<vector<int64_t>()> &get_in2_shape, const function<void(const vector<int64_t> &y_shape)> &set_out_shape);

约束说明

无。

参数说明

参数名

输入/输出

描述

in1_name

输入

算子第一个输入。

in2_name

输入

算子第二个输入。

out_name

输入

算子输出。

返回值

执行成功或失败。

调用示例

IMPLEMT_INFERFUNC(RightShift, RightShiftInfer) {
  DataType type = op.GetInputDesc("x").GetDataType();
  SET_OUTPUT_TYPE(op, "z", type);
  return BROADCAST_INFER("x", "y", "z")(op);
}
分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!非常感谢您的反馈,我们会继续努力做到更好!
反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问