更新时间:2025-09-08 GMT+08:00
分享

注册KVCache

功能介绍

将推理过程中使用的KVCache内存布局一次性注册到EMS上下文缓存(Context Caching)中,用于后续save/load/async_save/async_load接口按块(block)进行定位与管理。注册仅建立元数据与地址映射,不执行实际读写。

接口约束

  • 接口为同步调用,完成初始化后返回。
  • KVCache期望形状应为:[layers, k_v_index, GPU_blocks, Block_size, heads, head_size]。
  • Block_size将作为后续slot_mapping解析的基准,注册成功后内部记录block_size。
  • 仅支持华为昇腾NPU显存中的torch.Tensor;CPU张量或其他设备张量不支持。

方法定义

ContextCaching.register_kvcache(kvcache)

请求参数说明

表1 请求参数列表

参数名称

参数类型

是否必选

描述

kvcache

List[List[torch.Tensor]]

参数解释:

每层的KVCache结构,二维列表;

外层维度为layers,内层维度为k_v_index(为2,对应 K/V)。

最内层张量形状必须为[GPU_blocks, Block_size, heads, head_size]。

约束限制:

不能为空;必须全部位于NPU设备。

取值范围:

无。

默认取值:

无。

返回结果说明

表2 返回结果

类型

说明

None

参数解释:

无返回值。注册成功即表示KVCache结构就绪,可进行后续save/load/async_save/async_load。

代码样例

下面为向context cache注册一个KVCache结构,同时对异常进行容错处理。

import os, torch, torch_npu 
from ems import Ems, EmsConfig, EmsException, CcConfig, CcKvOption, KvBufferWrapper 
# 初始化cc配置 
cc_config = CcConfig(rank_id=8, device_id=0, model_id="llama2-13b") 
# 初始化Ems 
config = EmsConfig(cc_config=cc_config) 
try: 
    Ems.init(config) 
except EmsException as e: 
    print(f"exception: {e}.") 
    exit(1) 
# 获取context caching对象
cc = Ems.get_cc() 
if cc is None: 
    print("cc is None.") 
    exit(1) 
# 设定形状  
# 期望整体形状: [layers, k_v_index, GPU_blocks, Block_size, heads, head_size]
layers = 2 
k_v_index = 2 
GPU_blocks = 2
Block_size = 4
head = 2
head_size = 4
# 设置device和dtype
device = "npu:1"
dtype = torch.float16
# 构造最内层张量 (每个都是 [GPU_blocks, Block_size, heads, head_size])
k_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
v_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
k_t1 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
v_t2 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
# 组织成二维列表 List[List[Tensor]]:外层按 layer,内层按 K/V
# kvcache[l][0]->第l层的K,kvcache[l][1]-> 第l层的V
kvcache = [
    [k_t0, v_t0],   # 第 0 层
    [k_t1, v_t1],   # 第 1 层
]
# 注册到Context Caching
try:
     context_caching.register_kvcache(kvcache) 
     print("register_kvcache: success") 
except EmsException as e:
     print(f"register_kvcache failed: {e}")

相关文档