Updated on 2022-03-13 GMT+08:00

Implementing an Operator

Function Definition for Operator Implementation

As described below, the implementation function of an operator contains the input tensor shape, data type, operator attributes, kernel name, and build and print configurations. This function is called by plug-in code and is executed when OMG converts the model.

def operationname(shape, dtype, attribute1, attribute2, ... , kernel_name="KernelName", need_build=True, need_print=False)

In the preceding information:

  • shape: input tensor shape. If an operator has multiple input tensors and each tensor has a unique shape, multiple shapes need to be defined as placeholders for the tensors. If multiple input tensors have a same shape, define one shape.
  • dtype: data type of the input tensor.
  • attribute1, attribute2, ...: operator attributes. Edit the code based on the operator definition.
  • kernel_name: name of the operator in the kernel, that is, the name of the generated binary file. The value is user-defined and unique. The value can contain only uppercase letters, lowercase letters, digits, and underscores (_). Enter a maximum of 200 characters starting with a letter or underscore (_).
  • need_build: build enable, either True or False
  • need_print: intermediate representation (IR) print enable, either True or False

Examples:

Reduction operator:

def reduction(shape, dtype, axis, operation, coeff, kernel_name="Reduction", need_build=True, need_print=False)

Matmul operator:

def matmul(shape_a, shape_b, dtype, kernel_name="matmul", trans_a=False, trans_b=False,need_build=False, need_print=False):

Operator Implementation Logic

The TE operator implementation logic is summarized as follows:

Define placeholders for input tensors, and then call the SDL interfaces in te.lang.cce to describe the computation process. The following is a code example:

data = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)
with tvm.target.cce():
    cof = coeff
   data_tmp_input = te.lang.cce.vmuls(data, cof)    // Process the scaling parameter and multiply the input tensor by a scalar.
    tmp = data_tmp_input
    res_tmp = te.lang.cce.sum(tmp, axis=axis)    // Perform the summation operation on the axis.
    res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True)   // Convert the data type.

In the preceding information:

  • data indicates the input tensor, which is defined by using the placeholder interface of the TVM. A Tensor object is returned, indicating a group of input data.

    If the operator has multiple input tensors, multiple tensor objects need to be defined. For example:

    tensor_a = tvm.placeholder(shape_a, name='tensor_a', dtype=dtype)
    tensor_b = tvm.placeholder(shape_b, name='tensor_b', dtype=dtype)
  • vmuls (vector multiplication) and sum (summation) constitute the intermediate computation logic.
  • cast_to is used to convert the data type. The output tensor must be of the identical data type as the input tensor. If the data type is changed during computation, you need to use the cast_to interface to convert the data type of the output tensor to that of the input tensor.

    For example: If the data type of the input tensor is int8, it is converted to float16 for the vmuls operation. In this case, the cast_to interface must be called to convert the data type of the output tensor from float16 to int8 after the computation logic is complete. vmuls converts int8 values to float16, padding the decimal part with zeros. Therefore, f1628IntegerFlag is set to True. The sample code is as follows:

    res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag=True) 

    For details about how to use the te.lang.cce.cast_to API, see Compute APIs in TE API Reference.

  • res indicates the output tensor, of the identical data type as the input tensor.

Before implementing the operator logic, you can customize the code for pre-processing the input data. The sample code is as follows:

    # basic check
    check_list = ["float16", "float32"]
    if not (dtype.lower() in check_list):
        raise RuntimeError("Reduction only support %s while dtype is %s" % (
            ",".join(check_list), dtype))
    reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN")

    # axis parameter check
    if type(axis) != int:
        raise RuntimeError("type of axis value should be int")
    if axis >= len(shape) or axis < -len(shape):
        raise RuntimeError(
            "input axis is out of range, axis value can be from %d to %d" % (
                -len(shape), len(shape) - 1))
    # operation parameter check
    if operation not in reduction_op:
        raise RuntimeError("operation can only be one of SUM, ASUM, SUMSQ , MEAN")
    # coeff parameter check
    if type(coeff) != int and type(coeff) != float:
        raise RuntimeError("coeff must be a value")
    # Preprocess
    if axis < 0:
        axis = len(shape) + axis
    shape = list(shape)
    shape1 = shape[:axis] + [reduce(lambda x, y: x * y, shape[axis:])]
    inp_dtype = dtype.lower()