更新时间:2021-03-18 GMT+08:00
AITensorFactory
Tensor工厂类,用来创建Tensor。AITensorFactory类在ai_tensor.h中定义。
class AITensorFactory { public: static AITensorFactory* GetInstance(); /* * @brief 通过参数创建Tensor,包含分配内存 * @param [in] tensor_desc 包含Tensor参数的描述信息 * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr */ std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc); /* * @brief 通过type创建Tensor,不分配内存,非预分配情况下使用 * @param [in] type Tensor注册时的类型 * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr */ std::shared_ptr<IAITensor> CreateTensor(const std::string &type); #if defined( __ANDROID__) || defined(ANDROID) /* * @brief 通过参数、buffer创建Tensor,内容从buffer反序列化得到 * @param [in] tensor_desc 包含Tensor参数的描述信息 * @param [in] buffer 已存在的数据缓存 * @param [in] size buffer的大小 * @return shared_ptr<IAITensor> 创建完成的Tensor指针,如果创建失败,则返回nullptr */ std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc, const void *buffer, const int32_t size); #else std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc, void *buffer, int32_t size); #endif /* * @brief 注册Tensor * @param [in] tensor_desc tensor描述 * @param [in] create_func tensor创建函数 */ AIStatus RegisterTensor(const AITensorDescription &tensor_desc, CREATE_TENSOR_FUN create_func); AIStatus RegisterTensor(const string tensor_str, CREATE_TENSOR_FUN create_func) { // 加锁,防止多线程并发 AITensorDescription tensor_desc; tensor_desc.set_type(tensor_str); return RegisterTensor(tensor_desc, create_func); } /* * @brief 卸载注册Tensor * @param [in] tensor_desc tensor描述 */ AIStatus UnRegisterTensor(const AITensorDescription &tensor_desc); AIStatus UnRegisterTensor(const string tensor_str) { AITensorDescription tensor_desc; tensor_desc.set_type(tensor_str); return UnRegisterTensor(tensor_desc); } /* * @brief 获取所有的Tensor列表 * @param [out] tensor_desc_list 输出Tensor描述列表 */ void GetAllTensorDescription(AITensorDescriptionList &tensor_desc_list); /* * @brief 获取tensor描述 */ AIStatus GetTensorDescription(const std::string& tensor_type_name, AITensorDescription &tensor_desc); private: std::map<std::string, AITensorDescription> tensor_desc_map_; std::map<std::string, CREATE_TENSOR_FUN> create_func_map_; std::mutex tensor_reg_lock_; };
父主题: 其他用于编译依赖的接口