文档首页> Atlas 500应用> TE API参考> compute接口> te.lang.cce.matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, alpha_num=1.0, beta_num=0.0, tensor_c=None)
更新时间:2021-03-18 GMT+08:00
分享

te.lang.cce.matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, alpha_num=1.0, beta_num=0.0, tensor_c=None)

矩阵乘,计算:tensor_c=alpha_num * trans_a(tensor_a) * trans_b(tensor_b) + beta_num * tensor_c。

tensor_a与tensor_b的shape后两维(经过对应转置)需要满足矩阵乘(M, K) * (K, N) = (M, N),且batch数只支持1。tensor_a数据排布要满足L0A的分形结构,tensor_b要满足L0B的分形结构,mini形态下,数据类型只支持float16。

该接口在mmad_compute.py中定义。

参数说明

  • tensor_a:A矩阵,tvm.tensor类型。
  • tensor_b:B矩阵,tvm.tensor类型
  • trans_a:A矩阵是否转置,bool类型。
  • trans_b:B矩阵是否转置,bool类型
  • alpha_num : A*B矩阵系数,只支持1.0
  • beta_num : C矩阵系数,只支持0.0
  • tensor_c : C矩阵,tvm.tensor类型,由于beta_num只支持0.0,此参数为预留扩展接口

返回值

tensor_c:根据关系运算计算后得到的tensor,tvm.tensor类型。

调用示例

import tvm
import te.lang.cce
a_shape = (1024, 256)
b_shape = (256, 512)
a_fractal_shape = (a_shape[0] // 16, a_shape[1] // 16, 16, 16)
b_fractal_shape = (b_shape[0] // 16, b_shape[1] // 16, 16, 16)
in_dtype = "float16"
tensor_a = tvm.placeholder(a_fractal_shape, name='tensor_a', dtype=in_dtype)
tensor_b = tvm.placeholder(a_fractal_shape, name='tensor_b', dtype=in_dtype)
res = te.lang.cce.matmul(tensor_a, tensor_b, False, False)
分享:

    相关文档

    相关产品