te.lang.cce.broadcast(var, shape, output_dtype=None)
把var broadcast为大小为shape的tensor,结果的数据类型由output_dtype指定,var可以是标量,或者是一个tensor,要求var的shape与第二个参数shape的长度一致,每个维度的大小要么与shape相等,要么为1,为1的维度会被broadcast到与shape一致。例如var的维度为(2,1,64),shape为(2,128,64),运算结果var的维度变为(2,128,64)。支持的类型:float16、float32、int32。
该接口在broadcast_compute.py中定义。
参数说明
- var:需要broadcast的数据,标量或者tensor类型。
- shape:目标shape,进行broadcast操作的目标shape。
- output_dtype:输出数据类型,默认值var.dtype。
返回值
res_tensor:由var扩展后得到的tensor,shape为参数指定的shape,数据类型为output_dtype。
调用示例
outshape = (1024,1024) shape = (1024,1) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) res = te.lang.cce.broadcast(data, outshape)