注册KVCache
功能介绍
将推理过程中使用的KVCache内存布局一次性注册到EMS上下文缓存(Context Caching)中,用于后续save/load/async_save/async_load接口按块(block)进行定位与管理。注册仅建立元数据与地址映射,不执行实际读写。
接口约束
- 接口为同步调用,完成初始化后返回。
- KVCache期望形状应为:[layers, k_v_index, GPU_blocks, Block_size, heads, head_size]。
- Block_size将作为后续slot_mapping解析的基准,注册成功后内部记录block_size。
- 仅支持华为昇腾NPU显存中的torch.Tensor;CPU张量或其他设备张量不支持。
方法定义
ContextCaching.register_kvcache(kvcache)
请求参数说明
参数名称 |
参数类型 |
是否必选 |
描述 |
---|---|---|---|
kvcache |
List[List[torch.Tensor]] |
是 |
参数解释: 每层的KVCache结构,二维列表; 外层维度为layers,内层维度为k_v_index(为2,对应 K/V)。 最内层张量形状必须为[GPU_blocks, Block_size, heads, head_size]。 约束限制: 不能为空;必须全部位于NPU设备。 取值范围: 无。 默认取值: 无。 |
返回结果说明
类型 |
说明 |
---|---|
None |
参数解释: 无返回值。注册成功即表示KVCache结构就绪,可进行后续save/load/async_save/async_load。 |
代码样例
下面为向context cache注册一个KVCache结构,同时对异常进行容错处理。
import os, torch, torch_npu from ems import Ems, EmsConfig, EmsException, CcConfig, CcKvOption, KvBufferWrapper # 初始化cc配置 cc_config = CcConfig(rank_id=8, device_id=0, model_id="llama2-13b") # 初始化Ems config = EmsConfig(cc_config=cc_config) try: Ems.init(config) except EmsException as e: print(f"exception: {e}.") exit(1) # 获取context caching对象 cc = Ems.get_cc() if cc is None: print("cc is None.") exit(1) # 设定形状 # 期望整体形状: [layers, k_v_index, GPU_blocks, Block_size, heads, head_size] layers = 2 k_v_index = 2 GPU_blocks = 2 Block_size = 4 head = 2 head_size = 4 # 设置device和dtype device = "npu:1" dtype = torch.float16 # 构造最内层张量 (每个都是 [GPU_blocks, Block_size, heads, head_size]) k_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype) v_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype) k_t1 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype) v_t2 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype) # 组织成二维列表 List[List[Tensor]]:外层按 layer,内层按 K/V # kvcache[l][0]->第l层的K,kvcache[l][1]-> 第l层的V kvcache = [ [k_t0, v_t0], # 第 0 层 [k_t1, v_t1], # 第 1 层 ] # 注册到Context Caching try: context_caching.register_kvcache(kvcache) print("register_kvcache: success") except EmsException as e: print(f"register_kvcache failed: {e}")