te.lang.cce.matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, alpha_num=1.0, beta_num=0.0, tensor_c=None)
矩阵乘,计算:tensor_c=alpha_num * trans_a(tensor_a) * trans_b(tensor_b) + beta_num * tensor_c。
tensor_a与tensor_b的shape后两维(经过对应转置)需要满足矩阵乘(M, K) * (K, N) = (M, N),且batch数只支持1。tensor_a数据排布要满足L0A的分形结构,tensor_b要满足L0B的分形结构,mini形态下,数据类型只支持float16。
该接口在mmad_compute.py中定义。
参数说明
- tensor_a:A矩阵,tvm.tensor类型。
- tensor_b:B矩阵,tvm.tensor类型
- trans_a:A矩阵是否转置,bool类型。
- trans_b:B矩阵是否转置,bool类型
- alpha_num : A*B矩阵系数,只支持1.0
- beta_num : C矩阵系数,只支持0.0
- tensor_c : C矩阵,tvm.tensor类型,由于beta_num只支持0.0,此参数为预留扩展接口
返回值
tensor_c:根据关系运算计算后得到的tensor,tvm.tensor类型。
调用示例
import tvm import te.lang.cce a_shape = (1024, 256) b_shape = (256, 512) a_fractal_shape = (a_shape[0] // 16, a_shape[1] // 16, 16, 16) b_fractal_shape = (b_shape[0] // 16, b_shape[1] // 16, 16, 16) in_dtype = "float16" tensor_a = tvm.placeholder(a_fractal_shape, name='tensor_a', dtype=in_dtype) tensor_b = tvm.placeholder(a_fractal_shape, name='tensor_b', dtype=in_dtype) res = te.lang.cce.matmul(tensor_a, tensor_b, False, False)