算子实现
算子实现函数定义
如下所示,一个算子的实现函数中包含了输入张量的形状,数据类型,算子属性,内核名称,以及相应的编译、打印等配置。该函数会被插件代码调用,在离线模型生成器进行模型转换时执行。
def operationname(shape, dtype, attribute1, attribute2, ... , kernel_name="KernelName", need_build=True, need_print=False)
其中:
- shape:输入张量的形状,若算子有多个输入,且每个输入的shape不同,则此处需定义多个shape用于后续对每个输入张量占位,若多个输入张量的shape相同,则可以定义一个shape。
- dtype:输入张量的数据类型。
- attribute1,attribute2...:算子的属性,根据算子的实际定义进行代码编辑。
- kernel_name:算子在内核中的名称(即生成的二进制文件的名称),用户自定义,保持唯一,只能是大小写字母、数字、“_”的组合,且必须是字母或者“_”开头,长度小于或等于200个字符。
- 编译配置参数“need_build”:取值范围为True或者False,代表是否需要进行编译。
- 打印配置参数“need_print”:取值范围为True或者False,代表是否需要打印算子的中间表示(IR:Intermediate Representation)。
例如,
对于Reduction算子,实现函数定义如下:
def reduction(shape, dtype, axis, operation, coeff, kernel_name="Reduction", need_build=True, need_print=False)
对于Matmul算子,实现函数定义如下:
def matmul(shape_a, shape_b, dtype, kernel_name="matmul", trans_a=False, trans_b=False,need_build=False, need_print=False):
算子实现逻辑
TE的算子实现逻辑总体概括为:
定义好输入数据的张量占位符,然后调用te.lang.cce中的各种特定域语言接口进行计算过程的描述,如下代码示例所示:
data = tvm.placeholder(shape, name="data_input", dtype=inp_dtype) with tvm.target.cce(): cof = coeff data_tmp_input = te.lang.cce.vmuls(data, cof) //对缩放参数进行处理,将输入张量乘上一个标量 tmp = data_tmp_input res_tmp = te.lang.cce.sum(tmp, axis=axis) //在轴axis上做求和操作 res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True) //进行数据类型的转换
其中:
- data为输入张量,使用TVM的placeholder接口进行定义,placeholder是一个占位符,返回一个Tensor对象,表示一组输入数据。
若算子有多个输入张量,此处需要定义多个Tensor对象。例如:
tensor_a = tvm.placeholder(shape_a, name='tensor_a', dtype=dtype) tensor_b = tvm.placeholder(shape_b, name='tensor_b', dtype=dtype)
- vmuls(向量乘),sum(求和)组成中间的计算逻辑。
- cast_to(转换数据类型):输出张量需要与输入张量的数据类型保持一致,如果计算过程中对数据类型做了转换,需要使用cast_to接口将输出张量的数据类型转换为输入张量的数据类型。
例如:如果输入张量的数据类型为int8,则执行vmuls操作时,会将数据类型转换为float16,则需要在计算逻辑结束后调用cast_to接口将输出张量的数据类型由float16转换为int8,由于使用vmuls将int8转换为float16的数值时,小数部分为0,所以“f1628IntegerFlag”设置为True,代码示例为:
res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True)
“te.lang.cce.cast_to”接口的详细使用方法请参见《TE API参考》中的“Compute接口”。
- res为输出张量,其数据类型与输入张量的数据类型一致。
用户在进行算子逻辑实现前,可以自定义实现代码对输入数据进行预处理,如下代码示例所示:
# basic check check_list = ["float16", "float32"] if not (dtype.lower() in check_list): raise RuntimeError("Reduction only support %s while dtype is %s" % ( ",".join(check_list), dtype)) reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN") # axis parameter check if type(axis) != int: raise RuntimeError("type of axis value should be int") if axis >= len(shape) or axis < -len(shape): raise RuntimeError( "input axis is out of range, axis value can be from %d to %d" % ( -len(shape), len(shape) - 1)) # operation parameter check if operation not in reduction_op: raise RuntimeError("operation can only be one of SUM, ASUM, SUMSQ , MEAN") # coeff parameter check if type(coeff) != int and type(coeff) != float: raise RuntimeError("coeff must be a value") # Preprocess if axis < 0: axis = len(shape) + axis shape = list(shape) shape1 = shape[:axis] + [reduce(lambda x, y: x * y, shape[axis:])] inp_dtype = dtype.lower()