文档首页> Atlas 500应用> TE API参考> compute接口> te.lang.cce.broadcast(var, shape, output_dtype=None)
更新时间:2021-03-18 GMT+08:00
分享

te.lang.cce.broadcast(var, shape, output_dtype=None)

把var broadcast为大小为shape的tensor,结果的数据类型由output_dtype指定,var可以是标量,或者是一个tensor,要求var的shape与第二个参数shape的长度一致,每个维度的大小要么与shape相等,要么为1,为1的维度会被broadcast到与shape一致。例如var的维度为(2,1,64),shape为(2,128,64),运算结果var的维度变为(2,128,64)。支持的类型:float16、float32、int32。

该接口在broadcast_compute.py中定义。

参数说明

  • var:需要broadcast的数据,标量或者tensor类型。
  • shape:目标shape,进行broadcast操作的目标shape。
  • output_dtype:输出数据类型,默认值var.dtype。

返回值

res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。

调用示例

outshape = (1024,1024)
shape = (1024,1)
input_dtype = "float16"
data = tvm.placeholder(shape, name="data", dtype=input_dtype)
res = te.lang.cce.broadcast(data, outshape)
分享:

    相关文档

    相关产品