文档首页/ 云容器引擎 CCE/ 最佳实践/ 云原生AI/ CCE分布式推理场景下的EMS高性能KV Cache实践
更新时间:2026-06-17 GMT+08:00
分享

CCE分布式推理场景下的EMS高性能KV Cache实践

本文介绍了如何在CCE集群 上,基于vLLM-ascend推理引擎和Kthena框架部署Kimi-25模型。该架构采用1P1D(1 Prefill + 1 Decode)分离架构,通过物理隔离Prefill与Decode阶段,并结合弹性内存存储(EMS) 实现KV Cache的高效传输,显著提升资源利用率和推理性能。

背景信息

Kthena是一个专为Kubernetes设计的 云原生、高性能大语言模型(LLM)推理路由、编排与调度系统,作为云原生批量计算项目Volcano的子项目,Kthena旨在解决企业在生产环境中大规模部署、管理和运行大模型推理服务时所面临的复杂性问题。Kthena的核心能力包括:

  • 声明式编排:允许用户通过配置文件统一定义、部署和扩展复杂的LLM推理集群,将单次推理任务拆解为不同的“角色”(Roles)指派给不同的Pod执行,并统一编排协同工作组件。
  • 智能路由:Kthena Router不同于传统负载均衡器,它能感知底层推理引擎的状态与负载,实现智能请求路由。
  • 弹性伸缩:结合Kthena Autoscaler,支持在不中断服务的情况下进行细粒度的弹性扩缩容(如针对高并发场景调整不同角色的副本数)以及滚动更新。

尽管Kthena在控制面上很好地支撑着推理集群的宏观编排与路由分配,但在数据面上,Prefill与Decode阶段的物理隔离意味着海量的中间状态数据(KV Cache)必须在节点间高效流转。这也正是华为云弹性内存存储(EMS)在该方案中所扮演的关键角色——为跨节点的KV Cache传输提供高性能、低延迟的内存级缓存通道。

随着大语言模型(LLM)应用的爆发式增长,LLM在线推理已成为AI算力消耗的核心场景。这些推理任务(例如,多轮对话、信息检索、代码辅助和文本生成等)具有高并发需求和严格的延迟要求。然而,在使用传统架构部署大模型时,系统正面临显著的“显存内存墙”瓶颈。

  • 算力资源浪费与吞吐瓶颈

    在Transformer模型推理中,为避免重复计算,中间数据(KVCache)需存储在AI服务器显存中。受限于单机显存容量(通常仅几十 GB),除去模型参数外,剩余空间极小,导致单节点能并发处理的请求数量极低。若强行通过增加算力应对高并发,将导致推理成本急剧上升。

  • 历史长上下文遗忘,连贯性差

    受限于显存容量,系统被迫频繁丢弃历史对话的KV缓存以腾出空间。这导致智能助手在长文本交互中极易“遗忘”早期内容,造成上下文断裂,严重影响用户体验。

  • 首Token延迟(TTFT)高

    当被丢弃的历史会话或具有公共前缀的请求再次被激活时,系统必须消耗大量GPU算力重新计算 KVCache。这种重复计算不仅推高了推理成本,更导致从输入到输出的等待时间过长,严重拖慢响应速度。

为了解决上述问题,华为云推出了弹性内存存储(EMS,Elastic Memory Service)。EMS是一种以DRAM内存为主存储介质分布式内存池化管理“以存代算”技术,为大语言模型推理提供高性能缓存与加速能力,在显著降低推理算力成本的同时,大幅提升整体吞吐性能。

功能介绍

EMS产品架构主要由三部分组成:领域专用服务SDK、分布式内存池和管理面。

  • 领域专用服务SDK:包含一系列面向不同AI应用场景的插件和接口服务SDK,提供业务系统接入、业务数据布局和近数据处理等功能,实现业务请求的内存加速。目前,该SDK主要应用于大语言模型的推理,通过分布式内存池提升处理效率并降低成本。
  • 分布式内存池:负责跨节点的内存空间管理、数据负载均衡等任务,通过空间池化提供内存缓存共享访问。内存池当前采用融合部署方式,即利用AI服务器中的DRAM,将DRAM内存池化以实现分布式共享,并进行本地亲和调度和访问。
  • EMS管理面:负责EMS服务的部署、监控、升级及运维管理等功能,通过华为云的云原生基础设施为用户提供一站式的云上运维解决方案。

核心竞争力

  • 以存代算 (CachedAttention) :为了消除多轮对话和公共前缀场景中的重复计算,EMS实现了以存代算机制。系统将历史会话产生的KVCache异步保存到EMS分布式内存池中。当会话重新激活时,系统直接从EMS中加载复用这些KV缓存,避免昂贵的GPU重复计算。该机制显著降低了首字时延(TTFT),并极大提升了推理预填充(Prefill)阶段的吞吐量。
  • 分级缓存,突破“显存墙”: EMS构建了创新的“显存-内存-存储”三级缓存体系。EMS作为计算层与存储层之间的高性能内存缓存层,突破了单机AI服务器的物理显存瓶颈,实现了显存空间的动态延伸。不仅提升了持久化层的数据访问速度,更保障了大规模并发下的请求响应效率。
  • 融合部署与极简降本:针对AI场景中“DRAM 内存利用率较低”的痛点,EMS数据面采用半托管融合部署模式,直接统筹并纳管AI服务器上空闲的DRAM内存资源进行池化复用。用户无需额外采购昂贵的专用缓存硬件,即可实现极高的资源弹性与成本优化。
  • 分布式共享与高缓存命中:EMS构建的分布式内存池打破了单节点的孤岛效应,实现节点间高效共享。结合系统对调度器任务队列的自动感知以及亲和调度策略,大幅提升了全局跨节点的缓存命中率,完美契合大规模分布式推理的苛刻要求。
  • 完善的生态联动:EMS并非孤立的基础设施,它可以与华为云 AI 开发平台ModelArts、云容器引擎CCE、高性能弹性文件服务SFS Turbo以及对象存储服务 OBS搭配组合,帮助企业快速构建一站式、高性能、Serverless化的全栈大模型推理生产线。

前提条件

  • 已创建v1.29及以上的CCE Turbo集群,并已安装节点本地域名解析加速Volcano调度器(v1.22.1版本及以上)插件。
  • 集群内已纳管或创建至少5个Snt9b或Snt9b23规格节点,切勿混用不同规格机器。

约束与限制

该功能当前正处于上线阶段,已发布区域请以控制台实际为准。

操作步骤

  1. 开启EMS加速与节点打标。

    如果EMS配置项不存在,请提交工单联系客服处理。

    1. 登录CCE控制台,进入目标集群,在左侧导航栏选择“配置中心”,单击“异构资源配置”页签。
    2. 开启“EMS加速”开关,将准备好的Snt9b/Snt9b23节点打上标签。

    3. 在弹出的对话框中,单击“保存”,下发EMS集群。

  2. (可选)配置安全组规则。

    若节点使用宿主机网络,需手动配置安全组以允许EMS组件通信。

    1. 访问资源栈,根据EMS集群ID搜索资源栈,确认状态为部署成功。

    2. 获取默认安全组ID。
      1. 在CCE控制台目标集群的“概览”页面,在“网络信息”区域,单击“网络信息”区域“节点默认安全组”后的链接。

      2. “基本信息”页签,单击ID后的图标,复制安全组ID。

    3. 配置EMS安全组规则。
      1. 单击“基本信息”页签上方的图标。
      2. 在“访问控制 > 安全组”页面的搜索框里搜索“{EMS集群ID}”,然后单击“ems-zookeeper-{EMS集群ID}”。

      3. 单击“入方向规则”,然后单击“添加规则”
      4. “添加入方向规则”面板,“协议端口”填写“10000”,“源地址”下拉列表选择“安全组”,下方填写“$.{节点默认安全组ID}”。

      5. 单击下方的“增加1条规则”,“协议端口”填写“11000”,“源地址”下拉列表选择“安全组”,下方填写“$.{节点默认安全组ID}”,然后单击“确定”

  3. 重启EMS控制器。

    1. 在CCE控制台目标集群“工作负载”页面的“无状态负载”页签,找到对应的命名空间ems-{ems_cluster_id},单击ems-controller工作负载名称。

    2. 在实例列表页签,勾选所有实例,单击“批量删除”,在弹出的对话框中,单击“是”

    3. 重复上述1~2步骤,选择ems-server负载执行同样的操作。
    4. “配置中心”“异构资源配置”页签,等待组件加载完毕。

  4. 构建推理镜像。

    1. 在CCE控制台目标集群“配置中心”页面的“调度配置”页签,设置默认调度器为“Volcano调度器”。
    2. 确认机器ROCE网络状态,可参见Prefill-Decode Disaggregation (Deepseek)
    3. 准备本地大模型。可通过ModelScope下载模型,本文以Kimi-K2.5-W4A8模型为例,确保模型已下载到节点的/mnt/pass/models目录。
    4. 准备昇腾引擎镜像。
      1. 下载openssl和cyrus安装包。
        https://github.com/openssl/openssl/releases/download/OpenSSL_1_1_1w/openssl-1.1.1w.tar.gz
        https://github.com/cyrusimap/cyrus-sasl/releases/download/cyrus-sasl-2.1.28/cyrus-sasl-2.1.28.tar.gz
      2. 请前往CCE集群“配置中心 > 异构资源配置”页签下的EMS配置中下载EMS SDK安装包,并将其置于与Dockerfile同级目录下。

      3. 创建ems_store文件夹。
        mkdir ems_store

        在文件夹下创建__init__.py、ems_adapter.py、ems_env.py、ems_store_connector.py四个文件:

        • __init__.py
          vi __init__.py

          该文件为空,直接按Esc,输入wq保存。

        • ems_adapter.py
          vi ems_adapter.py

          内容如下:

          from typing import List, Tuple, Dict, Set
          import threading
          import time
          import os
          from collections import deque
          
          import torch
          import numpy as np
          from vllm.config import VllmConfig
          from vllm.distributed import get_tp_group, get_pp_group
          from vllm.logger import logger
          
          from ems import Ems, CcKvOption, EmsException, EmsErrorCode, EmsConfig, CcConfig_v1
          from ems.cc_v1.cc_config import KVCacheType
          
          from vllm_ascend.distributed.ems_store.ems_env import EmsEnv
          
          
          class EmsAdapter:
              def __init__(self, vllm_config: VllmConfig):
                  self.vllm_config = vllm_config
                  self.is_mla = vllm_config.model_config.use_mla
                  self.model_name = "_".join(os.path.normpath(vllm_config.model_config.model).split("/")[-2:])
                  logger.info(f"[EMS] EmsAdapter init with vllm_config: {vllm_config}, "
                              f"is_mla: {self.is_mla}, model_name: {self.model_name}")
          
                  self.rank = vllm_config.parallel_config.rank
                  self.world_size = vllm_config.parallel_config.world_size
                  self.block_size = vllm_config.cache_config.block_size
          
                  self.load_futures = {}
                  self.save_futures = {}
          
                  self.failed_block_ids: set[int] = set()
          
                  self._inited = False
                  self._need_reinit = False
                  self._registered = False
                  self._pending_kv_caches = None
          
                  self.ems_cfg = self.get_ems_cfg()
                  self.context_caching = None
                  self.init_ems()
          
                  self.task_manager = None
                  self.init_task_manager()
          
          
              def get_ems_cfg(self) -> EmsConfig:
                  tp_group = get_tp_group()
                  pp_group = get_pp_group()
                  cc_config_v1 = CcConfig_v1(
                      rank_id=tp_group.rank,
                      device_id=tp_group.local_rank,
                      model_id=EmsEnv.model_id,
                      tp_world_size=tp_group.world_size,
                      pp_world_size=pp_group.world_size,
                      rank_in_tp_group=tp_group.rank_in_group,
                      rank_in_pp_group=pp_group.rank_in_group,
                      llm_engine=f"{EmsEnv.llm_engine}@{self.model_name}"
                  )
                  if self.is_mla:
                      cc_config_v1.kvcache_type = KVCacheType.MLA
          
                  ems_cfg = EmsConfig(
                      access_id=EmsEnv.access_id,
                      access_key=EmsEnv.access_key,
                      cc_config_v1=cc_config_v1
                  )
                  return ems_cfg
          
          
              def init_ems(self) -> None:
                  try:
                      Ems.init(self.ems_cfg)
                      self.context_caching = Ems.get_cc()
                      self._inited = True
                      logger.info(f"[EMS][Init] EmsConnector init succeed, EMS ready.")
                  except EmsException as e:
                      if e.status_code() == EmsErrorCode.EMS_RECOVERD_ERROR:
                          self._need_reinit = True
                      logger.error(f"[EMS][Init] EmsConnector init fail, error: {e}.")
                  if not self._inited and self.context_caching:
                      logger.warning(f"[EMS][Init][degraded] EmsConnector init failed but get cc, reason=EMS not ready.")
          
          
              def init_task_manager(self) -> None:
                  try:
                      self.task_manager = PeriodicTaskManager(check_fn=self._check_health)
                      logger.info(f"[EMS] init periodic task manager succeeded.")
                  except Exception as e:
                      logger.error(f"[EMS] init periodic task manager failed, error: {e}.")
                      raise
          
          
              def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]) -> None:
                  if not self._inited:
                      self._pending_kv_caches = kv_caches
                      logger.warning("[EMS][Register] EMS not ready, skip register_kvcache.")
                      return
          
                  kv_caches_list: List[List[torch.Tensor]] = []
                  for _, kv_caches_layer in kv_caches.items():
                      num_kv = len(kv_caches_layer)
                      kv_tensor_list: List[torch.Tensor] = [kv_caches_layer[idx] for idx in range(num_kv)]
                      kv_caches_list.append(kv_tensor_list)
          
                  try:
                      self.context_caching.register_kvcache(kv_caches_list)
                      self._registered = True
                      self._pending_kv_caches = None
          
                      # shape of kv_caches_list: [layers, [k_v_index, [gpu_blocks, block_size, heads, head_size]]]
                      ds_layers = len(kv_caches_list)
                      ds_k_v_index = len(kv_caches_list[0])
                      ds_gpu_blocks, ds_block_size, ds_heads, ds_head_size = list(kv_caches_list[0][0].shape)
                      logger.info(f"[EMS][Register] layer_num={ds_layers}, block_size={ds_block_size}, "
                                  f"is_mla={self.is_mla}, kvcache_dim={ds_k_v_index}")
          
                  except EmsException as e:
                      logger.error(f"[EMS][Register] register kvcache error: {e}.")
          
          
              def _cal_block_offsets(self, block_ids: List[int], block_size: int) -> List[int]:
                  return [block_size] * len(block_ids)
          
          
              def _cal_block_slot_mapping(self, block_ids: List[int], block_size: int) -> List[int]:
                  block_ids_np = np.array(block_ids)
                  range_arr = np.arange(block_size)
                  result_matrix = block_ids_np[:, np.newaxis] * block_size + range_arr
                  flattened_result = result_matrix.ravel()
                  return flattened_result.tolist()
          
          
              def exists_block_num(self, req_id: str, block_hashes: list[int]) -> int:
          
                  if not self._inited or not self._registered:
                      logger.warning(f"[EMS][exists][degraded] req_id={req_id} reason=EMS not ready.")
                      return 0
          
                  if not self.task_manager.get_status():
                      logger.warning(f"[EMS][exists][skip] req_id={req_id} reason=Unhealthy EMS")
                      return 0
          
                  option = CcKvOption(
                      write_rcache=EmsEnv.ems_enable_write_rcache,
                      read_local_only=EmsEnv.ems_enable_read_local_only,
                      timeout=EmsEnv.ems_timeout
                  )
                  try:
                      cc_result = self.context_caching.exists(hashes=block_hashes, option=option)
                      return cc_result.success
                  except EmsException as e:
                      logger.debug(f"[EMS][Exist] fallback to len(hashes) due to: {e}")
                      return len(block_hashes)
          
              def async_load(self, req_id: str, block_hashes: List[int], block_ids: List[int], num_computed_blocks: int) -> None:
                  future = None
                  submit_time = time.perf_counter()
          
                  if not self._check_params(req_id, block_hashes, block_ids, "AsyncLoad"):
                      self.load_futures[req_id] = (future, submit_time, num_computed_blocks, block_ids)
                      return
          
                  option = CcKvOption(
                      write_rcache=EmsEnv.ems_enable_write_rcache,
                      read_local_only=EmsEnv.ems_enable_read_local_only,
                      timeout=EmsEnv.ems_timeout
                  )
                  offsets = self._cal_block_offsets(block_ids, self.block_size)
                  slot_mapping = self._cal_block_slot_mapping(block_ids, self.block_size)
          
                  try:
                      future = self.context_caching.async_load(
                          slot_mapping=slot_mapping,
                          hashes=block_hashes,
                          offsets=offsets,
                          option=option
                      )
                      logger.info(f"[EMS][AsyncLoad][start] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} "
                                  f"planned_blocks={len(block_ids)} first_block_hash={block_hashes[0]}, future=submitted")
                  except EmsException as e:
                      self._process_exception(e)
                      logger.error(f"[EMS][AsyncLoad][error] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} ..., err={e}")
                  self.load_futures[req_id] = (future, submit_time, num_computed_blocks, block_ids)
          
          
              def async_save(self, req_id: str, block_hashes: List[int], block_ids: List[int]) -> None:
                  future = None
                  submit_time = time.perf_counter()
          
                  if not self._check_params(req_id, block_hashes, block_ids, "AsyncSave"):
                      self.save_futures[req_id] = (future, submit_time)
                      return
          
                  option = CcKvOption(
                      write_rcache=EmsEnv.ems_enable_write_rcache,
                      read_local_only=EmsEnv.ems_enable_read_local_only,
                      timeout=EmsEnv.ems_timeout
                  )
                  offsets = self._cal_block_offsets(block_ids, self.block_size)
                  slot_mapping = self._cal_block_slot_mapping(block_ids, self.block_size)
          
                  try:
                      future = self.context_caching.async_save(
                          slot_mapping=slot_mapping,
                          hashes=block_hashes,
                          offsets=offsets,
                          option=option
                      )
                      logger.info(f"[EMS][AsyncSave][start] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} "
                                  f"planned_blocks={len(block_ids)} first_block_hash={block_hashes[0]}, future=submitted")
                  except EmsException as e:
                      self._process_exception(e)
                      logger.error(f"[EMS][AsyncSave][error] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} ..., err={e}")
                  self.save_futures[req_id] = (future, submit_time)
          
          
              def get_finished_load_reqs(self) -> Set[str]:
                  finished_reqs = set()
          
                  for req_id, (future, submit_time, num_computed_blocks, block_ids) in list(self.load_futures.items()):
                      if future and not self.context_caching.is_ready(future):
                          continue
          
                      try:
                          if future:
                              result = self.context_caching.get_result(future)
                              cost_ms = 1e3 * (time.perf_counter() - submit_time)
                              logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={result.success}, "
                                          f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}")
                              self.task_manager.update_req_stat("LOAD", result.success, cost_ms)
                              self.task_manager.update_hit_stat(result.success, result.total)
                              if result.success < len(block_ids):
                                  failed_part = block_ids[result.success:]
                                  self.failed_block_ids.update(failed_part)
                          else:
                              cost_ms = 1e3 * (time.perf_counter() - submit_time)
                              logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={0}, "
                                          f"total_blocks_num={0}, status=EMS_INTERNAL_ERROR, cost_ms={cost_ms:.2f}")
                              self.task_manager.update_req_stat("LOAD", 0, cost_ms)
                              self.failed_block_ids.update(block_ids)
                      except EmsException as e:
                          cost_ms = 1e3 * (time.perf_counter() - submit_time)
                          self.failed_block_ids.update(block_ids)
                          self._process_exception(e)
                          logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={0}, "
                                      f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}")
                      finished_reqs.add(req_id)
                      self.load_futures.pop(req_id)
          
                  return finished_reqs
          
          
              def get_block_ids_with_load_errors(self) -> Set[int]:
                  failed_block_ids = self.failed_block_ids
                  self.failed_block_ids = set()
          
                  return failed_block_ids
          
          
              def get_finished_save_reqs(self) -> Set[str]:
                  finished_reqs = set()
          
                  for req_id, (future, submit_time) in list(self.save_futures.items()):
                      if future and not self.context_caching.is_ready(future):
                          continue
          
                      try:
                          if future:
                              result = self.context_caching.get_result(future)
                              cost_ms = 1e3 * (time.perf_counter() - submit_time)
                              logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={result.success}, "
                                          f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}")
                              self.task_manager.update_req_stat("SAVE", result.success, cost_ms)
                          else:
                              cost_ms = 1e3 * (time.perf_counter() - submit_time)
                              logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, "
                                          f"total_blocks_num={0}, status=EMS_INTERNAL_ERROR, cost_ms={cost_ms:.2f}")
                              self.task_manager.update_req_stat("SAVE", 0, cost_ms)
                      except EmsException as e:
                          cost_ms = 1e3 * (time.perf_counter() - submit_time)
                          self._process_exception(e)
                          logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, "
                                      f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}")
                      finished_reqs.add(req_id)
                      self.save_futures.pop(req_id)
          
                  return finished_reqs
          
          
              def sync_save_reqs(self) -> None:
                  for req_id, (future, submit_time) in self.save_futures.items():
                      if future is None:
                          continue
          
                      try:
                          result = self.context_caching.get_result(future)
                          cost_ms = 1e3 * (time.perf_counter() - submit_time)
                          logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={result.success}, "
                                      f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}")
                          self.task_manager.update_req_stat("SAVE", result.success, cost_ms)
                      except EmsException as e:
                          cost_ms = 1e3 * (time.perf_counter() - submit_time)
                          self._process_exception(e)
                          logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, "
                                      f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}")
          
                  self.save_futures.clear()
          
          
              def _check_health(self) -> bool:
                  is_health = Ems.check_health()
                  if is_health:
                      logger.info("[EMS] EMS health status is ok.")
                  else:
                      logger.info("[EMS] EMS health status is abnormal.")
          
                  if not self._inited and is_health and self._need_reinit:
                      logger.info(f"[EMS][Init] re-init during health check.")
                      self.init_ems()
                      if self._inited:
                          logger.info("[EMS][Init] re-init succeed during health check.")
                          if (not self._registered) and (self._pending_kv_caches is not None):
                              self.register_kv_caches(self._pending_kv_caches)
                              logger.info(f"[EMS][Register] re-register pending kvcache succeed.")
          
                  return is_health
          
          
              def _check_params(self, req_id: str, block_hashes: List[int], block_ids: List[int], called_at: str) -> bool:
                  if not self._inited or not self._registered:
                      logger.warning(f"[EMS][{called_at}][degraded] req_id={req_id} reason=EMS not ready.")
                      return False
          
                  if not self.task_manager.get_status():
                      logger.warning(f"[EMS][{called_at}][skip] req_id={req_id} reason=Unhealthy EMS")
                      return False
          
                  if (len(block_hashes) == 0 or len(block_ids) == 0) or (len(block_hashes) != len(block_ids)):
                      logger.error(f"[EMS] req {req_id} has invalid block_hashes or block_ids: "
                                   f"len(block_hashes) == {len(block_hashes)}, len(block_ids) == {len(block_ids)}.")
                      return False
          
                  return True
          
          
              def _process_exception(self, e: EmsException) -> None:
                  if e.status_code() == EmsErrorCode.EMS_INVALID_ARGUMENT:
                      return
                  self.task_manager.reset_status()
          
          
          class PeriodicTaskManager:
              HEALTH_CHECK_INTERVAL = 10  # 检查间隔 (秒)
              FLAPPING_WINDOW = 60  # 震荡检测窗口 (秒)
              FLAPPING_LIMIT = 3  # 窗口内允许的最大状态变更次数
              SUCCESS_THRESHOLD = 3  # 恢复健康所需的连续成功次数
          
              PRINT_INTERVAL = 30  # 检查间隔 (秒)
          
              def __init__(self, check_fn):
                  self.check_fn = check_fn
          
                  self._ems_ok = False
                  self._consecutive_success_count = 0
                  # 滑动窗口震荡检测记录
                  self._change_history: deque = deque()
          
                  self.last_log_time = time.perf_counter()
                  self.stat = { # [sum, min, max], init min with a large number: 2**30 (int) or 1e9 (float)
                      "LOAD": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]},
                      "SAVE": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]},
                      "HIT": {"num_hit_blocks": 0, "num_total_blocks": 0},
                  }
          
                  self.thread_lock = threading.Lock()
                  self.start_loop()
          
              def get_status(self) -> bool:
                  return self._ems_ok
          
              def reset_status(self) -> None:
                  if not self._ems_ok:
                      return
          
                  logger.info(f"[EMS] Ems health status reset to False")
                  self._ems_ok = False
          
              def check_health_status(self) -> None:
                  is_healthy = False
                  try:
                      is_healthy = self.check_fn()
                  except Exception as e:
                      logger.error(f"[EMS] EMS health check failed, error: {e}.")
          
                  self._process_check_result(is_healthy)
          
              def _process_check_result(self, is_healthy: bool):
                  if is_healthy:
                      self._handle_success()
                  else:
                      self._handle_failure()
          
              def _handle_success(self):
                  self._consecutive_success_count += 1
          
                  if not self._ems_ok:
                      if self._consecutive_success_count >= self.SUCCESS_THRESHOLD:
                          # 尝试切换健康状态为True,这会受到震荡检测的严格拦截
                          if self._try_switch_status(new_status=True):
                              pass
                          else:
                              # 被震荡拦截,重置计数器
                              self._consecutive_success_count = 0
                  else:
                      self._consecutive_success_count = min(self._consecutive_success_count, self.SUCCESS_THRESHOLD)
          
              def _handle_failure(self):
                  self._consecutive_success_count = 0
          
                  if self._ems_ok:
                      # 尝试切换健康状态为False,忽略震荡,立即切换
                      self._try_switch_status(new_status=False)
          
              def _try_switch_status(self, new_status: bool) -> bool:
                  current_time = time.monotonic()
          
                  # 1. 清理滑动窗口
                  self._clean_flapping_history(current_time)
          
                  # 2. 检查震荡
                  is_flapping = len(self._change_history) >= self.FLAPPING_LIMIT
          
                  if is_flapping:
                      if new_status is True:
                          # 震荡中尝试恢复被拦截
                          logger.warning(
                              f"[EMS] EMS status flapping detected ({len(self._change_history)} changes in {self.FLAPPING_WINDOW}s). "
                              f"Blocking recovery (switch to True)."
                          )
                          return False
                      else:
                          logger.info(
                              f"[EMS] EMS status flapping detected, but forcing status to False (Fail Fast strategy)."
                          )
          
                  # 3. 执行切换
                  logger.info(f"[EMS] EMS health status changing: {self._ems_ok} -> {new_status}")
                  self._ems_ok = new_status
          
                  # 4. 记录变更历史
                  self._change_history.append(current_time)
          
                  return True
          
              def _clean_flapping_history(self, current_time: float):
                  threshold_time = current_time - self.FLAPPING_WINDOW
                  while self._change_history and self._change_history[0] < threshold_time:
                      self._change_history.popleft()
          
              def update_req_stat(self, event: str, block_num: int, cost_time: float) -> None:
                  with self.thread_lock:
                      self.stat[event]["count"] += 1
                      self.stat[event]["block_nums"][0] += block_num
                      self.stat[event]["block_nums"][1] = min(block_num, self.stat[event]["block_nums"][1])
                      self.stat[event]["block_nums"][2] = max(block_num, self.stat[event]["block_nums"][2])
                      self.stat[event]["cost_times"][0] += cost_time
                      self.stat[event]["cost_times"][1] = min(cost_time, self.stat[event]["cost_times"][1])
                      self.stat[event]["cost_times"][2] = max(cost_time, self.stat[event]["cost_times"][2])
          
              def update_hit_stat(self, num_hit_blocks: int, num_total_blocks: int) -> None:
                  with self.thread_lock:
                      self.stat["HIT"]["num_hit_blocks"] += num_hit_blocks
                      self.stat["HIT"]["num_total_blocks"] += num_total_blocks
          
              def print_stat(self) -> None:
                  with self.thread_lock:
                      stat = self.stat
                      self.stat = { # [sum, min, max], init min with a large number: 2**30 (int) or 1e9 (float)
                          "LOAD": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]},
                          "SAVE": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]},
                          "HIT": {"num_hit_blocks": 0, "num_total_blocks": 0},
                      }
          
                  load = stat["LOAD"]["count"]
                  avg_load_block_num = stat["LOAD"]["block_nums"][0] / load if load else 0.0
                  min_load_block_num = stat["LOAD"]["block_nums"][1] if load else 0
                  max_load_block_num = stat["LOAD"]["block_nums"][2] if load else 0
                  avg_load_cost_time = stat["LOAD"]["cost_times"][0] / load if load else 0.0
                  min_load_cost_time = stat["LOAD"]["cost_times"][1] if load else 0.0
                  max_load_cost_time = stat["LOAD"]["cost_times"][2] if load else 0.0
          
                  save = stat["SAVE"]["count"]
                  avg_save_block_num = stat["SAVE"]["block_nums"][0] / save if save else 0.0
                  min_save_block_num = stat["SAVE"]["block_nums"][1] if save else 0
                  max_save_block_num = stat["SAVE"]["block_nums"][2] if save else 0
                  avg_save_cost_time = stat["SAVE"]["cost_times"][0] / save if save else 0.0
                  min_save_cost_time = stat["SAVE"]["cost_times"][1] if save else 0.0
                  max_save_cost_time = stat["SAVE"]["cost_times"][2] if save else 0.0
          
                  num_hit_blocks = stat["HIT"]["num_hit_blocks"]
                  num_total_blocks = stat["HIT"]["num_total_blocks"]
                  hit_rate = 100 * (num_hit_blocks / num_total_blocks) if num_total_blocks else 0.0
          
                  logger.info(f"[EMS][{self.PRINT_INTERVAL}Sec Summary] req(load={load},save={save}) "
                              f"LOAD(avg_block_num={avg_load_block_num:.1f} [min:{min_load_block_num}, max:{max_load_block_num}]; "
                              f"avg_cost_ms={avg_load_cost_time:.1f} [min:{min_load_cost_time:.1f}, max:{max_load_cost_time:.1f}]) "
                              f"SAVE(avg_block_num={avg_save_block_num:.1f} [min:{min_save_block_num}, max:{max_save_block_num}]; "
                              f"avg_cost_ms={avg_save_cost_time:.1f} [min:{min_save_cost_time:.1f}, max:{max_save_cost_time:.1f}]) "
                              f"HIT(hit_rate={hit_rate:.1f} [hit:{num_hit_blocks}, total:{num_total_blocks}])")
          
              def task_loop(self) -> None:
                  while True:
                      time.sleep(self.HEALTH_CHECK_INTERVAL)
                      self.check_health_status()
                      cur_time = time.perf_counter()
                      if cur_time - self.last_log_time > self.PRINT_INTERVAL:
                          self.last_log_time = cur_time
                          self.print_stat()
          
              def start_loop(self) -> None:
                  logger.info("[EMS] EMS start periodic tasks.")
                  self.check_health_status()
                  self.periodic_task_thread = threading.Thread(target=self.task_loop, daemon=True, name="ems_task_loop")
                  self.periodic_task_thread.start()
                  logger.info(f"[EMS] periodic tasks subthread \"ems_task_loop\" started.")
        • ems_env.py
          vi ems_env.py

          内容如下:

          # Copyright (c) HuaWei Technologies Co., Ltd. 2025-2025. All rights reserved
          
          import os
          
          
          class EmsEnv:
              llm_engine = os.environ.get("LLM_ENGINE", "vllm")
              model_id = os.environ.get("MODEL_ID", "cc_kvstore@_@ds_default_ns_001")
              service_name = os.environ.get("SERVICE_NAME", "deepseek")
              access_id = os.environ.get("ACCELERATE_ID", "cc_kvstore@_@ds_default_ns_001")
              access_key = os.environ.get("ACCELERATE_KEY", "")
              ems_timeout: int = int(os.environ.get("EMS_TIMEOUT", "5000"))
              ems_enable_write_rcache: bool = os.environ.get("EMS_ENABLE_WRITE_RCACHE", "1") == "1"
              ems_enable_read_local_only: bool = os.environ.get("EMS_ENABLE_READ_LOCAL_ONLY", "0") == "1"
              ems_num_min_reuse_tokens: int = int(os.environ.get("EMS_NUM_MIN_REUSE_TOKENS", "2048"))
              ems_num_min_load_blocks: int = int(os.environ.get("EMS_NUM_MIN_LOAD_BLOCKS", "1"))
              ems_lookup_key_server_ip = os.environ.get("EMS_LOOKUP_KEY_SERVER_IP", "127.0.0.1")
              ems_lookup_key_server_base_port = os.environ.get("EMS_LOOKUP_KEY_SERVER_BASE_PORT", "50005")
          
        • ems_store_connector.py
          vi ems_store_connector.py

          内容如下:

          import os
          import threading
          from dataclasses import dataclass
          from typing import Any, Optional, Tuple, List, Set, Dict
          
          import torch
          import zmq
          
          from vllm.distributed import get_tp_group
          from vllm.v1.attention.backend import AttentionMetadata
          from vllm.config import VllmConfig
          from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
          from vllm.forward_context import ForwardContext
          from vllm.logger import logger
          from vllm.v1.core.kv_cache_manager import KVCacheBlocks
          from vllm.v1.core.sched.output import SchedulerOutput
          from vllm.v1.request import Request
          
          from vllm_ascend.distributed.ems_store.ems_adapter import EmsAdapter
          from vllm_ascend.distributed.ems_store.ems_env import EmsEnv
          
          
          @dataclass
          class CcReqMeta:
              req_id: str
              num_computed_blocks: int
              num_total_blocks: int
              block_hashes: List[int]
              block_ids: Optional[List[int]]
              operation: str
          
          
          @dataclass
          class CcConnectorMetadata(KVConnectorMetadata):
              def __init__(self):
                  self.requests: List[CcReqMeta] = []
          
              def add_requests(self, req_meta: CcReqMeta):
                  self.requests.append(req_meta)
          
          
          class EmsStoreConnector(KVConnectorBase_V1):
              def __init__(self,
                  vllm_config: VllmConfig,
                  role: KVConnectorRole,
                  kv_cache_config=None
              ):
                  super().__init__(
                      vllm_config=vllm_config,
                      role=role,
                      kv_cache_config=kv_cache_config
                  )
          
                  if role == KVConnectorRole.SCHEDULER:
                      self.connector_scheduler = CcConnectorScheduler(vllm_config)
          
                  if role == KVConnectorRole.WORKER:
                      self.connector_worker = CcConnectorWorker(vllm_config)
                      if get_tp_group().rank_in_group == 0:
                          self.lookup_server = LookupKeyServer(vllm_config.parallel_config.data_parallel_index,
                                                               self.connector_worker)
          
              ############################################################
              # Scheduler Side Methods
              ############################################################
          
              def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) -> Tuple[int, bool]:
                  return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
          
              def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int):
                  self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)
          
              def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
                  return self.connector_scheduler.build_connector_meta(scheduler_output)
          
              def request_finished(self, request: Request, block_ids: List[int]) -> Tuple[bool, Dict[str, Any] | None]:
                  return self.connector_scheduler.request_finished(request, block_ids)
          
              ############################################################
              # Worker Side Methods
              ############################################################
              def register_kv_caches(self, kv_caches: Dict[str, torch.Tensor]):
                  self.connector_worker.register_kv_caches(kv_caches)
          
              def start_load_kv(self, forward_context: ForwardContext, **kwargs):
                  self.connector_worker.start_load_kv(self._get_connector_metadata())
          
              def wait_for_layer_load(self, layer_name: str):
                  pass
          
              def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs):
                  pass
          
              def wait_for_save(self):
                  self.connector_worker.wait_for_save(self._get_connector_metadata())
          
              def get_finished(self, finished_req_ids: Set[str]) -> Tuple[Set[str] | None, Set[str] | None]:
                  return self.connector_worker.get_finished(finished_req_ids)
          
              def get_block_ids_with_load_errors(self) -> Set[int]:
                  return self.connector_worker.get_block_ids_with_load_errors()
          
          
          class CcConnectorScheduler:
              def __init__(self, vllm_config: VllmConfig):
                  logger.info(f"[EMS] CcConnectorScheduler init.")
          
                  self.block_size = vllm_config.cache_config.block_size
                  self.lookup_client = LookupKeyClient(vllm_config.parallel_config.data_parallel_index)
          
                  self.processed_requests: Set[str] = set()
                  self.meta_load_reqs: Dict[str, CcReqMeta] = {}
                  self.need_save_req_ids: Set[str] = set()
          
              def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) -> Tuple[int, bool]:
                  if request.request_id in self.processed_requests:
                      logger.info(f"[EMS][Scheduler] req {request.request_id} already in processed requests.")
                      return 0, False
          
                  num_computed_blocks = num_computed_tokens // self.block_size
                  num_total_blocks = (len(request.prompt_token_ids) - 1) // self.block_size
          
                  if not self._need_load(num_computed_blocks, num_total_blocks):
                      logger.info(f"[EMS][Scheduler] req {request.request_id} no need to load, num_computed_blocks: {num_computed_blocks}, "
                                   f"num_total_blocks: {num_total_blocks}.")
                      return 0, False
          
                  block_hashes = self._cal_block_hashes(
                      request.prompt_token_ids, self.block_size
                  )[num_computed_blocks:num_total_blocks]
                  num_exist_blocks = self.lookup_client.lookup(request.request_id, block_hashes)
          
                  num_total_blocks = num_computed_blocks + num_exist_blocks
                  block_hashes = block_hashes[:num_exist_blocks]
          
                  if not self._need_load(num_computed_blocks, num_total_blocks):
                      logger.info(f"[EMS][Scheduler] req {request.request_id} num_total_blocks is updated,still no need to load, num_computed_blocks: {num_computed_blocks}, "
                                   f"num_computed_blocks + num_exist_blocks: {num_total_blocks}.")
                      return 0, False
          
                  req_meta = CcReqMeta(
                      req_id=request.request_id,
                      num_computed_blocks=num_computed_blocks,
                      num_total_blocks=num_total_blocks,
                      block_hashes=block_hashes,
                      block_ids=None,
                      operation="no-op"
                  )
                  self.meta_load_reqs[request.request_id] = req_meta
                  logger.debug(f"[EMS][Scheduler] req {request.request_id} meta: {req_meta}.")
          
                  num_new_matched_tokens = (num_total_blocks - num_computed_blocks) * self.block_size
                  logger.info(f"[EMS][Scheduler] matched result, req_id={request.request_id}, "
                              f"num_computed_blocks={num_computed_blocks}, "
                              f"num_exist_blocks={num_exist_blocks}, new num_computed_blocks={num_total_blocks}, "
                              f"num_new_matched_tokens={num_new_matched_tokens}, prompt_len={len(request.prompt_token_ids)}")
                  return num_new_matched_tokens, True
          
              def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int):
                  if request.request_id not in self.meta_load_reqs:
                      logger.debug(f"[EMS][Scheduler] req {request.request_id} not in meta_load_reqs.")
                      return
          
                  if num_external_tokens == 0:
                      logger.debug(f"[EMS][Scheduler] req {request.request_id} num_external_tokens is 0, no need to update block ids.")
                      self.meta_load_reqs[request.request_id].operation = "no-op"
                      return
          
                  req_meta = self.meta_load_reqs[request.request_id]
                  all_block_ids = blocks.get_block_ids()[0]
                  req_meta.num_total_blocks = min(req_meta.num_total_blocks, len(all_block_ids))
          
                  req_meta.block_ids = all_block_ids[req_meta.num_computed_blocks:req_meta.num_total_blocks]
                  req_meta.block_hashes = req_meta.block_hashes[:len(req_meta.block_ids)]
                  req_meta.operation = "load"
          
                  self.processed_requests.add(request.request_id)
          
                  logger.info(f"[EMS][Scheduler] req {request.request_id} update block ids, meta: {req_meta}.")
          
              def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
                  connector_meta = CcConnectorMetadata()
          
                  for req_id, req_meta in self.meta_load_reqs.items():
                      if req_meta.operation != "load":
                          continue
                      connector_meta.add_requests(req_meta)
                  self.meta_load_reqs.clear()
          
                  for new_req in scheduler_output.scheduled_new_reqs:
                      num_total_tokens = scheduler_output.num_scheduled_tokens[new_req.req_id] + new_req.num_computed_tokens
                      num_computed_blocks = new_req.num_computed_tokens // self.block_size
                      num_total_blocks = num_total_tokens // self.block_size
          
                      if num_computed_blocks < num_total_blocks:
                          block_hashes = self._cal_block_hashes(new_req.prompt_token_ids, self.block_size)[
                                         num_computed_blocks:num_total_blocks]
                          block_ids = new_req.block_ids[0][num_computed_blocks:num_total_blocks]
          
                          req_meta = CcReqMeta(
                              req_id=new_req.req_id,
                              num_computed_blocks=num_computed_blocks,
                              num_total_blocks=num_total_blocks,
                              block_hashes=block_hashes,
                              block_ids=block_ids,
                              operation="save"
                          )
                          logger.debug(f"[EMS] req {new_req.req_id} need save, meta: {req_meta}.")
                          connector_meta.add_requests(req_meta)
                          self.need_save_req_ids.add(new_req.req_id)
          
                  return connector_meta
          
              def request_finished(self, request: Request, block_ids: List[int], ) -> Tuple[bool, Optional[Dict[str, Any]]]:
                  if request.request_id in self.processed_requests:
                      self.processed_requests.remove(request.request_id)
          
                  need_save = request.request_id in self.need_save_req_ids
                  self.need_save_req_ids.discard(request.request_id)
          
                  logger.debug(f"[EMS][Scheduler] request finished, req {request.request_id}, "
                              f"need_save={need_save}")
                  return need_save, None
          
              def _need_load(self, num_computed_blocks: int, num_total_blocks: int) -> bool:
                  # 长度小于ems_num_min_reuse_tokens,不load
                  if num_total_blocks * self.block_size < EmsEnv.ems_num_min_reuse_tokens:
                      return False
                  # load block数量小于ems_num_min_load_blocks, 不load
                  if num_total_blocks - num_computed_blocks <= EmsEnv.ems_num_min_load_blocks:
                      return False
          
                  return True
          
              def _cal_block_hashes(self, token_ids: List[int], block_size) -> List[int]:
                  result: List[int] = []
                  prev_block_hash = 0
                  num_blocks = len(token_ids) // block_size
          
                  for block_id in range(num_blocks):
                      block_hash = self._cal_block_hash(token_ids[block_id * block_size:(block_id + 1) * block_size],
                                                        prev_block_hash)
                      result.append(block_hash)
                      prev_block_hash = block_hash
          
                  return result
          
              def _cal_block_hash(self, block_token_ids: List[int], prev_block_hash: int) -> int:
                  return hash((prev_block_hash, *block_token_ids))
          
          
          class CcConnectorWorker:
              def __init__(self, vllm_config: VllmConfig):
                  self.ems_adapter = EmsAdapter(vllm_config)
                  self.finished_req_ids = set()
                  self.finished_save_req_ids = set()
          
              def register_kv_caches(self, kv_caches: Dict[str, Tuple[torch.Tensor]]):
                  return self.ems_adapter.register_kv_caches(kv_caches)
          
              def start_load_kv(self, metadata: CcConnectorMetadata):
                  for request in metadata.requests:
                      if request.operation != "load":
                          continue
                      self.ems_adapter.async_load(
                          request.req_id,
                          request.block_hashes,
                          request.block_ids,
                          request.num_computed_blocks
                      )
          
              def wait_for_save(self, metadata: CcConnectorMetadata):
                  for request in metadata.requests:
                      if request.operation != "save":
                          continue
          
                      import torch_npu
                      torch_npu.npu.synchronize()
          
                      self.ems_adapter.async_save(
                          request.req_id,
                          request.block_hashes,
                          request.block_ids
                      )
          
              def get_finished(self, finished_req_ids: Set[str]) -> Tuple[Set[str], Set[str]]:
                  self.finished_req_ids.update(finished_req_ids)
                  self.finished_save_req_ids.update(self.ems_adapter.get_finished_save_reqs())
          
                  finished_save_req_ids = self.finished_req_ids & self.finished_save_req_ids
                  self.finished_req_ids -= finished_save_req_ids
                  self.finished_save_req_ids -= finished_save_req_ids
          
                  finished_load_req_ids = self.ems_adapter.get_finished_load_reqs()
          
                  if finished_save_req_ids or finished_load_req_ids:
                      logger.info(f"[EMS][Worker] finished_save_req_ids: {finished_save_req_ids}, "
                                   f"finished_load_req_ids: {finished_load_req_ids}.")
          
                  return finished_save_req_ids, finished_load_req_ids
          
              def get_block_ids_with_load_errors(self) -> Set[int]:
                  block_ids_with_load_errors = self.ems_adapter.get_block_ids_with_load_errors()
                  if block_ids_with_load_errors:
                      logger.info(f"[EMS][Worker] block_ids_with_load_errors: {block_ids_with_load_errors}")
                  return block_ids_with_load_errors
          
              def _lookup_block_hashes(self, req_id: str, block_hashes: List[int]) -> int:
                  return self.ems_adapter.exists_block_num(req_id, block_hashes)
          
          
          class LookupKeyServer:
              def __init__(self, data_parallel_rank, worker_connector):
                  self.worker_connector = worker_connector
          
                  self.zmq_ctx = zmq.Context()
                  self.socket = self.zmq_ctx.socket(zmq.REP)
                  port = int(EmsEnv.ems_lookup_key_server_base_port) + data_parallel_rank
                  self.addr = f"tcp://{EmsEnv.ems_lookup_key_server_ip}:{port}"
                  self.socket.bind(self.addr)
          
                  def process_lookup():
                      while True:
                          req_id, block_hashes = self.socket.recv_pyobj()
                          num_exist_blocks = self.worker_connector._lookup_block_hashes(req_id, block_hashes)
                          self.socket.send_pyobj(num_exist_blocks)
          
                  self.thread = threading.Thread(target=process_lookup, daemon=True, name="ems_exist_proxy")
                  self.thread.start()
          
          
          class LookupKeyClient:
              def __init__(self, data_parallel_rank):
                  self.zmq_ctx = zmq.Context()
                  self.socket = self.zmq_ctx.socket(zmq.REQ)
                  port = int(EmsEnv.ems_lookup_key_server_base_port) + data_parallel_rank
                  self.addr = f"tcp://{EmsEnv.ems_lookup_key_server_ip}:{port}"
                  self.socket.connect(self.addr)
          
              def lookup(self, req_id: str, block_hashes: List[int]) -> int:
                  self.socket.send_pyobj([req_id, block_hashes])
                  num_exist_blocks = self.socket.recv_pyobj()
                  return num_exist_blocks
          
      4. 创建kv_transfer文件夹,在文件夹下创建__init__.py文件。
        mkdir kv_transfer
        vi __init__.py

        __init__.py文件内容如下:

        #
        # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
        # This file is a part of the vllm-ascend project.
        #
        # Licensed under the Apache License, Version 2.0 (the "License");
        # you may not use this file except in compliance with the License.
        # You may obtain a copy of the License at
        #
        #     http://www.apache.org/licenses/LICENSE-2.0
        #
        # Unless required by applicable law or agreed to in writing, software
        # distributed under the License is distributed on an "AS IS" BASIS,
        # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        # See the License for the specific language governing permissions and
        # limitations under the License.
        #
        
        from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
        
        
        def register_connector():
            # override multi_connector as ascend_multi_connector
            if "MultiConnector" in KVConnectorFactory._registry:
                KVConnectorFactory._registry.pop("MultiConnector")
            KVConnectorFactory.register_connector(
                "MultiConnector", "vllm_ascend.distributed.kv_transfer.ascend_multi_connector", "AscendMultiConnector"
            )
        
            KVConnectorFactory.register_connector(
                "MooncakeConnectorV1", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", "MooncakeConnector"
            )
        
            KVConnectorFactory.register_connector(
                "MooncakeConnectorStoreV1",
                "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
                "AscendStoreConnector",
            )
        
            KVConnectorFactory.register_connector(
                "AscendStoreConnector",
                "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
                "AscendStoreConnector",
            )
        
            KVConnectorFactory.register_connector(
                "MooncakeLayerwiseConnector",
                "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector",
                "MooncakeLayerwiseConnector",
            )
        
            KVConnectorFactory.register_connector(
                "UCMConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ucm_connector", "UCMConnectorV1"
            )
        
            KVConnectorFactory.register_connector(
                "LMCacheAscendConnector",
                "vllm_ascend.distributed.kv_transfer.kv_pool.lmcache_ascend_connector",
                "LMCacheConnectorV1",
            )
        
            KVConnectorFactory.register_connector(
                "EmsStoreConnector",
                "vllm_ascend.distributed.ems_store.ems_store_connector",
                "EmsStoreConnector",
            )
    5. 制作镜像。
      1. 创建Dockerfile。
        FROM quay.io/ascend/vllm-ascend:v0.18.0-a3
        
        # Install OpenSSL 1.1.1w from source
        COPY openssl-1.1.1w.tar.gz .
        RUN tar xzvf openssl-1.1.1w.tar.gz \
            && cd openssl-1.1.1w \
            && ./config \
            && make \
            && make install \
            && ln -sf /usr/local/lib/libssl.so.1.1 /usr/lib/aarch64-linux-gnu/libssl.so.1.1 \
            && ln -sf /usr/local/lib/libcrypto.so.1.1 /usr/lib/aarch64-linux-gnu/libcrypto.so.1.1 \
            && cd .. \
            && rm -rf openssl-1.1.1w openssl-1.1.1w.tar.gz
        
        # Install cyrus-sasl 2.1.28 from source
        COPY cyrus-sasl-2.1.28.tar.gz .
        RUN tar xzvf cyrus-sasl-2.1.28.tar.gz \
            && cd cyrus-sasl-2.1.28 \
            && ./configure \
            && make \
            && make install \
            && ln -sf /usr/local/lib/libsasl2.so.3 /usr/lib/aarch64-linux-gnu/libsasl2.so.3 \
            && ln -sf /usr/local/lib/sasl2 /usr/lib/aarch64-linux-gnu/sasl2 \
            && cd .. \
            && rm -rf cyrus-sasl-2.1.28 cyrus-sasl-2.1.28.tar.gz
        
        RUN mkdir -p /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/
        
        COPY ./ems_store/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/
        COPY ./ems_store/ems_store_connector.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/
        COPY ./ems_store/ems_env.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/
        COPY ./ems_store/ems_adapter.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/
        COPY ./kv_transfer/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/kv_transfer/
        COPY ems-26.3.0-cp311-cp311-linux_aarch64.whl .
        
        RUN pip install --no-deps ems-26.3.0-cp311-cp311-linux_aarch64.whl \
            && rm ems-26.3.0-cp311-cp311-linux_aarch64.whl
      2. 构建镜像。
        docker build -t vllm-ascend-ems:v0.18.0-a3 .
    6. 创建包含启动脚本的ConfigMap。根据所选模型的不同,启动脚本参数需相应调整,具体可参见vllm-ascend
      kubectl apply -f config.yaml

      部署模板示如下所示。其中nic_name可通过ip route | grep default命令获取。

      kind: ConfigMap
      apiVersion: v1
      metadata:
        name: kimi-k25-pd-cm
      data:
        prefill.sh: |
          nic_name="enp23s0f3"  # network card name
          local_ip=$POD_IP
          export HCCL_IF_IP=$local_ip
          export GLOO_SOCKET_IFNAME=$nic_name
          export TP_SOCKET_IFNAME=$nic_name
          export HCCL_SOCKET_IFNAME=$nic_name
      
          export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD
          sysctl -w vm.swappiness=0
          sysctl -w kernel.numa_balancing=0
          sysctl kernel.sched_migration_cost_ns=50000
          export VLLM_RPC_TIMEOUT=3600000
          export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=30000
      
          export HCCL_OP_EXPANSION_MODE="AIV"
          export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
          export OMP_PROC_BIND=false
          export OMP_NUM_THREADS=1
          export TASK_QUEUE_ENABLE=1
          export ASCEND_BUFFER_POOL=4:8
      
          export HCCL_BUFFSIZE=256
          export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
      
          # --- EMS Configuration Exports ---
          export EMS_NUM_MIN_REUSE_TOKENS=0
          export EMS_NUM_MIN_LOAD_BLOCKS=0
          export EMS_BLOCK_GROUP_SIZE=2
          export LD_LIBRARY_PATH=/usr/local/python3.11.14/lib/python3.11/site-packages/ems/lib:$LD_LIBRARY_PATH
          export EMS_LOOKUP_KEY_SERVER_BASE_PORT=60687
      
          vllm serve $MODEL_LOCATION \
            --host $POD_IP \
            --port "7100" \
            --data-parallel-size 2 \
            --data-parallel-address $POD_IP \
            --data-parallel-rpc-port 12321 \
            --tensor-parallel-size 8 \
            --enable-expert-parallel \
            --seed 1024 \
            --quantization ascend \
            --served-model-name kimi_k25 \
            --trust-remote-code \
            --max-num-seqs 8 \
            --max-model-len 32768 \
            --max-num-batched-tokens 16384 \
            --disable-hybrid-kv-cache-manager \
            --no-enable-prefix-caching \
            --gpu-memory-utilization 0.8 \
            --enforce-eager \
            --speculative-config '{"method": "eagle3", "model":"/models/kimi-k2.5-eagle3", "num_speculative_tokens": 3}' \
            --additional-config '{"recompute_scheduler_enable":true}' \
            --mm-encoder-tp-mode data \
            --kv-transfer-config \
            '{
              "kv_connector": "EmsStoreConnector",
              "kv_role": "kv_producer"
            }'
      
        decode.sh: |
          nic_name="enp23s0f3"  # network card name
          local_ip=$POD_IP
          export HCCL_IF_IP=$local_ip
          export GLOO_SOCKET_IFNAME=$nic_name
          export TP_SOCKET_IFNAME=$nic_name
          export HCCL_SOCKET_IFNAME=$nic_name
      
          export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD
          sysctl -w vm.swappiness=0
          sysctl -w kernel.numa_balancing=0
          sysctl kernel.sched_migration_cost_ns=50000
          export VLLM_RPC_TIMEOUT=3600000
          export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=30000
      
          export HCCL_OP_EXPANSION_MODE="AIV"
          export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
          export OMP_PROC_BIND=false
          export OMP_NUM_THREADS=1
          export TASK_QUEUE_ENABLE=1
          export ASCEND_BUFFER_POOL=4:8
      
          export HCCL_BUFFSIZE=2200
          export VLLM_ASCEND_ENABLE_MLAPO=0
      
          # --- EMS Configuration Exports ---
          export EMS_NUM_MIN_REUSE_TOKENS=0
          export EMS_NUM_MIN_LOAD_BLOCKS=0
          export EMS_BLOCK_GROUP_SIZE=2
          export LD_LIBRARY_PATH=/usr/local/python3.11.14/lib/python3.11/site-packages/ems/lib:$LD_LIBRARY_PATH
          export EMS_LOOKUP_KEY_SERVER_BASE_PORT=60687
      
          vllm serve $MODEL_LOCATION \
            --host $POD_IP \
            --port "7100" \
            --data-parallel-size 16 \
            --data-parallel-address $POD_IP \
            --data-parallel-rpc-port 12322 \
            --tensor-parallel-size 1 \
            --enable-expert-parallel \
            --seed 1024 \
            --quantization ascend \
            --served-model-name kimi_k25 \
            --trust-remote-code \
            --max-num-seqs 48 \
            --max-model-len 32768 \
            --max-num-batched-tokens 256 \
            --disable-hybrid-kv-cache-manager \
            --no-enable-prefix-caching \
            --gpu-memory-utilization 0.95 \
            --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes":[4,8,16,32,48,64,80,96,112,128,144,160]}' \
            --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": false}' \
            --speculative-config '{"method": "eagle3", "model":"/models/kimi-k2.5-eagle3", "num_speculative_tokens": 3}' \
            --kv-transfer-config \
            '{
              "kv_connector": "EmsStoreConnector",
              "kv_role": "kv_consumer"
            }'
    7. 部署ModelServing。
      kubectl apply -f kimi-k25-serv.yaml

      部署模板示如下所示。

      apiVersion: workload.serving.volcano.sh/v1alpha1
      kind: ModelServing
      metadata:
        name: kimi-k25-pd
        namespace: default
      spec:
        schedulerName: volcano
        replicas: 1
        plugins:
          - name: pod-discovery
        recoveryPolicy: ServingGroupRecreate
        template:
          restartGracePeriodSeconds: 60
          roles:
          - name: prefill
            replicas: 1
            workerReplicas: 0
            entryTemplate:
              spec:
                hostNetwork: true
                containers:
                - name: prefill
                  image: docker.io/library/vllm-ascend-ems:v0.18.0-a3
                  command:
                    - /bin/bash
                  args:
                    - '-c'
                    - cd /workspace && ./prefill.sh
                  env:
                  - name: ROLE
                    value: "prefill"
                  - name: GROUP_NAME
                    valueFrom:
                      fieldRef:
                        fieldPath: metadata.labels['modelserving.volcano.sh/group-name']
                  - name: ROLE_ID
                    valueFrom:
                      fieldRef:
                        fieldPath: metadata.labels['modelserving.volcano.sh/role-id']
                  - name: POD_IP
                    valueFrom:
                      fieldRef:
                        fieldPath: status.podIP
                  - name: NODE_IP
                    valueFrom:
                      fieldRef:
                        fieldPath: status.hostIP
                  - name: MODEL_LOCATION
                    value: /models/Kimi-K2.5-w4a8
                  - name: TP_SIZE
                    value: "8"
                  - name: DP_SIZE
                    value: "2"
                  readinessProbe:
                    httpGet:
                      path: /health
                      port: 7100
                      scheme: HTTP
                    initialDelaySeconds: 60
                    periodSeconds: 10
                    timeoutSeconds: 2
                    failureThreshold: 3
                  resources:
                    limits:
                      cpu: '188'
                      huawei.com/ascend-1980: '16'
                      memory: 1800Gi
                    requests:
                      cpu: '64'
                      huawei.com/ascend-1980: '16'
                      memory: 700Gi
                  ports:
                  - containerPort: 7100
                    name: server
                  volumeMounts:
                  - name: model
                    mountPath: /models
                  - name: dshm
                    mountPath: /dev/shm
                  - name: ems-shm
                    mountPath: /dev/shm/ems
                  - name: hccn-conf
                    mountPath: /etc/hccn.conf
                  - name: hccn-tool
                    mountPath: /usr/local/Ascend/driver/tools/hccn_tool
                  - name: ascend-install-info
                    mountPath: /etc/ascend_install.info
                  - name: config
                    mountPath: /workspace/prefill.sh
                    subPath: prefill.sh
                volumes:
                - name: model
                  hostPath:
                    path: /mnt/paas/models
                    type: Directory
                - name: dshm
                  emptyDir:
                    medium: Memory
                - name: ems-shm
                  hostPath:
                    path: /mnt/paas/kubernetes/kubelet/ems
                    type: DirectoryOrCreate
                - name: hccn-conf
                  hostPath:
                    path: /etc/hccn.conf
                - name: hccn-tool
                  hostPath:
                    path: /usr/local/Ascend/driver/tools/hccn_tool
                - name: ascend-install-info
                  hostPath:
                    path: /etc/ascend_install.info
                - name: config
                  configMap:
                    name: kimi-k25-pd-cm
                    defaultMode: 0777
          - name: decode
            replicas: 1
            workerReplicas: 0
            entryTemplate:
              spec:
                hostNetwork: true
                containers:
                - name: decode
                  image: docker.io/library/vllm-ascend-ems:v0.18.0-a3
                  command:
                    - /bin/bash
                  args:
                    - '-c'
                    - cd /workspace && ./decode.sh
                  env:
                  - name: ROLE
                    value: "decode"
                  - name: ENGINE_ID
                    valueFrom:
                      fieldRef:
                        fieldPath: metadata.name
                  - name: POD_IP
                    valueFrom:
                      fieldRef:
                        fieldPath: status.podIP
                  - name: NODE_IP
                    valueFrom:
                      fieldRef:
                        fieldPath: status.hostIP
                  - name: GROUP_NAME
                    valueFrom:
                      fieldRef:
                        fieldPath: metadata.labels['modelserving.volcano.sh/group-name']
                  - name: ROLE_ID
                    valueFrom:
                      fieldRef:
                        fieldPath: metadata.labels['modelserving.volcano.sh/role-id']
                  - name: MODEL_LOCATION
                    value: /models/Kimi-K2.5-w4a8
                  - name: TP_SIZE
                    value: "1"
                  - name: DP_SIZE
                    value: "16"
                  readinessProbe:
                    httpGet:
                      path: /health
                      port: 7100
                      scheme: HTTP
                    initialDelaySeconds: 60
                    periodSeconds: 10
                    timeoutSeconds: 2
                    failureThreshold: 3
                  ports:
                  - containerPort: 7100
                    name: server
                  resources:
                    limits:
                      cpu: '188'
                      huawei.com/ascend-1980: '16'
                      memory: 1800Gi
                    requests:
                      cpu: '64'
                      huawei.com/ascend-1980: '16'
                      memory: 700Gi
                  volumeMounts:
                  - name: model
                    mountPath: /models
                  - name: dshm
                    mountPath: /dev/shm
                  - name: ems-shm
                    mountPath: /dev/shm/ems
                  - name: hccn-conf
                    mountPath: /etc/hccn.conf
                  - name: hccn-tool
                    mountPath: /usr/local/Ascend/driver/tools/hccn_tool
                  - name: ascend-install-info
                    mountPath: /etc/ascend_install.info
                  - name: config
                    mountPath: /workspace/decode.sh
                    subPath: decode.sh
                volumes:
                - name: model
                  hostPath:
                    path: /mnt/paas/models
                    type: Directory
                - name: dshm
                  emptyDir:
                    medium: Memory
                - name: ems-shm
                  hostPath:
                    path: /mnt/paas/kubernetes/kubelet/ems
                    type: DirectoryOrCreate
                - name: hccn-conf
                  hostPath:
                    path: /etc/hccn.conf
                - name: hccn-tool
                  hostPath:
                    path: /usr/local/Ascend/driver/tools/hccn_tool
                - name: ascend-install-info
                  hostPath:
                    path: /etc/ascend_install.info
                - name: config
                  configMap:
                    name: kimi-k25-pd-cm
                    defaultMode: 0777
    8. 配置代理。
      1. 下载代理脚本。
        wget https://raw.githubusercontent.com/vllm-project/vllm-ascend/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py
      2. 获取Pod IP。
        kubectl get pods -owide

        返回示例如下。

        NAME                        READY   STATUS    RESTARTS   AGE     IP              NODE            NOMINATED NODE   READINESS GATES
        kimi-k25-pd-0-decode-0-0    1/1     Running   0          9m57s   192.168.0.156   192.168.0.156   <none>           <none>
        kimi-k25-pd-0-prefill-0-0   1/1     Running   0          9m57s   192.168.0.163   192.168.0.163   <none>           <none>
      3. 启动代理服务。
        python3 load_balance_proxy_server_example.py \
          --port 8181 \
          --host 0.0.0.0 \
          --prefiller-hosts 192.168.0.163 \
          --prefiller-ports 7100 \
          --decoder-hosts 192.168.0.156 \
          --decoder-ports 7100
    9. 服务验证。使用curl发送标准Chat API请求,验证推理链路是否正常。
      1. 发送请求。
        curl -X POST http://192.168.0.163:8181/v1/chat/completions \
          -H "Content-Type: application/json" \
          -d '{
            "model": "kimi_k25",
            "messages": [
              {
                "role": "user",
                "content": "Hello, how are you?"
              }
            ],
            "max_tokens": 100
          }'
      2. 预期返回结果。

        若部署成功,应返回包含choices和usage字段的JSON数据。

        {
          "id": "chatcmpl-fd56a4a3-829f-49c7-9ffa-f890e8******",
          "object": "chat.completion",
          "created": 1779713507,
          "model": "kimi_k25",
          "choices": [
            {
              "index": 0,
              "message": {
                "role": "assistant",
                "content": " The user is greeting me and asking how I am. This is a standard conversational opening. I should respond politely, let them know I'm doing well (as an AI, I don't have feelings, but it's social convention to respond positively), and invite them to continue the conversation or ask how I can help them today.\n\nI should keep it friendly, concise, and open-ended so they can tell me what they need help with. </think> Hello! I'm doing well, thank you for asking. How can I",
                "refusal": null,
                "annotations": null,
                "audio": null,
                "function_call": null,
                "tool_calls": [],
                "reasoning": null
              },
              "logprobs": null,
              "finish_reason": "length",
              "stop_reason": null,
              "token_ids": null
            }
          ],
          "service_tier": null,
          "system_fingerprint": null,
          "usage": {
            "prompt_tokens": 14,
            "total_tokens": 114,
            "completion_tokens": 100,
            "prompt_tokens_details": null,
            "completion_tokens_details": null
          },
          "prompt_logprobs": null,
          "prompt_token_ids": null,
          "kv_transfer_params": null
        }
      3. 代理侧日志确认。

        在运行代理脚本的终端中,应看到请求被成功接收并转发。

        INFO:     Started server process [303949]
        INFO:     Waiting for application startup.
        Initialized 1 prefill clients and 1 decode clients.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:8181 (Press CTRL+C to quit)
        INFO:     192.168.0.163:37794 - "POST /v1/chat/completions HTTP/1.1" 200 OK

相关文档