更新时间:2021-03-18 GMT+08:00
分享

算子实现

算子实现函数定义

如下所示,一个算子的实现函数中包含了输入张量的形状,数据类型,算子属性,内核名称,以及相应的编译、打印等配置。该函数会被插件代码调用,在离线模型生成器进行模型转换时执行。

def operationname(shape, dtype, attribute1, attribute2, ... , kernel_name="KernelName", need_build=True, need_print=False)

其中:

  • shape:输入张量的形状,若算子有多个输入,且每个输入的shape不同,则此处需定义多个shape用于后续对每个输入张量占位,若多个输入张量的shape相同,则可以定义一个shape。
  • dtype:输入张量的数据类型。
  • attribute1attribute2...:算子的属性,根据算子的实际定义进行代码编辑。
  • kernel_name:算子在内核中的名称(即生成的二进制文件的名称),用户自定义,保持唯一,只能是大小写字母、数字、“_”的组合,且必须是字母或者“_”开头,长度小于或等于200个字符。
  • 编译配置参数“need_build”:取值范围为True或者False,代表是否需要进行编译。
  • 打印配置参数“need_print”:取值范围为True或者False,代表是否需要打印算子的中间表示(IR:Intermediate Representation)。

例如,

对于Reduction算子,实现函数定义如下:

def reduction(shape, dtype, axis, operation, coeff, kernel_name="Reduction", need_build=True, need_print=False)

对于Matmul算子,实现函数定义如下:

def matmul(shape_a, shape_b, dtype, kernel_name="matmul", trans_a=False, trans_b=False,need_build=False, need_print=False):

算子实现逻辑

TE的算子实现逻辑总体概括为:

定义好输入数据的张量占位符,然后调用te.lang.cce中的各种特定域语言接口进行计算过程的描述,如下代码示例所示:

data = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)
with tvm.target.cce():
    cof = coeff
    data_tmp_input = te.lang.cce.vmuls(data, cof)       //对缩放参数进行处理,将输入张量乘上一个标量
    tmp = data_tmp_input
    res_tmp = te.lang.cce.sum(tmp, axis=axis)           //在轴axis上做求和操作
    res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True)    //进行数据类型的转换

其中:

  • data为输入张量,使用TVM的placeholder接口进行定义,placeholder是一个占位符,返回一个Tensor对象,表示一组输入数据。

    若算子有多个输入张量,此处需要定义多个Tensor对象。例如:

    tensor_a = tvm.placeholder(shape_a, name='tensor_a', dtype=dtype)
    tensor_b = tvm.placeholder(shape_b, name='tensor_b', dtype=dtype)
  • vmuls(向量乘),sum(求和)组成中间的计算逻辑。
  • cast_to(转换数据类型):输出张量需要与输入张量的数据类型保持一致,如果计算过程中对数据类型做了转换,需要使用cast_to接口将输出张量的数据类型转换为输入张量的数据类型。

    例如:如果输入张量的数据类型为int8,则执行vmuls操作时,会将数据类型转换为float16,则需要在计算逻辑结束后调用cast_to接口将输出张量的数据类型由float16转换为int8,由于使用vmuls将int8转换为float16的数值时,小数部分为0,所以“f1628IntegerFlag”设置为True,代码示例为:

    res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True) 

    “te.lang.cce.cast_to”接口的详细使用方法请参见《TE API参考》中的“Compute接口”

  • res为输出张量,其数据类型与输入张量的数据类型一致。

用户在进行算子逻辑实现前,可以自定义实现代码对输入数据进行预处理,如下代码示例所示:

    # basic check
    check_list = ["float16", "float32"]
    if not (dtype.lower() in check_list):
        raise RuntimeError("Reduction only support %s while dtype is %s" % (
            ",".join(check_list), dtype))
    reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN")

    # axis parameter check
    if type(axis) != int:
        raise RuntimeError("type of axis value should be int")
    if axis >= len(shape) or axis < -len(shape):
        raise RuntimeError(
            "input axis is out of range, axis value can be from %d to %d" % (
                -len(shape), len(shape) - 1))
    # operation parameter check
    if operation not in reduction_op:
        raise RuntimeError("operation can only be one of SUM, ASUM, SUMSQ , MEAN")
    # coeff parameter check
    if type(coeff) != int and type(coeff) != float:
        raise RuntimeError("coeff must be a value")
    # Preprocess
    if axis < 0:
        axis = len(shape) + axis
    shape = list(shape)
    shape1 = shape[:axis] + [reduce(lambda x, y: x * y, shape[axis:])]
    inp_dtype = dtype.lower()
分享:

    相关文档

    相关产品