更新时间:2021-03-18 GMT+08:00
分享

代码结构介绍

TBE DSL方式实现的算子代码结构如下所示:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 导入依赖的Python模块
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util

# 算子计算函数
@fusion_manager.register("add")
def add_compute(input_x, input_y, output_z, kernel_name="add"):
    """
    算子计算逻辑实现
    """
# 算子定义函数

def add(input_x, input_y, output_z, kernel_name="add"):
    """
    算子校验(可选)
    为输入tensor占位
    """

    res = add_compute(data_x, data_y, output_z, kernel_name) # 调用算子计算函数

    # 自动调度
    with tvm.target.cce(): 
        schedule = generic.auto_schedule(res)        
    # 算子编译
    config = {"print_ir": False,
              "name": kernel_name,
              "tensor_list": (data_x, data_y, res)}
    te.lang.cce.cce_build_code(schedule, config)

算子实现代码总体结构包含依赖Python模块的导入,算子定义函数实现,算子计算函数实现。

其中:

  • 算子定义函数包含算子的校验,计算函数的调用以及调度与编译。
  • 算子计算函数是对算子计算逻辑的实现。

下面详细介绍每个代码块的实现。

分享:

    相关文档

    相关产品

close