文档首页> Atlas 300应用(型号 3000)> TE API参考> compute接口> te.lang.cce.concat(raw_tensors, axis)
更新时间:2021-03-18 GMT+08:00
分享

te.lang.cce.concat(raw_tensors, axis)

在指定轴上对输入的多个Tensor进行重新连接。

输入raw_tensors为多个Tensor,数据类型相同。

如果raw_tensors[i].shape = [D0, D1, ... Daxis(i), ...Dn],沿着轴axis连接后的结果的shape为:[D0, D1, ... Raxis, ...Dn]。

其中:Raxis = sum(Daxis(i))。

对输入tensor来说,除了轴axis以外,其他轴的维度要完全一致。

例如:

t1 = [[1, 2, 3], [4, 5, 6]] 
t2 = [[7, 8, 9], [10, 11, 12]] 
concat([t1, t2], 0)  # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] 
concat([t1, t2], 1)  # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]  

# tensor t1的shape为 [2, 3] 
# tensor t2的shape为 [2, 3] 
concat([t1, t2], 0).shape  # [4, 3] 
concat([t1, t2], 1).shape  # [2, 6]

参数axis也可以为负数,表示从维度的最后开始计算,表示第axis + len(shape)跟轴。

例如:

t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]] 
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]] 
concat([t1, t2], -1)

结果为:

[[[ 1,  2,  7,  4],
  [ 2,  3,  8,  4]], 

 [[ 4,  4,  2, 10], 
  [ 5,  3, 15, 11]]]

支持的数据类型:int8、uint8、int16、int32、float16、float32。

该接口在concat_compute.py中定义。

参数说明

  • raw_tensors:tensor list,list类型,元素为tvm.tensor,且tensor shape的最后一维要32字节对齐。
  • axis:做 concat 操作的轴,取值范围:[-d,d-1],其中d是raw_tensor的维数。

返回值

res_tensor:重新连接后的tensor,tvm.tensor类型。

调用示例

import tvm
import te.lang.cce 
shape1 = (64,128) 
shape1 = (64,128) 
input_dtype = "float16"
data1 = tvm.placeholder(shape1, name="data1", dtype=input_dtype) 
data2 = tvm.placeholder(shape2, name="data1", dtype=input_dtype) 
data = [data1, data2] 
res = te.lang.cce.concat(data, 0) 
# res.shape = (128,128)
分享:

    相关文档

    相关产品