Updated on 2022-03-13 GMT+08:00

te.lang.cce.concat(raw_tensors, axis)

Reconcatenates multiple input tensors based on the specified axis.

raw_tensors indicates multiple input tensors. The data types are the same.

If raw_tensors[i].shape = [D0, D1, ... Daxis(i), ...Dn], the shape of the output after the concatenation is established based on axis is: as follows: [D0, D1, ... Raxis, ...Dn].

Where, Raxis = sum(Daxis(i)).

For input tensors, the dimensions of other axes must be the same except for axis.

For example:

t1 = [[1, 2, 3], [4, 5, 6]] 
t2 = [[7, 8, 9], [10, 11, 12]] 
concat([t1, t2], 0)  # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] 
concat([t1, t2], 1)  # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]  

# The shape of tensor t1 is [2, 3].
# The shape of tensor t2 is [2, 3].
concat([t1, t2], 0).shape  # [4, 3] 
concat([t1, t2], 1).shape  # [2, 6]

The parameter axis can also be a negative number, indicating the axis + len(shape) axis, which is calculated from the end of the dimension.

For example:

t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]] 
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]] 
concat([t1, t2], -1)

The output is as follows:

[[[ 1,  2,  7,  4],
  [ 2,  3,  8,  4]], 

 [[ 4,  4,  2, 10], 
  [ 5,  3, 15, 11]]]

The supported data types are as follows: int8, uint8, int16, int32 float16, and float32.

This API is defined in concat_compute.py.

Parameter Description

  • raw_tensors: tensor list, list type. The element is tvm.tensor, and the last dimension of tensor shape must be 32-byte aligned.
  • axis: axis based on which the concat operation is performed. The value range is [–d, d–1]. The parameter d indicates the dimension of raw_tensor.

Return Value

res_tensor: tensor after reconcatenation is implemented, tvm.tensor type

Calling Example

import tvm
import te.lang.cce 
shape1 = (64,128) 
shape1 = (64,128) 
input_dtype = "float16"
data1 = tvm.placeholder(shape1, name="data1", dtype=input_dtype) 
data2 = tvm.placeholder(shape2, name="data1", dtype=input_dtype) 
data = [data1, data2] 
res = te.lang.cce.concat(data, 0) 
# res.shape = (128,128)