更新时间:2021-03-18 GMT+08:00
分享

AITensorFactory

Tensor工厂类,用来创建Tensor。AITensorFactory类在ai_tensor.h中定义。

    class AITensorFactory
    {
    public:
        static AITensorFactory* GetInstance();

        /*
        * @brief 通过参数创建Tensor,包含分配内存
        * @param [in] tensor_desc 包含Tensor参数的描述信息
        * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr
        */
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc);

        /*
        * @brief 通过type创建Tensor,不分配内存,非预分配情况下使用
        * @param [in] type Tensor注册时的类型
        * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr
        */
        std::shared_ptr<IAITensor> CreateTensor(const std::string &type);
     
       
#if defined( __ANDROID__) || defined(ANDROID)
        /*
        * @brief 通过参数、buffer创建Tensor,内容从buffer反序列化得到
        * @param [in] tensor_desc 包含Tensor参数的描述信息
        * @param [in] buffer 已存在的数据缓存
        * @param [in] size buffer的大小
        * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr
        */
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc,
            const void *buffer, const int32_t size);
#else
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc,
            void *buffer, int32_t size);
#endif
    
        /*
        * @brief 注册Tensor
        * @param [in] tensor_desc tensor描述
        * @param [in] create_func tensor创建函数
        */
        AIStatus RegisterTensor(const AITensorDescription &tensor_desc,
            CREATE_TENSOR_FUN create_func);
        AIStatus RegisterTensor(const string tensor_str, CREATE_TENSOR_FUN create_func)
        {
            // 加锁,防止多线程并发
            AITensorDescription tensor_desc;
            tensor_desc.set_type(tensor_str);
            return RegisterTensor(tensor_desc, create_func);
        }

       /*
        * @brief 卸载注册Tensor
        * @param [in] tensor_desc tensor描述
        */
        AIStatus UnRegisterTensor(const AITensorDescription &tensor_desc);
        AIStatus UnRegisterTensor(const string tensor_str)
        {
            AITensorDescription tensor_desc;
            tensor_desc.set_type(tensor_str);
            return UnRegisterTensor(tensor_desc);
        }

        /*
        * @brief 获取所有的Tensor列表
        * @param [out] tensor_desc_list 输出Tensor描述列表
        */
        void GetAllTensorDescription(AITensorDescriptionList &tensor_desc_list);

        /*
        * @brief 获取tensor描述
        */
        AIStatus GetTensorDescription(const std::string& tensor_type_name,
            AITensorDescription &tensor_desc);

    private:
        std::map<std::string, AITensorDescription> tensor_desc_map_;
        std::map<std::string, CREATE_TENSOR_FUN> create_func_map_;
        std::mutex tensor_reg_lock_;
    };
分享:

    相关文档

    相关产品