Help Center/
Atlas 500 Application/
TE Custom Operator Development Guide/
Developing a Custom Operator/
Code Examples
Updated on 2022-03-13 GMT+08:00
Code Examples
Open the copied operator sample file reduction.py.
For details about the code description, see the following comments:
#coding=utf-8 import te.lang.cce from te import tvm from topi import generic from topi.cce import util def reduction(shape, dtype, axis, operation, coeff, kernel_name="Reduction", need_build=True, need_print=False): """ Reduce a tensor on a certain axis, and scale output with coeff Parameters ---------- shape : shape of data dtype : source data type, only support float16, float32 axis : the first axis to reduce, may be negative to index from the end (e.g., -1 for the last axis). If axis == 0, the output Blob always has the empty shape (count 1), performing reduction across the entire input. op : can only be one of "SUM, ASUM (sum of abs), SUMSQ (sum of sqr), MEAN" coeff : scale for output kernel_name : cce kernel name, default value is "cce_reductionLayer" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ # Basic parameter verification #shape parameter verification. The check_shape_rule() function definition is stored in ddk/ddk/site-packages/topi-0.4.0.egg/topi/cce/util.py. util.check_shape_rule(shape) check_list = ["float16", "float32"] if not (dtype.lower() in check_list): raise RuntimeError("Reduction only support %s while dtype is %s" % (",".join(check_list), dtype)) reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN") # Axis parameter verification if type(axis) != int: raise RuntimeError("type of axis value should be int") if axis >= len(shape) or axis < -len(shape): raise RuntimeError( "input axis is out of range, axis value can be from %d to %d" % (-len(shape), len(shape) - 1)) # op parameter verification if operation not in reduction_op: raise RuntimeError("op can only be one of SUM, ASUM, SUMSQ , MEAN") # coeff parameter verification if type(coeff) != int and type(coeff) != float: raise RuntimeError("coeff must be a value") # Parameter pre-processing if axis < 0: axis = len(shape) + axis shape = list(shape) shape1 = shape[:axis] + [reduce(lambda x, y: x * y, shape[axis:])] inp_dtype = dtype.lower() # Define the tensor object of the input data. The data is only used as a placeholder and no actual memory is allocated. data = tvm.placeholder(shape1, name="data_input", dtype=inp_dtype) # Define the operator calculation process. with tvm.target.cce(): if operation == "ASUM": data_tmp_input = te.lang.cce.vabs(data) cof = coeff tmp = te.lang.cce.vmuls(data_tmp_input, cof) elif operation == "SUMSQ": data_tmp_input = te.lang.cce.vmul(data, data) cof = coeff tmp = te.lang.cce.vmuls(data_tmp_input, cof) elif operation == "MEAN": size = shape1[-1] cof = float(coeff) * (size ** (-0.5)) tmp = te.lang.cce.vmuls(data, cof) elif operation == "SUM": cof = coeff data_tmp_input = te.lang.cce.vmuls(data, cof) tmp = data_tmp_input #Sum up data by axis to reduce dimensions. res_tmp = te.lang.cce.sum(tmp, axis=axis) #Convert the data type. res = te.lang.cce.cast_to(res_tmp, inp_dtype, f1628IntegerFlag = True) if operation == "MEAN": size = shape1[-1] sqrt_size = size ** (-0.5) res = te.lang.cce.vmuls(res_tmp, sqrt_size) #Generate the schedule object to be calculated by the operator. sch = generic.auto_schedule(res) #Define compilation parameters. config = {"print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": [data, res]} #Compile the operator and generate the target file. te.lang.cce.cce_build_code(sch, config) #Call the reduction operator by using the parameters: shape (2, 3, 4), datatype (float16), axis (1), op (SUM), coeff (2), and operator name (Reduction). if __name__ == "__main__": reduction((2, 3, 4), "float16", 1, "SUM", coeff = 2,kernel_name = "Reduction")
Parent topic: Developing a Custom Operator
Feedback
Was this page helpful?
Provide feedbackThank you very much for your feedback. We will continue working to improve the documentation.See the reply and handling status in My Cloud VOC.
The system is busy. Please try again later.
For any further questions, feel free to contact us through the chatbot.
Chatbot