更新时间:2021-03-18 GMT+08:00
分享

推理算子输出张量描述

用户需要根据算子的输入张量描述、算子逻辑及算子属性,推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息。这样离线模型转换时就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。

函数声明

函数的声明如下所示:

Status InferShapeAndTypexx(const ge::Operator& op, vector<ge::TensorDesc>& v_output_desc)
  • InferShapeAndTypexx:函数名称,用户自定义,需要保持唯一。
  • op:计算节点定义,存储输入张量描述及各种算子属性,ge::Operator类型的介绍请参见《GE API参考》中的“Operator类接口”
  • v_output_desc:存储该计算节点的输出张量描述,包括形状、数据排布格式及数据类型,关于TensorDesc类型的介绍请参见《GE API参考》中的“TensorDesc类接口”

下面详细讲解不同场景下InferShapeAndTypexx函数的实现。

输出张量与输入张量形状相同的算子

对于输出张量与输入张量形状相同的算子,可以直接将输入张量的描述插入输出张量描述所在的向量空间中。

代码示例如下所示:

v_output_desc.push_back(op.GetInputDesc(0));

其中GetInputDesc是Operator类中根据算子Input名称或者Input索引获取输入张量描述的接口,详细的接口介绍请参见《GE API参考》中的“Operator类接口”

降低维度的算子

对于Reduction、Reduce之类的常见降维操作,需要根据算子输入属性axis等信息计算出输出张量的形状(包含输出张量维度以及每一个维度有多少个元素),然后将输出张量的形状插入到v_output_desc向量中。

代码示例如下所示:

  1. 获取算子的输入张量描述及输入张量的形状信息。

    auto tensorDesc = op.GetInputDesc(0);        //输入张量描述,包括的形状、数据排布格式及数据类型
    auto shape = tensorDesc.GetShape();          //获取输入张量的形状

    GetShape接口的详细介绍请参见《GE API参考》中的“TensorDesc类接口”

  2. 根据计算逻辑获取算子的属性值,并计算算子输出张量的形状。

    例如,对于Mylenet网络中的Reduction算子,由于Reduction的上一层为Softmax,Softmax的实际输出是2维,但在离线模型中会被补齐到4维,所以需要调整axis,将其指向2维的位置,进行reduce操作,并将调整后的Shape值赋给输出张量描述。

    • 从operator对象中获取axis属性的键值对,然后从其中获取axis的属性值,并将其从INT类型转换为int64_t类型赋给名称为axis的变量,最后对axis的值进行校验、调整,将其指向轴1的位置,代码示例如下所示:
      int64_t axis = -1;
      ge::AttrValue axisAttrValue;
          if ((ge::GRAPH_SUCCESS != op.GetAttr("axis", axisAttrValue)) || (ge::GRAPH_SUCCESS != axisAttrValue.GetValue<AttrValue::INT>(axis)))
          {
              printf("Get axis failed!\n");
          }
          // In the OM model, all shape are supplemented to 4d. In this case, axis needs to be repaired to point to the original 2d.
          if (axis < 0) axis -= 2;
      
          if (axis < 0) axis += shape.GetDimNum();
      
          if (axis < 0 || axis >= shape.GetDimNum())
          {
              printf("invalid axis:%d, dim_size:%d\n", (int32_t)axis, (int32_t)shape.GetDimNum());
              return PARAM_INVALID;
          }
      
    • 调整Shape,将轴1及其之后的维度设置为1,例如输入张量的Shape为(2,3,4,5),则调整之后的Shape为(2,1,1,1)。
       int32_t dimsize = (int32_t)shape.GetDimNum();
       int32_t idx = 0;
       for(idx=axis; idx<dimsize; idx++)
       {
           shape.SetDim(idx, 1);
       } 
    • 将调整后的Shape值设置到tensorDesc对象。
       tensorDesc.SetShape(shape); 

    GetDimNum与SetDim接口的详细介绍请参见《GE API参考》中的“Shape类接口”

  3. 设置算子的输出张量描述。

    v_output_desc.push_back(tensorDesc)

    将tensorDesc赋给输出张量的描述对象v_output_desc。

网络存在Type相同的多个算子

一个网络中往往存在多层某个Type相同的算子,例如Convolution算子,开发者有时需要对某层算子进行形状的自定义(重定义已经存在的某层算子),此时在进行输出张量描述推理的时候需要根据网络的不同情况,根据num_output、kernel、stride、 pad等属性判断是针对哪一层的算子进行的自定义,从而推断出对应算子的张量信息。

代码示例如下所示:

  1. 将输入张量描述赋给输出张量描述,并获取算子的输入张量描述及输入张量的形状信息。

    v_output_desc.push_back(op.GetInputDesc(0));  //将输入张量描述信息赋给输出张量描述对象,开发者也可以将后续推理出的shape直接赋给tensorDesc对象,再将tensorDesc对象赋给输出张量描述对象。
    auto tensorDesc = op.GetInputDesc(0);        //输入张量描述,包括的形状、数据排布格式及数据类型
    auto shape = tensorDesc.GetShape();          //获取输入张量的形状

    GetShape接口的详细介绍请参见《GE API参考》中的“TensorDesc类接口”

  2. 根据计算逻辑获取算子的属性值,并根据算子的属性值、shape匹配算子,并计算算子输出张量的形状。

    例如,对于某一网络中的Type为Convolution的算子(网络中存在多层Convolution),匹配num_outputs为128、shape.GetDim(0)为1、shape.GetDim(1) 为128、shape.GetDim(2) 为28、shape.GetDim(3) 为28的Convolution,并对其输出描述张量的形状重新赋值。

    • 获取num_outputs的属性值
      ge::AttrValue num_outputsAttrValue;
          if ((ge::GRAPH_SUCCESS != op.GetAttr("num_output", num_outputsAttrValue)) || 
              (ge::GRAPH_SUCCESS != num_outputsAttrValue.GetValue<AttrValue::INT>(num_outputs)))
          {
              printf("GetOpAttr num_outputs failed!\n");
          }
    • 匹配num_outputs为128、shape.GetDim(0)为1、shape.GetDim(1) 为128、shape.GetDim(2) 为28、shape.GetDim(3) 为28的Convolution,并对其输出张量描述的形状重新赋值。
       if(shape.GetDim(0) == 1 && shape.GetDim(1) == 128 &&
              shape.GetDim(2) == 28 && shape.GetDim(3) == 28 && num_outputs == 128)
          {
              shape.SetDim(0, 1);
              shape.SetDim(1, 128);
              shape.SetDim(2, 28);
              shape.SetDim(3, 28);
              v_output_desc[0].SetShape(shape);
              return SUCCESS;
              return FAILED;
          }

    GetDimNum与SetDim接口的详细介绍请参见《GE API参考》中的“Shape类接口”

    若需要使用op_name进行匹配,获取op_name的方法如下所示:

    auto op_name = op.GetName();

    其他属性kernel_w、kernel_h、stride_w、stride_h、pad_w、pad_h的获取方法与num_output相同,修改op.GetAttr中的key值即可。

分享:

    相关文档

    相关产品