更新时间:2021-03-18 GMT+08:00
分享

算子实现

算子代码实现

ScatterNdAdd算子的详细实现代码请参见“tbe/impl/scatter_nd_add.py”,下面主要介绍关键代码原理。

ScatterNdAdd的算子实现的关键点是进行算子schedule策略的实现,包含tiling参数的计算、多核实现等。

  1. 接口定义。

    def scatter_nd_add(var,
                       indices,
                       updates,
                       var_out,
                       use_locking=False,
                       kernel_name="scatter_nd_add"):
        scatter_nd = Scatter(var, indices, updates, var_out, True, kernel_name, "vadd")
        scatter_nd.scatter_operator()

    主要包括以下关键点:

    • 定义Scatter类,并在初始化函数中进行tiling参数的计算,请参考2
    • 通过scatter_nd_add实现算子计算逻辑,请参考3

  2. tiling参数计算。

    定义Scatter类,并在初始化函数中进行tiling参数的计算。核心计算主要是计算每个输入的shape的大小,再根据数据类型计算需要UB多少空间。我们可以通过cce.CceProductParams().getParams("Unified_Buffer")接口获取到UB的实际物理空间后,根据UB大小来划分UB空间,为定义UB上的tensor做准备。后续的步骤中,我们还会使用这些数据来计算data_move、vec_add等接口的参数。设置独立的tiling模块,将其与算子计算逻辑分离可以很好的做到算子的shape泛化。对于不同的shape,我们可以在不改变计算逻辑的情况下,只改变tiling参数来优化搬运和计算的次数,来做到泛化和高性能。其中加粗部分请根据实际芯片修改。

    class Scatter():
        def __init__(self, var, indices, updates, var_out, nd_flag, kernel_name,
                     compute_type):
            # 初始化tik容器
            if cce.CceProductParams().cce_product in ("1.1", "1.3"):
                self.product_name = "mini"
            elif cce.CceProductParams().cce_product == "1.60":
                self.product_name = "cloud"
            else:
                raise RuntimeError(
                    "scatter compute only support target:cloud_v100/mini_v100")
            self.tik_instance = tik.Tik(tik.Dprofile("v100", self.product_name))
            self.nd_flag = nd_flag
            #初始化三个输入的shape和数据类型
            self.var_shape = var.get("shape")
            self.var_dtype = var.get("dtype").lower()
            self.indices_shape = indices.get("shape")
            self.indices_dtype = indices.get("dtype").lower()
            self.updates_shape = updates.get("shape")
            self.updates_dtype = updates.get("dtype").lower()
            # 计算三个输入的tensor的大小
            self.var_ele_num = functools_reduce(lambda x, y: x * y, self.var_shape)
            self.indices_num = functools_reduce(lambda x, y: x * y,
                                                self.indices_shape)
            self.updates_num = functools_reduce(lambda x, y: x * y,
                                                self.updates_shape)
            self.kernel_name = kernel_name
            self.check_param(var_out)
            # 计算index的个数和最大值,用于遍历和分核。ND和非ND场景按不同的分支计算
            if nd_flag:
                if self.indices_shape[-1] == len(self.var_shape):
                    self.update_data_num = 1
                else:
                    self.update_data_num = functools_reduce(
                        lambda x, y: x * y, self.var_shape[self.indices_shape[-1]:])
                self.max_indice = functools_reduce(
                    lambda x, y: x * y, self.var_shape[0:self.indices_shape[-1]])
                self.index_dims = self.indices_shape[-1]
            else:
                if len(self.var_shape) > 1:
                    self.update_data_num = functools_reduce(lambda x, y: x * y,
                                                            self.var_shape[1:])
                else:
                    self.update_data_num = 1
                self.max_indice = self.var_shape[0]
                self.index_dims = 1
            # 初始化算类型,用于兼容add和sub等方法,方便实现不同的操作
            self.compute_type = compute_type
            # 获取UB buffer空间大小,并计算一个block可以存储多少相应数据类型的数据
            self.ub_size_bytes = (
                cce.CceProductParams().getParams("Unified_Buffer") - 8192)
            self.var_dtype_bytes_size = cce.cce_intrin.get_bit_len(
                self.var_dtype) // 8
            self.indices_dtype_bytes_size = cce.cce_intrin.get_bit_len(
                self.indices_dtype) // 8
            self.var_data_each_block = 32 // self.var_dtype_bytes_size
            self.indices_data_each_block = 32 // self.indices_dtype_bytes_size
            self.indices_ub_number = 0
            self.updates_ub_number = 0
    
            self.index_loop_num = 0
    
            self.max_num_one_repeat = 128
            if self.var_dtype in ("float32", "int32"):
                self.max_num_one_repeat = 64
            # 计算使用的AI Core的个数,以及每个AI Core处理多少个index,对于updates分片小于32B场景采用单核
            if self.update_data_num < self.var_data_each_block:
                self.block_num = 1
            else:
                ai_core_num = tik.Dprofile("v100",
                                           self.product_name).get_aicore_num()
                self.indice_step = math.ceil(self.max_indice / ai_core_num)
                self.block_num = math.ceil(self.max_indice / self.indice_step)
            # 定义输入和输出在GM中的tensor
            self.var_gm = self.tik_instance.Tensor(
                self.var_dtype, self.var_shape, name="var_gm", scope=tik.scope_gm)
            self.indices_gm = self.tik_instance.Tensor(
                self.indices_dtype,
                self.indices_shape,
                name="indices_gm",
                scope=tik.scope_gm)
            self.updates_gm = self.tik_instance.Tensor(
                self.updates_dtype,
                self.updates_shape,
                name="updates_gm",
                scope=tik.scope_gm)
            self.out_gm = self.tik_instance.Tensor(
                self.var_dtype, self.var_shape, name="out_gm", scope=tik.scope_gm)
    
            self.vconv_dst_dtype = "float16"
    
            self.init_ub_tensor_para()
            self.var_vconv_ub = None
            self.updates_vconv_ub = None
            self.var_tile_vconv_ub = None
            self.updates_tile_vconv_ub = None
    
            self.var_ub = None
            self.updates_ub = None
            self.indices_ub = None
            self.var_tile_ub = None
            self.updates_tile_ub = None
    
            self.var_read_index = None
            self.updates_read_index = None
            self.indices_loop_index = None
            self.indices_tmp = None
    
        # 计算UB大小的划分,根据输入的shape大小和数据类型计算
        def init_ub_tensor_para(self):
            updates_size_bytes = self.var_dtype_bytes_size * self.update_data_num
            indices_size_bytes = self.indices_dtype_bytes_size * self.indices_num
            need_vconv_dtype = ("int8", "uint8")
            # update数据类型为int8或者uint8时的计算方法
            if self.var_dtype in need_vconv_dtype:
                vconv_dtype_bytes_size = cce.cce_intrin.get_bit_len(
                    self.vconv_dst_dtype)
                vconv_data_each_block = 32 // vconv_dtype_bytes_size
                vconv_size_bytes = (
                    updates_size_bytes // self.var_dtype_bytes_size *
                    vconv_dtype_bytes_size)
                # 当update和var分片能在UB上放下时优先存储这两个数据
                if (updates_size_bytes + vconv_size_bytes) * 2 < (
                        self.ub_size_bytes * 0.9):
                    self.updates_ub_number = math.ceil(
                        self.update_data_num /
                        self.var_data_each_block) * self.var_data_each_block
    
                    self.vconv_ub_number = math.ceil(
                        self.update_data_num /
                        vconv_data_each_block) * vconv_data_each_block
    
                    self.indices_ub_number = (
                        self.ub_size_bytes - updates_size_bytes * 2 -
                        vconv_size_bytes * 2) // self.indices_dtype_bytes_size
    
                    self.indices_ub_number = math.ceil(
                        self.indices_ub_number /
                        self.indices_data_each_block) * self.indices_data_each_block
                # 当update和var分片在UB上放不下时,如果indices能放下,优先存储indices数据
                elif indices_size_bytes < (self.ub_size_bytes * 0.9):
                    self.indices_ub_number = math.ceil(
                        self.indices_num /
                        self.indices_data_each_block) * self.indices_data_each_block
                    self.updates_ub_number = (
                        self.ub_size_bytes -
                        indices_size_bytes) // self.var_dtype_bytes_size // 6
    
                    self.updates_ub_number = math.ceil(
                        self.updates_ub_number /
                        self.var_data_each_block) * self.var_data_each_block
    
                    self.vconv_ub_number = math.ceil(
                        self.updates_ub_number /
                        vconv_data_each_block) * vconv_data_each_block
                # 都放不下时,UB内存对半分
                else:
                    self.updates_ub_number = (self.ub_size_bytes // 2 //
                                              (vconv_dtype_bytes_size +
                                               self.var_dtype_bytes_size) // 2 //
                                              self.var_data_each_block *
                                              self.var_data_each_block)
                    self.indices_ub_number = (self.ub_size_bytes //
                                              self.indices_dtype_bytes_size // 2 //
                                              self.var_data_each_block *
                                              self.var_data_each_block)
                    self.vconv_ub_number = self.updates_ub_number
            # update数据类型非int8或者uint8时的处理方法
            else:
                # 当update和var分片能在UB上放下时优先存储这两个数据
                if updates_size_bytes * 2 < self.ub_size_bytes * 0.9:
                    self.updates_ub_number = math.ceil(
                        self.update_data_num /
                        self.var_data_each_block) * self.var_data_each_block
                    self.indices_ub_number = (
                        self.ub_size_bytes -
                        updates_size_bytes * 2) // self.indices_dtype_bytes_size
                    self.indices_ub_number = math.ceil(
                        self.indices_ub_number /
                        self.indices_data_each_block) * self.indices_data_each_block
                    if self.indices_num < self.indices_ub_number:
                        self.indices_ub_number = math.ceil(
                            self.indices_num / self.indices_data_each_block
                        ) * self.indices_data_each_block
                # 当update和var分片在UB上放不下时,如果indices能放下,优先存储indices数据
                elif indices_size_bytes < self.ub_size_bytes * 0.9:
                    self.indices_ub_number = math.ceil(
                        self.indices_num /
                        self.indices_data_each_block) * self.indices_data_each_block
    
                    self.updates_ub_number = (
                        self.ub_size_bytes -
                        indices_size_bytes) // 2 // self.var_dtype_bytes_size
    
                    self.updates_ub_number = math.ceil(
                        self.updates_ub_number /
                        self.var_data_each_block) * self.var_data_each_block
                # 都放不下时,UB内存对半分
                else:
                    self.indices_ub_number = (self.ub_size_bytes //
                                              self.indices_dtype_bytes_size // 2 //
                                              self.indices_data_each_block *
                                              self.indices_data_each_block)
                    self.updates_ub_number = (self.indices_ub_number // 2 //
                                              self.var_data_each_block *
                                              self.var_data_each_block)
    
            last_num = self.update_data_num % self.updates_ub_number
            if (last_num < self.var_data_each_block and
                    self.update_data_num > self.updates_ub_number):
                self.updates_ub_number -= self.var_data_each_block

  3. 计算过程实现。

    根据tiling的计算结果,我们判断要不要使用多核。如果要使用多核,就需要设置多核循环。并且定义UB tensor的操作必须定义在多核循环内,防止编译时出现冲突。对于多核场景,每次循环都会遍历输入张量indices,在计算出index后判断该index是否在当前核的处理范围内再进行计算。
        def scatter_operator(self):
            # 根据tiling计算结果判断能否开多核,如果需要开多核,需要指定多核循环
            if self.block_num > 1:
                with self.tik_instance.for_range(
                        0, self.block_num,
                        block_num=self.block_num) as indices_loop_index:
                  # 初始化UB中的tensor
                    self.init_ub_tensor()
                    self.indices_loop_index.set_as(indices_loop_index)
                    # 遍历indices索引计算
                    self.traversing_indices()
            else:
                self.init_ub_tensor()
                self.traversing_indices()
    
            # 通过BuildCCE接口进行算子编译,最终生成算子目标文件.o与算子描述文件.json
            self.tik_instance.BuildCCE(
                kernel_name=self.kernel_name,
                inputs=(self.var_gm, self.indices_gm, self.updates_gm),
                outputs=(self.out_gm),
                enable_l2=False)
    
            return self.tik_instance
    1. traversing_indices函数定义。
      该函数主要操作是将indices分片搬入到UB中,然后遍历和计算出需要更新的var对应的index。搬运的时候需要考虑最后一个分片,搬运的burst_len需要单独计算。将一个indice分片搬入到UB后,在self.updates_the_var函数中遍历当前UB中的indices,做相应的计算和处理。
          def traversing_indices(self):
              # 计算indices需要分多少次搬入UB进行遍历,根据给indices分配的UB大小来计算
              max_ub_idx_num = (self.indices_ub_number // self.index_dims *
                                self.index_dims)
              indices_loop_num = self.indices_num // max_ub_idx_num
      
              if indices_loop_num > 0:
                  with self.tik_instance.for_range(
                          0, indices_loop_num) as indices_loop_index:
                      # 封装计算var分片的函数,对每一个index做更新操作,输入的参数为var和updates读取的偏移量
                      self.updates_the_var(indices_loop_index * max_ub_idx_num,
                                           max_ub_idx_num)
              # 遍历的尾巴,或者只需要一次搬入遍历的场景
              indices_last_num = self.indices_num % max_ub_idx_num
              if indices_last_num > 0:
                  self.updates_the_var(indices_loop_num * max_ub_idx_num,
                                       indices_last_num)
    2. updates_the_var函数定义。
      该函数的入参为当前搬运到UB的indices的位置和个数。indices的位置主要用来计算当前的indices对应的updates分片的位置,indices的个数主要用来计算需要遍历多少个index。对于当前遍历计算出来的index,判断是否在当前核心的处理范围,如果不是,就跳过不进行处理。对于每个updates分片的处理,我们仍然需要考虑UB放不下后需要分片处理。对于每个分片的处理,我们可以封装相同的规则进行处理。
          def updates_the_var(self, indices_in_index, indice_num):
              # 计算数据搬运的burst_len
              indices_burst_len = math.ceil(indice_num / self.indices_data_each_block)
              # 将indices搬运到UB
              if self.indices_num == 1:
                  self.tik_instance.data_move(self.indices_ub, self.indices_gm, 0, 1,
                                              indices_burst_len, 0, 0)
              else:
                  self.tik_instance.data_move(self.indices_ub,
                                              self.indices_gm[indices_in_index], 0, 1,
                                              indices_burst_len, 0, 0)
              if self.nd_flag:
                  indice_loop_num = indice_num // self.indices_shape[-1]
              else:
                  indice_loop_num = indice_num
              # 遍历搬运到UB的indices
              with self.tik_instance.for_range(0,
                                               indice_loop_num) as indices_ub_index:
                  self.get_var_read_index(indices_ub_index)
                  if self.block_num > 1:
                      # 判断index是否在当前核的计算范围内,如果在,进行对应的计算
                      with self.tik_instance.if_scope(
                              self.indices_loop_index *
                              self.indice_step <= self.var_read_index):
                          with self.tik_instance.if_scope(
                                  (self.indices_loop_index + 1) *
                                  self.indice_step > self.var_read_index):
                              if self.nd_flag:
                                  indices_in_index = indices_in_index // \
                                                     self.indices_shape[
                                                         -1]
                              self.get_updates_read_index(indices_ub_index +
                                                          indices_in_index)
                              self.var_read_index.set_as(self.var_read_index *
                                                         self.update_data_num)
                              # 计算update和var的函数
                              self.calc_updates()
                  else:
                      if self.nd_flag:
                          indices_in_index = indices_in_index // self.indices_shape[
                              -1]
                      self.get_updates_read_index(indices_ub_index + indices_in_index)
                      self.var_read_index.set_as(self.var_read_index *
                                                 self.update_data_num)
                      self.calc_updates()
          # 对updates数据进行分段遍历
          def calc_updates(self):
              updates_loop = self.update_data_num // self.updates_ub_number
              if updates_loop > 0:
                  with self.tik_instance.for_range(0, updates_loop) as loop_index:
                      self.calc_updates_small(loop_index * self.updates_ub_number,
                                              self.updates_ub_number)
      
              last_num = self.update_data_num % self.updates_ub_number
      
              if last_num > 0:
                  self.calc_updates_small(updates_loop * self.updates_ub_number,
                                          last_num)
    3. calc_updates_small函数定义。
      该函数主要实现每个updates分片的处理。在实现的过程中主要需要考虑非32B对齐场景,多核时序问题导致的写覆盖规避。同时,由于vec_add计算指令单条指令最大只能计算255*128(32640)个float16数据(255次repeat,每个repeat计算128个数,每次repeat计算的最大个数和数据类型相关)。因此,我们需要进行三步处理。第一步,通过tik.for_range循环计算多次,每次计算255*128个数据。剩下的通过设置repeat次数,将N*128个元素通过一条指令计算完毕。最后小于128个元素,通过设置mask进行精确计算。相同的数据量,vec_add调用次数越少,性能越高。通过三次处理,可以做到shape的泛化和最优性能。
           def calc_updates_small(self, read_index_offset, element_num):
              # 计算一次搬运到UB的burst_len参数
              updates_burst_len = math.ceil(element_num / self.var_data_each_block)
              # 将需要更新的var分片搬运到UB buffer
              self.tik_instance.data_move(
                  self.var_ub, self.var_gm[self.var_read_index + read_index_offset],
                  0, 1, updates_burst_len, 0, 0)
              # 将需要更新的updates分片搬运到UB buffer上
              self.tik_instance.data_move(
                  self.updates_ub,
                  self.updates_gm[self.updates_read_index + read_index_offset], 0, 1,
                  updates_burst_len, 0, 0)
              # 计算非32B对齐场景尾巴的数据有多少,需要两次计算和搬运防止写覆盖
              tile_ele_num = element_num % self.var_data_each_block
              align_offset = 0
              # 非32B对齐,且大于32B的场景进行计算。并将计算结果搬出
              if (tile_ele_num != 0 and
                      self.update_data_num > self.var_data_each_block):
                  align_ele_num = (
                      element_num // self.var_data_each_block *
                      self.var_data_each_block)
                  align_offset = (
                      read_index_offset + align_ele_num -
                      (self.var_data_each_block - tile_ele_num))
                  self.tik_instance.data_move(
                      self.var_tile_ub,
                      self.var_gm[self.var_read_index + align_offset], 0, 1, 1, 0, 0)
      
                  self.tik_instance.data_move(
                      self.updates_tile_ub,
                      self.updates_gm[self.updates_read_index + align_offset], 0, 1, 1, 0, 0)
      
              compute_loop = element_num // self.max_num_one_repeat // 255
              // 对于vec_add指令,根据updates数量大小判断需要调用多少次
              if compute_loop > 0:
                  with self.tik_instance.for_range(0, compute_loop) as index:
                      index_offset = index * self.max_num_one_repeat * 255
                      self.calc_process(self.max_num_one_repeat, index_offset,
                                        index_offset, index_offset, 255, False)
              last_loop = element_num % (self.max_num_one_repeat *
                                         255) // self.max_num_one_repeat
      
              if last_loop > 0:
                  index_offset = compute_loop * self.max_num_one_repeat * 255
                  self.calc_process(self.max_num_one_repeat, index_offset,
                                    index_offset, index_offset, last_loop, False)
      
              compute_mask = element_num % self.max_num_one_repeat
              if compute_mask > 0:
                  index_offset = (
                      element_num // self.max_num_one_repeat *
                      self.max_num_one_repeat)
                  # 32B对齐场景,只需要将数据一次搬出去
                  if (tile_ele_num == 0 or
                          self.update_data_num < self.var_data_each_block):
                      self.calc_process(compute_mask, index_offset, index_offset,
                                        index_offset, 1, False)
      
                      self.tik_instance.data_move(
                          self.out_gm[self.var_read_index + read_index_offset],
                          self.var_ub, 0, 1, updates_burst_len, 0, 0)
                  # 非32B对齐场景,需要把对齐部分和非对齐部分分两次计算,然后搬出
                  else:
                      self.calc_process(self.var_data_each_block, 0, 0, 0, 1, True)
                      self.tik_instance.data_move(
                          self.out_gm[self.var_read_index + align_offset],
                          self.var_tile_ub, 0, 1, 1, 0, 0)
                      self.calc_process(compute_mask, index_offset, index_offset,
                                        index_offset, 1, False)
                      self.tik_instance.data_move(
                          self.out_gm[self.var_read_index + read_index_offset],
                          self.var_ub, 0, 1, updates_burst_len - 1, 0, 0)
              else:
                  self.tik_instance.data_move(
                      self.out_gm[self.var_read_index + read_index_offset],
                      self.var_ub, 0, 1, updates_burst_len, 0, 0)
    4. calc_process函数定义。
      对于核心的计算指令,我们封装成calc_process函数,主要来做数据类型和计算类型的泛化。首先,对于int8和uint8类型的数据,无法直接使用vec_add接口进行计算。此时需要使用vconv进行数据类型转换再进行计算,并在计算完成之后转换回之前的数据类型。其次,由于scatter_nd_add和scatter_nd_sub计算过程一样,只是最后调用的计算指令不一样。我们可以通过参数来进行控制进行那种计算,以实现一个模板适配多个算子类型。
          def calc_process(self, mask, dest_addr, src_addr1, src_addr2, repeat_times, 
                           is_tile):
              need_vconv_dtype = ("int8", "uint8")
              # 对于int8和uint8数据类型,需要进行转换后再进行计算
              if self.var_dtype in need_vconv_dtype:
                  if is_tile:
                      self.tik_instance.vec_conv(mask, "",
                                              self.var_tile_vconv_ub[dest_addr],
                                              self.var_tile_ub[src_addr1],
                                              repeat_times, 8, 4)
                      self.tik_instance.vec_conv(mask, "",
                                              self.updates_tile_vconv_ub[dest_addr],
                                              self.updates_tile_ub[src_addr2],
                                              repeat_times, 8, 4)
                      compute_repeat_strid = 8
                      src1_ub = self.var_tile_vconv_ub
                      src2_ub = self.updates_tile_vconv_ub
                      dst_ub = self.var_tile_vconv_ub
                      mask = self.var_data_each_block
                  else:
                      self.tik_instance.vec_conv(mask, "", self.var_vconv_ub[dest_addr],
                                              self.var_ub[src_addr1], repeat_times, 8, 4)
                      self.tik_instance.vec_conv(mask, "",
                                              self.updates_vconv_ub[dest_addr],
                                              self.updates_ub[src_addr2],
                                              repeat_times, 8, 4)
                      compute_repeat_strid = 8
                      src1_ub = self.var_vconv_ub[src_addr1]
                      src2_ub = self.updates_vconv_ub[src_addr2]
                      dst_ub = self.var_vconv_ub[dest_addr]
      
              else:
                  if is_tile:
                      compute_repeat_strid = (
                          self.max_num_one_repeat // self.var_data_each_block)
                      src1_ub = self.var_tile_ub
                      src2_ub = self.updates_tile_ub
                      dst_ub = self.var_tile_ub
                      mask = self.var_data_each_block
                  else:
                      compute_repeat_strid = (
                          self.max_num_one_repeat // self.var_data_each_block)
                      src1_ub = self.var_ub[src_addr1]
                      src2_ub = self.updates_ub[src_addr2]
                      dst_ub = self.var_ub[dest_addr]
      
              if self.compute_type == "vadd":
                  self.tik_instance.vec_add(mask, dst_ub, src1_ub, src2_ub, repeat_times,
                                         compute_repeat_strid,
                                         compute_repeat_strid, compute_repeat_strid)
              elif self.compute_type == "vsub":
                  self.tik_instance.vec_sub(mask, dst_ub, src1_ub, src2_ub, repeat_times,
                                         compute_repeat_strid,
                                         compute_repeat_strid, compute_repeat_strid)
              else:
                  raise RuntimeError("the operater [%s] is not supported" %
                                     self.compute_type)
              if self.var_dtype in need_vconv_dtype:
                  if is_tile:
                      self.tik_instance.vec_conv(mask, "", self.var_tile_ub,
                                              self.var_tile_vconv_ub, repeat_times, 4, 8)
                  else:
                      self.tik_instance.vec_conv(mask, "", self.var_ub[src_addr1],
                                              self.var_vconv_ub[dest_addr],
                                              repeat_times, 4, 8)

算子适配插件实现

将原始Tensorflow的ScatterNdAdd算子或者ResourceScatterNdAdd算子解析并映射为适配昇腾AI处理器的ScatterNdAdd算子,算子属性的映射可直接调用AutoMappingFn( )接口进行实现,完整代码可参考sample样例中的“framework/tf_plugin/scatter_nd_add_plugin.cpp”文件。

算子原型定义

ScatterNdAdd算子的原型定义详细代码可参见“op_proto/scatter_nd_add.h”“op_proto/scatter_nd_add.cpp”文件。

scatter_nd_add.h对ScatterNdAdd算子进行原型定义。

#ifndef GE_OP_ARG_MAX_H
#define GE_OP_ARG_MAX_H
#include "graph/operator_reg.h"

namespace ge {
REG_OP(ScatterNdAdd)
    .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
    .INPUT(indices, TensorType::IndexNumberType())
    .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
    .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
    .ATTR(use_locking, Bool, false)
    .OP_END_FACTORY_REG(ScatterNdAdd)

#endif  // GE_OP_ARG_MAX_H

IndexNumberType()的数据类型定义请参见“inc/graph/types.h”文件,此文件中定义了所有GE使用的数据类型。

scatter_nd_add.cpp对算子基本类型进行校验并推理算子的输出shape。

由于输入tensor var与updates的数据类型要求相同,所以需要对其进行校验:

IMPLEMT_VERIFIER(ScatterNdAdd, ScatterNdAddVerify) {
  if (!CheckTwoInputDtypeSame(op, "var", "updates")) {
    return GRAPH_FAILED;
  }
  return GRAPH_SUCCESS;
}

将输入tensor var的shape与数据类型更新到输出tensor。

IMPLEMT_COMMON_INFERFUNC(ScatterNdAddInferShape) {
  Shape var_shape = op.GetInputDesc("var").GetShape();
  DataType input_dtype = op.GetInputDesc("var").GetDataType();
  TensorDesc td = op.GetOutputDesc("var");
  td.SetShape(ge::Shape(var_shape));
  td.SetDataType(input_dtype);
  (void)op.UpdateOutputDesc("var", td);
  return GRAPH_SUCCESS;
}

算子信息定义

ScatterNdAdd算子的信息定义文件请参见“tbe/op_info_cfg/ai_core/<soc_version>/scatter_nd_add.ini”,由于信息定义中未配置算子实现代码的Python文件的名字opFile.vaule以及算子定义函数的名字opInterface.vaule,所以FE默认按照将算子类型中的大写字符转换为下划线加小写字符的形式去匹配算子实现文件与算子定义函数名字,匹配规则可参见4

分享:

    相关文档

    相关产品

close