Updated on 2022-03-13 GMT+08:00

AITensorFactory

Tensor factory class, which is used to create tensors. The AITensorFactory class is defined in ai_tensor.h.

    class AITensorFactory
    {
    public:
        static AITensorFactory* GetInstance();

        /*
        * @brief  Create a tensor by using parameters, including memory allocation
       * @param [in] tensor_desc  Description that contains the tensor parameters
       * @return shared_ptr<IAITensor>  Pointer to the created tensor. If the creation fails, a null pointer is returned.
        */
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc);

        /*
        * @brief  Create a tensor by using the type, not involving memory allocation. It is used in non-pre-allocation mode.
        * @param [in] type  Type during tensor registration
        * @return shared_ptr<IAITensor>  Pointer to the created tensor. If the creation fails, a null pointer is returned.
        */
        std::shared_ptr<IAITensor> CreateTensor(const std::string &type);
     
       
#if defined( __ANDROID__) || defined(ANDROID)
        /*
        * @brief   Creates a tensor by using parameters and buffers. The content is obtained from buffer deserialization.
       * @param [in] tensor_desc  Description that contains the tensor parameters
        * @param [in] buffer   Existing data buffer
        * @param [in] size   Buffer size
       * @return shared_ptr<IAITensor>  Pointer to the created tensor. If the creation fails, a null pointer is returned.
        */
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc,
            const void *buffer, const int32_t size);
#else
        std::shared_ptr<IAITensor> CreateTensor(const AITensorDescription &tensor_desc,
            void *buffer, int32_t size);
#endif
    
        /*
        * @brief   Register a tensor.
        * @param [in] tensor_desc  Tensor description
        * @param [in] create_func   Function for creating a tensor
        */
        AIStatus RegisterTensor(const AITensorDescription &tensor_desc,
            CREATE_TENSOR_FUN create_func);
        AIStatus RegisterTensor(const string tensor_str, CREATE_TENSOR_FUN create_func)
        {
            // Add a lock to prevent multi-thread concurrency.
            AITensorDescription tensor_desc;
            tensor_desc.set_type(tensor_str);
            return RegisterTensor(tensor_desc, create_func);
        }

       /*
        * @brief  Unregister a tensor.
        * @param [in] tensor_desc  Tensor description
        */
        AIStatus UnRegisterTensor(const AITensorDescription &tensor_desc);
        AIStatus UnRegisterTensor(const string tensor_str)
        {
            AITensorDescription tensor_desc;
            tensor_desc.set_type(tensor_str);
            return UnRegisterTensor(tensor_desc);
        }

        /*
        * @brief  Obtain the list of all tensors.
        * @param [out] tensor_desc_list  Description list of output tensors
        */
        void GetAllTensorDescription(AITensorDescriptionList &tensor_desc_list);

        /*
        * @brief  Obtain the tensor description.
        */
        AIStatus GetTensorDescription(const std::string& tensor_type_name,
            AITensorDescription &tensor_desc);

    private:
        std::map<std::string, AITensorDescription> tensor_desc_map_;
        std::map<std::string, CREATE_TENSOR_FUN> create_func_map_;
        std::mutex tensor_reg_lock_;
    };