文档首页> Atlas 500应用> TE API参考> compute接口> te.lang.cce.compute_four2five(input, raw_shape_4D)
更新时间:2021-03-18 GMT+08:00
分享

te.lang.cce.compute_four2five(input, raw_shape_4D)

把给定4-D “NCHW”数据格式转换为5-D “NC1HWC0”数据格式。支持的数据类型:float16。

该接口在dim_conv.py中定义。

参数说明

  • input:输入tensor,4-D格式(N, C, H, W),tvm.tensor类型。
  • raw_shape_4D:输入tensor的维度。

返回值:

res_tensor:转换为5-D格式(N, C1, H, W, C0)后的tensor,tvm.tensor类型

调用示例

import tvm
import te.lang.cce
raw_shape = (N,C,H,W)
in_dtype = "float16"
input = tvm.placeholder(raw_shape, name='input', dtype=in_dtype)
res = te.lang.cce.compute_four2five(input, raw_shape)
# res.shape = (N,(C+15)//16,H,W,16)
分享:

    相关文档

    相关产品