注册KVCache
功能介绍
将推理过程中使用的KVCache内存布局一次性注册到EMS上下文缓存(Context Caching)中,用于后续save/load/async_save/async_load接口按块(block)进行定位与管理。注册仅建立元数据与地址映射,不执行实际读写。
接口约束
- 接口为同步调用,完成初始化后返回。
- KVCache期望形状应为:[layers, k_v_index, GPU_blocks, Block_size, heads, head_size]。
- Block_size将作为后续slot_mapping解析的基准,注册成功后内部记录block_size。
- 仅支持华为昇腾NPU显存中的torch.Tensor;CPU张量或其他设备张量不支持。
方法定义
ContextCaching.register_kvcache(kvcache)
请求参数说明
|
参数名称 |
参数类型 |
是否必选 |
描述 |
|---|---|---|---|
|
kvcache |
List[List[torch.Tensor]] |
是 |
参数解释: 每层的KVCache结构,二维列表; 外层维度为layers,内层维度为k_v_index(为2,对应 K/V)。 最内层张量形状必须为[GPU_blocks, Block_size, heads, head_size]。 约束限制: 不能为空;必须全部位于NPU设备。 取值范围: 无。 默认取值: 无。 |
返回结果说明
|
类型 |
说明 |
|---|---|
|
None |
参数解释: 无返回值。注册成功即表示KVCache结构就绪,可进行后续save/load/async_save/async_load。 |
代码样例
下面为向context cache注册一个KVCache结构,同时对异常进行容错处理。
import os, torch, torch_npu
from ems import Ems, EmsConfig, EmsException, CcConfig, CcKvOption, KvBufferWrapper
# 初始化cc配置
cc_config = CcConfig(rank_id=8, device_id=0, model_id="llama2-13b")
# 初始化Ems
config = EmsConfig(cc_config=cc_config)
try:
Ems.init(config)
except EmsException as e:
print(f"exception: {e}.")
exit(1)
# 获取context caching对象
cc = Ems.get_cc()
if cc is None:
print("cc is None.")
exit(1)
# 设定形状
# 期望整体形状: [layers, k_v_index, GPU_blocks, Block_size, heads, head_size]
layers = 2
k_v_index = 2
GPU_blocks = 2
Block_size = 4
head = 2
head_size = 4
# 设置device和dtype
device = "npu:1"
dtype = torch.float16
# 构造最内层张量 (每个都是 [GPU_blocks, Block_size, heads, head_size])
k_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
v_t0 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
k_t1 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
v_t2 = torch.zeros((GPU_blocks, Block_size, heads, head_size), device=device, dtype=dtype)
# 组织成二维列表 List[List[Tensor]]:外层按 layer,内层按 K/V
# kvcache[l][0]->第l层的K,kvcache[l][1]-> 第l层的V
kvcache = [
[k_t0, v_t0], # 第 0 层
[k_t1, v_t1], # 第 1 层
]
# 注册到Context Caching
try:
context_caching.register_kvcache(kvcache)
print("register_kvcache: success")
except EmsException as e:
print(f"register_kvcache failed: {e}")