更新时间:2024-11-21 GMT+08:00
分享

使用Tool Retriever优化Agent性能(Python SDK)

Agent在实际生产应用中往往涉及到的工具数量较多,如果把所用的工具全部添加至Agent会产生如下问题:

  • 占用大量输入token。
  • 和问题无关的工具太多,影响模型的判断。

通过Tool Retriever可以解决上述问题,其原理是在Agent运行前,先从所有可用的工具中选择与问题最相关的工具,再交给Agent去处理。

  • 定义一个Tool Retriever:
    from pangukitsappdev.tool.in_memory_tool_provider import InMemoryToolProvider
    from pangukitsappdev.retriever.css_tool_retriever import CSSToolRetriever
    
    # 新增InMemoryToolProvider,添加工具集
    in_memory_tool_provider = InMemoryToolProvider()
    tool_list = [AddTool(), ReverseTool(), ReserveMeeting(), ReserveMeetingRoom()]
    in_memory_tool_provider.add(tool_list)
    
    # 初始化CSSToolRetriever
    vector_config = VectorStoreConfig(index_name="your_index_name",
                                              verify_certs=False,
                                              text_key="name",
                                              vector_fields=["name", "description"])
    css_tool_retriever = CSSToolRetriever(tool_provider, vector_config)

    定义一个ToolRetriever包含2个参数,一个ToolProvider,一个向量数据库配置。其中,ToolProvider的作用为根据工具检索的结果组装工具。

    上述例子使用了一个简单的InMemoryToolProvider,InMemoryToolProvider的原理为将完整的工具存入内存,再根据工具检索的结果(tool_id)将其从内存中取出。一般来说,ToolProvider将由用户自定义,后续会有例子说明。

    上述例子使用的向量数据库配置指定索引名称,以及使用name和description作为向量化字段,因此工具入库时,会将工具的name和description进行向量化,并在后续的检索中生效。

    注意,上述tool_list中包含的工具在SDK中并不存在,需要替换成实际的工具。

  • 向ToolRetriever中添加工具:
    # 添加工具
    css_tool_retriever.add_tools(tool_list)

    工具添加后,会存储在向量库的索引中,并将指定的字段向量化。

  • 从ToolRetriever中查找工具:
    # 查找工具
    result = css_tool_retriever.search("预订会议室", 2)

    返回的result中,包含与预订会议室最相关的工具。搜索支持topK和阈值2个参数,例如上例指定topK=2,则最多返回2个工具。

  • 从ToolRetriever中删除工具:
    # 删除工具
    css_tool_retriever.remove(["add", "reverse"])

以上为一个比较基础的用法,在实际使用过程中会有更加灵活的场景,可以通过自定义ToolProvider的方式解决。

  • 自定义ToolProvider:
    # 初始化CSSToolRetriever,使用ToolProviderWithMetadata作为ToolProvider
    vector_config = VectorStoreConfig(index_name="your_index_name",
                                              verify_certs=False,
                                              text_key="name",
                                              vector_fields=["name", "description", "principle"])
    css_tool_retriever = CSSToolRetriever(ToolProviderWithMetadata(), vector_config)
    其中,ToolProviderWithMetadata为自定义ToolProvider:
    import pickle
    
    class ToolProviderWithMetadata(ToolProvider):
    
        def provide(self, retrieved_tools: List[RetrievedTool], query: str) -> List[AbstractTool]:
            retrieved_tools = self.do_some_filter(retrieved_tools, query)
            return [Tool.from_function(func=pickle.loads(eval(retrieved_tool.tool_metadata.get("function"))),
                                       name=retrieved_tool.tool_id,
                                       description=retrieved_tool.tool_metadata.get("description"),
                                       principle=retrieved_tool.tool_metadata.get("principle"),
                                       input_desc=retrieved_tool.tool_metadata.get("input_desc"),
                                       output_desc=retrieved_tool.tool_metadata.get("output_desc"),
                                       args_schema=pickle.loads(eval(retrieved_tool.tool_metadata.get("args_schema"))),
                                       return_type=pickle.loads(eval(retrieved_tool.tool_metadata.get("return_type"))))
                    for retrieved_tool in retrieved_tools]
    
        @staticmethod
        def do_some_filter(retrieved_tools: List[RetrievedTool], query: str) -> List[RetrievedTool]:
            print(f"{retrieved_tools}, {query}")
            return retrieved_tools

    上述tool_provider中,实现了provide接口,可以利用工具检索的返回动态构建出工具列表,同时也可以加一些后处理工作,例如根据黑白名单做工具的过滤。

  • 与上述的tool_provide呼应,在向tool_retriever中添加工具时,可以添加任意的元数据,python需要借助pickle将函数或类转换成字节流字符串存入CSS中,用于在tool_provider中把工具组装出来:
    from pydantic import BaseModel, Field
    import pickle
    
    # 构造工具元数据
    class MeetingInfo(BaseModel):
        id: str = Field(description="会议ID")
        info: str = Field(description="会议信息")
    
    def list_meeting(inputs: NoneType) -> List[MeetingInfo]:
        return [MeetingInfo(id=1, info="金桥2023"), MeetingInfo(id=2, info="金桥203")]
    
    tool_meta_data = {
        "name": "list_meeting",
        "description": "查询员工的会议预订状态,返回已经预订的会议和其会议ID",
        "principle":"请在需要查询员工已预订会议室列表时使用",
        "input_desc": "",
        "output_desc": "已预订会议室列表",
        "args_schema": str(pickle.dumps(None)),
        "function": str(pickle.dumps(list_meeting)),
        "return_type": str(pickle.dumps(MeetingInfo))
    }
    # 工具管理面添加工具到toolRetriever,这里实际可以添加若干个工具
    css_tool_retriever.add_tools_from_metadata([tool_meta_data])
    
    # 运行时检索工具,并添加到Agent执行
    tool_list = css_tool_retriever.search("查询会议室预订状态", 1, 0.8)

    工具的检索与之前的用法一致。

  • 以下是一个将Tool Retriever集成在Agent中的完整示例:
    from pangukitsappdev.skill.conversation_rewrite_skill import ConversationRewriteSkill
    
    # 工具集
    toolList = [AddTool(), ReverseTool(), ReserveMeeting(), ReserveMeetingRoom()]
    
    # 新增InMemoryToolProvider,添加工具集
    in_memory_tool_provider = InMemoryToolProvider()
    tool_list = [AddTool(), ReverseTool(), ReserveMeeting(), ReserveMeetingRoom()]
    in_memory_tool_provider.add(tool_list)
    
    # 初始化CSSToolRetriever
    vector_config = VectorStoreConfig(index_name="your_index_name",
                                              verify_certs=False,
                                              text_key="name",
                                              vector_fields=["name", "description"])
    css_tool_retriever = CSSToolRetriever(tool_provider, vector_config)
    
    
    # 添加工具
    css_tool_retriever.add_tools(tool_list)
    
    # 添加多轮改写
    css_tool_retriever.set_query_preprocessor(ConversationRewriteSkill(LLMs.of("pangu")).rewrite)
    
    # 为Agent添加ToolRetriever
    agent = ReactPanguAgent(LLMs.of("yundao"))
    agent.set_tool_retriever(css_tool_retriever);
    
    # 多轮对话调用
    messages = [ConversationMessage(role=Role.USER, content="定个2点的会议"),
                ConversationMessage(role=ROLE.ASSISTANT, content="请问您的会议预计何时结束?另外,您是需要预订线上会议还是实体会议室?"),
                ConversationMessage(role=Role.USER, content="4点结束,线上会议")]
    print(agent.run(messages))
    
    # 删除工具
    cssToolRetriever.remove([tool.name for tool in tool_list])

    其中,有两个变化值得关注,一是为ToolRetriever添加了一个query_preprocessor,它的作用为对用户输入的多轮对话进行改写,会将改写后的结果作为工具检索的输入,这里使用了系统内置的ConversationRewriteSkill,它的作用为将多轮对话改写为单轮。二是在创建一个Agent后,调用了set_tool_retriever方法为其添加了一个ToolRetriever,这样Agent所使用的工具会根据用户的对话动态的选择。

相关文档