华为云EMS和昇腾NPU大模型推理部署指南
简介
本实践介绍如何在CCE集群中,基于昇腾NPU部署大语言模型推理服务,并通过集成弹性内存存储(EMS)实现KVCache分布式缓存。结合EMS的以存代算能力,可显著降低首Token时延(TTFT)、提升多轮对话吞吐,缓解显存与内存瓶颈,在保障推理体验的同时有效降低算力成本。更多EMS的介绍,请参见什么是弹性内存存储。
前提条件
- 已申请EMS服务公测资格,具体请参见申请公测_。
- 已创建v1.29及以上的CCE Turbo集群,并已安装节点本地域名解析加速和Volcano调度器(v1.22.1版本及以上)插件。
- 集群内已纳管或创建至少5个A2或A3规格节点,切勿混用不同规格机器。
约束与限制
该功能当前正处于上线阶段,已发布区域请以控制台实际为准。
操作步骤
- 开启EMS加速与节点打标。
- 登录CCE控制台,进入目标集群,在左侧导航栏选择“配置中心”,单击“异构资源配置”页签。
- 开启“EMS加速”开关,将准备好的A2/A3节点打上标签。
- 在弹出的对话框中,单击“保存”,下发EMS集群。
- (可选)配置安全组规则。
若节点使用宿主机网络,需手动配置安全组以允许EMS组件通信。
- 访问资源栈,根据EMS集群ID搜索资源栈,确认状态为部署成功。
- 获取默认安全组ID。
- 在CCE控制台目标集群的“概览”页面,在“网络信息”区域,单击“网络信息”区域“节点默认安全组”后的链接。
- 在“基本信息”页签,单击ID后的
图标,复制安全组ID。
- 配置EMS安全组规则。
- 单击“基本信息”页签上方的
图标。 - 在“访问控制 > 安全组”页面的搜索框里搜索“{EMS集群ID}”,然后单击“ems-zookeeper-{EMS集群ID}”。
- 单击“入方向规则”,然后单击“添加规则”。
- 在“添加入方向规则”面板,“协议端口”填写“10000”,“源地址”下拉列表选择“安全组”,下方填写“$.{节点默认安全组ID}”。
- 单击下方的“增加1条规则”,“协议端口”填写“11000”,“源地址”下拉列表选择“安全组”,下方填写“$.{节点默认安全组ID}”,然后单击“确定”。
- 单击“基本信息”页签上方的
- 重启EMS控制器。
- 在CCE控制台目标集群“工作负载”页面的“无状态负载”页签,找到对应的命名空间ems-{ems_cluster_id},单击ems-controller工作负载名称。
- 在实例列表页签,勾选所有实例,单击“批量删除”,在弹出的对话框中,单击“是”。
- 重复上述1~2步骤,选择ems-server负载执行同样的操作。
- 在“配置中心”的“异构资源配置”页签,等待组件加载完毕。
- 构建推理镜像。
- 在CCE控制台目标集群“配置中心”页面的“调度配置”页签,设置默认调度器为“Volcano调度器”。
- 确认机器ROCE网络状态,可参见Prefill-Decode Disaggregation (Deepseek)。
- 准备本地大模型。可通过ModelScope下载模型,本文以Kimi-K2.5-W4A8模型为例,确保模型已下载到节点的/mnt/pass/models目录。
- 准备昇腾引擎镜像。
- 下载openssl和cyrus安装包。
https://github.com/openssl/openssl/releases/download/OpenSSL_1_1_1w/openssl-1.1.1w.tar.gz https://github.com/cyrusimap/cyrus-sasl/releases/download/cyrus-sasl-2.1.28/cyrus-sasl-2.1.28.tar.gz
- 请前往CCE集群“配置中心 > 异构资源配置”页签下的EMS配置中下载EMS SDK安装包,并将其置于与Dockerfile同级目录下。
- 创建ems_store文件夹。
mkdir ems_store
在文件夹下创建__init__.py、ems_adapter.py、ems_env.py、ems_store_connector.py四个文件:
- __init__.py
vi __init__.py
该文件为空,直接按ECS,输入wq保存。
- ems_adapter.py
vi ems_adapter.py
内容如下:
from typing import List, Tuple, Dict, Set import threading import time import os from collections import deque import torch import numpy as np from vllm.config import VllmConfig from vllm.distributed import get_tp_group, get_pp_group from vllm.logger import logger from ems import Ems, CcKvOption, EmsException, EmsErrorCode, EmsConfig, CcConfig_v1 from ems.cc_v1.cc_config import KVCacheType from vllm_ascend.distributed.ems_store.ems_env import EmsEnv class EmsAdapter: def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.is_mla = vllm_config.model_config.use_mla self.model_name = "_".join(os.path.normpath(vllm_config.model_config.model).split("/")[-2:]) logger.info(f"[EMS] EmsAdapter init with vllm_config: {vllm_config}, " f"is_mla: {self.is_mla}, model_name: {self.model_name}") self.rank = vllm_config.parallel_config.rank self.world_size = vllm_config.parallel_config.world_size self.block_size = vllm_config.cache_config.block_size self.load_futures = {} self.save_futures = {} self.failed_block_ids: set[int] = set() self._inited = False self._need_reinit = False self._registered = False self._pending_kv_caches = None self.ems_cfg = self.get_ems_cfg() self.context_caching = None self.init_ems() self.task_manager = None self.init_task_manager() def get_ems_cfg(self) -> EmsConfig: tp_group = get_tp_group() pp_group = get_pp_group() cc_config_v1 = CcConfig_v1( rank_id=tp_group.rank, device_id=tp_group.local_rank, model_id=EmsEnv.model_id, tp_world_size=tp_group.world_size, pp_world_size=pp_group.world_size, rank_in_tp_group=tp_group.rank_in_group, rank_in_pp_group=pp_group.rank_in_group, llm_engine=f"{EmsEnv.llm_engine}@{self.model_name}" ) if self.is_mla: cc_config_v1.kvcache_type = KVCacheType.MLA ems_cfg = EmsConfig( access_id=EmsEnv.access_id, access_key=EmsEnv.access_key, cc_config_v1=cc_config_v1 ) return ems_cfg def init_ems(self) -> None: try: Ems.init(self.ems_cfg) self.context_caching = Ems.get_cc() self._inited = True logger.info(f"[EMS][Init] EmsConnector init succeed, EMS ready.") except EmsException as e: if e.status_code() == EmsErrorCode.EMS_RECOVERD_ERROR: self._need_reinit = True logger.error(f"[EMS][Init] EmsConnector init fail, error: {e}.") if not self._inited and self.context_caching: logger.warning(f"[EMS][Init][degraded] EmsConnector init failed but get cc, reason=EMS not ready.") def init_task_manager(self) -> None: try: self.task_manager = PeriodicTaskManager(check_fn=self._check_health) logger.info(f"[EMS] init periodic task manager succeeded.") except Exception as e: logger.error(f"[EMS] init periodic task manager failed, error: {e}.") raise def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]) -> None: if not self._inited: self._pending_kv_caches = kv_caches logger.warning("[EMS][Register] EMS not ready, skip register_kvcache.") return kv_caches_list: List[List[torch.Tensor]] = [] for _, kv_caches_layer in kv_caches.items(): num_kv = len(kv_caches_layer) kv_tensor_list: List[torch.Tensor] = [kv_caches_layer[idx] for idx in range(num_kv)] kv_caches_list.append(kv_tensor_list) try: self.context_caching.register_kvcache(kv_caches_list) self._registered = True self._pending_kv_caches = None # shape of kv_caches_list: [layers, [k_v_index, [gpu_blocks, block_size, heads, head_size]]] ds_layers = len(kv_caches_list) ds_k_v_index = len(kv_caches_list[0]) ds_gpu_blocks, ds_block_size, ds_heads, ds_head_size = list(kv_caches_list[0][0].shape) logger.info(f"[EMS][Register] layer_num={ds_layers}, block_size={ds_block_size}, " f"is_mla={self.is_mla}, kvcache_dim={ds_k_v_index}") except EmsException as e: logger.error(f"[EMS][Register] register kvcache error: {e}.") def _cal_block_offsets(self, block_ids: List[int], block_size: int) -> List[int]: return [block_size] * len(block_ids) def _cal_block_slot_mapping(self, block_ids: List[int], block_size: int) -> List[int]: block_ids_np = np.array(block_ids) range_arr = np.arange(block_size) result_matrix = block_ids_np[:, np.newaxis] * block_size + range_arr flattened_result = result_matrix.ravel() return flattened_result.tolist() def exists_block_num(self, req_id: str, block_hashes: list[int]) -> int: if not self._inited or not self._registered: logger.warning(f"[EMS][exists][degraded] req_id={req_id} reason=EMS not ready.") return 0 if not self.task_manager.get_status(): logger.warning(f"[EMS][exists][skip] req_id={req_id} reason=Unhealthy EMS") return 0 option = CcKvOption( write_rcache=EmsEnv.ems_enable_write_rcache, read_local_only=EmsEnv.ems_enable_read_local_only, timeout=EmsEnv.ems_timeout ) try: cc_result = self.context_caching.exists(hashes=block_hashes, option=option) return cc_result.success except EmsException as e: logger.debug(f"[EMS][Exist] fallback to len(hashes) due to: {e}") return len(block_hashes) def async_load(self, req_id: str, block_hashes: List[int], block_ids: List[int], num_computed_blocks: int) -> None: future = None submit_time = time.perf_counter() if not self._check_params(req_id, block_hashes, block_ids, "AsyncLoad"): self.load_futures[req_id] = (future, submit_time, num_computed_blocks, block_ids) return option = CcKvOption( write_rcache=EmsEnv.ems_enable_write_rcache, read_local_only=EmsEnv.ems_enable_read_local_only, timeout=EmsEnv.ems_timeout ) offsets = self._cal_block_offsets(block_ids, self.block_size) slot_mapping = self._cal_block_slot_mapping(block_ids, self.block_size) try: future = self.context_caching.async_load( slot_mapping=slot_mapping, hashes=block_hashes, offsets=offsets, option=option ) logger.info(f"[EMS][AsyncLoad][start] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} " f"planned_blocks={len(block_ids)} first_block_hash={block_hashes[0]}, future=submitted") except EmsException as e: self._process_exception(e) logger.error(f"[EMS][AsyncLoad][error] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} ..., err={e}") self.load_futures[req_id] = (future, submit_time, num_computed_blocks, block_ids) def async_save(self, req_id: str, block_hashes: List[int], block_ids: List[int]) -> None: future = None submit_time = time.perf_counter() if not self._check_params(req_id, block_hashes, block_ids, "AsyncSave"): self.save_futures[req_id] = (future, submit_time) return option = CcKvOption( write_rcache=EmsEnv.ems_enable_write_rcache, read_local_only=EmsEnv.ems_enable_read_local_only, timeout=EmsEnv.ems_timeout ) offsets = self._cal_block_offsets(block_ids, self.block_size) slot_mapping = self._cal_block_slot_mapping(block_ids, self.block_size) try: future = self.context_caching.async_save( slot_mapping=slot_mapping, hashes=block_hashes, offsets=offsets, option=option ) logger.info(f"[EMS][AsyncSave][start] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} " f"planned_blocks={len(block_ids)} first_block_hash={block_hashes[0]}, future=submitted") except EmsException as e: self._process_exception(e) logger.error(f"[EMS][AsyncSave][error] req_id={req_id} timeout_ms={EmsEnv.ems_timeout} ..., err={e}") self.save_futures[req_id] = (future, submit_time) def get_finished_load_reqs(self) -> Set[str]: finished_reqs = set() for req_id, (future, submit_time, num_computed_blocks, block_ids) in list(self.load_futures.items()): if future and not self.context_caching.is_ready(future): continue try: if future: result = self.context_caching.get_result(future) cost_ms = 1e3 * (time.perf_counter() - submit_time) logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={result.success}, " f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}") self.task_manager.update_req_stat("LOAD", result.success, cost_ms) self.task_manager.update_hit_stat(result.success, result.total) if result.success < len(block_ids): failed_part = block_ids[result.success:] self.failed_block_ids.update(failed_part) else: cost_ms = 1e3 * (time.perf_counter() - submit_time) logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={0}, " f"total_blocks_num={0}, status=EMS_INTERNAL_ERROR, cost_ms={cost_ms:.2f}") self.task_manager.update_req_stat("LOAD", 0, cost_ms) self.failed_block_ids.update(block_ids) except EmsException as e: cost_ms = 1e3 * (time.perf_counter() - submit_time) self.failed_block_ids.update(block_ids) self._process_exception(e) logger.info(f"[EMS][GetResult] Req {req_id} async load done, success_blocks={0}, " f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}") finished_reqs.add(req_id) self.load_futures.pop(req_id) return finished_reqs def get_block_ids_with_load_errors(self) -> Set[int]: failed_block_ids = self.failed_block_ids self.failed_block_ids = set() return failed_block_ids def get_finished_save_reqs(self) -> Set[str]: finished_reqs = set() for req_id, (future, submit_time) in list(self.save_futures.items()): if future and not self.context_caching.is_ready(future): continue try: if future: result = self.context_caching.get_result(future) cost_ms = 1e3 * (time.perf_counter() - submit_time) logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={result.success}, " f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}") self.task_manager.update_req_stat("SAVE", result.success, cost_ms) else: cost_ms = 1e3 * (time.perf_counter() - submit_time) logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, " f"total_blocks_num={0}, status=EMS_INTERNAL_ERROR, cost_ms={cost_ms:.2f}") self.task_manager.update_req_stat("SAVE", 0, cost_ms) except EmsException as e: cost_ms = 1e3 * (time.perf_counter() - submit_time) self._process_exception(e) logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, " f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}") finished_reqs.add(req_id) self.save_futures.pop(req_id) return finished_reqs def sync_save_reqs(self) -> None: for req_id, (future, submit_time) in self.save_futures.items(): if future is None: continue try: result = self.context_caching.get_result(future) cost_ms = 1e3 * (time.perf_counter() - submit_time) logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={result.success}, " f"total_blocks_num={result.total}, status=SUCCESS, cost_ms={cost_ms:.2f}") self.task_manager.update_req_stat("SAVE", result.success, cost_ms) except EmsException as e: cost_ms = 1e3 * (time.perf_counter() - submit_time) self._process_exception(e) logger.info(f"[EMS][GetResult] Req {req_id} async save done, success_blocks={0}, " f"total_blocks_num={0}, status={e.status_code().name}, cost_ms={cost_ms:.2f}") self.save_futures.clear() def _check_health(self) -> bool: is_health = Ems.check_health() if is_health: logger.info("[EMS] EMS health status is ok.") else: logger.info("[EMS] EMS health status is abnormal.") if not self._inited and is_health and self._need_reinit: logger.info(f"[EMS][Init] re-init during health check.") self.init_ems() if self._inited: logger.info("[EMS][Init] re-init succeed during health check.") if (not self._registered) and (self._pending_kv_caches is not None): self.register_kv_caches(self._pending_kv_caches) logger.info(f"[EMS][Register] re-register pending kvcache succeed.") return is_health def _check_params(self, req_id: str, block_hashes: List[int], block_ids: List[int], called_at: str) -> bool: if not self._inited or not self._registered: logger.warning(f"[EMS][{called_at}][degraded] req_id={req_id} reason=EMS not ready.") return False if not self.task_manager.get_status(): logger.warning(f"[EMS][{called_at}][skip] req_id={req_id} reason=Unhealthy EMS") return False if (len(block_hashes) == 0 or len(block_ids) == 0) or (len(block_hashes) != len(block_ids)): logger.error(f"[EMS] req {req_id} has invalid block_hashes or block_ids: " f"len(block_hashes) == {len(block_hashes)}, len(block_ids) == {len(block_ids)}.") return False return True def _process_exception(self, e: EmsException) -> None: if e.status_code() == EmsErrorCode.EMS_INVALID_ARGUMENT: return self.task_manager.reset_status() class PeriodicTaskManager: HEALTH_CHECK_INTERVAL = 10 # 检查间隔 (秒) FLAPPING_WINDOW = 60 # 震荡检测窗口 (秒) FLAPPING_LIMIT = 3 # 窗口内允许的最大状态变更次数 SUCCESS_THRESHOLD = 3 # 恢复健康所需的连续成功次数 PRINT_INTERVAL = 30 # 检查间隔 (秒) def __init__(self, check_fn): self.check_fn = check_fn self._ems_ok = False self._consecutive_success_count = 0 # 滑动窗口震荡检测记录 self._change_history: deque = deque() self.last_log_time = time.perf_counter() self.stat = { # [sum, min, max], init min with a large number: 2**30 (int) or 1e9 (float) "LOAD": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]}, "SAVE": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]}, "HIT": {"num_hit_blocks": 0, "num_total_blocks": 0}, } self.thread_lock = threading.Lock() self.start_loop() def get_status(self) -> bool: return self._ems_ok def reset_status(self) -> None: if not self._ems_ok: return logger.info(f"[EMS] Ems health status reset to False") self._ems_ok = False def check_health_status(self) -> None: is_healthy = False try: is_healthy = self.check_fn() except Exception as e: logger.error(f"[EMS] EMS health check failed, error: {e}.") self._process_check_result(is_healthy) def _process_check_result(self, is_healthy: bool): if is_healthy: self._handle_success() else: self._handle_failure() def _handle_success(self): self._consecutive_success_count += 1 if not self._ems_ok: if self._consecutive_success_count >= self.SUCCESS_THRESHOLD: # 尝试切换健康状态为True,这会受到震荡检测的严格拦截 if self._try_switch_status(new_status=True): pass else: # 被震荡拦截,重置计数器 self._consecutive_success_count = 0 else: self._consecutive_success_count = min(self._consecutive_success_count, self.SUCCESS_THRESHOLD) def _handle_failure(self): self._consecutive_success_count = 0 if self._ems_ok: # 尝试切换健康状态为False,忽略震荡,立即切换 self._try_switch_status(new_status=False) def _try_switch_status(self, new_status: bool) -> bool: current_time = time.monotonic() # 1. 清理滑动窗口 self._clean_flapping_history(current_time) # 2. 检查震荡 is_flapping = len(self._change_history) >= self.FLAPPING_LIMIT if is_flapping: if new_status is True: # 震荡中尝试恢复被拦截 logger.warning( f"[EMS] EMS status flapping detected ({len(self._change_history)} changes in {self.FLAPPING_WINDOW}s). " f"Blocking recovery (switch to True)." ) return False else: logger.info( f"[EMS] EMS status flapping detected, but forcing status to False (Fail Fast strategy)." ) # 3. 执行切换 logger.info(f"[EMS] EMS health status changing: {self._ems_ok} -> {new_status}") self._ems_ok = new_status # 4. 记录变更历史 self._change_history.append(current_time) return True def _clean_flapping_history(self, current_time: float): threshold_time = current_time - self.FLAPPING_WINDOW while self._change_history and self._change_history[0] < threshold_time: self._change_history.popleft() def update_req_stat(self, event: str, block_num: int, cost_time: float) -> None: with self.thread_lock: self.stat[event]["count"] += 1 self.stat[event]["block_nums"][0] += block_num self.stat[event]["block_nums"][1] = min(block_num, self.stat[event]["block_nums"][1]) self.stat[event]["block_nums"][2] = max(block_num, self.stat[event]["block_nums"][2]) self.stat[event]["cost_times"][0] += cost_time self.stat[event]["cost_times"][1] = min(cost_time, self.stat[event]["cost_times"][1]) self.stat[event]["cost_times"][2] = max(cost_time, self.stat[event]["cost_times"][2]) def update_hit_stat(self, num_hit_blocks: int, num_total_blocks: int) -> None: with self.thread_lock: self.stat["HIT"]["num_hit_blocks"] += num_hit_blocks self.stat["HIT"]["num_total_blocks"] += num_total_blocks def print_stat(self) -> None: with self.thread_lock: stat = self.stat self.stat = { # [sum, min, max], init min with a large number: 2**30 (int) or 1e9 (float) "LOAD": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]}, "SAVE": {"count": 0, "block_nums": [0, 2**30, 0], "cost_times": [0.0, 1e9, 0.0]}, "HIT": {"num_hit_blocks": 0, "num_total_blocks": 0}, } load = stat["LOAD"]["count"] avg_load_block_num = stat["LOAD"]["block_nums"][0] / load if load else 0.0 min_load_block_num = stat["LOAD"]["block_nums"][1] if load else 0 max_load_block_num = stat["LOAD"]["block_nums"][2] if load else 0 avg_load_cost_time = stat["LOAD"]["cost_times"][0] / load if load else 0.0 min_load_cost_time = stat["LOAD"]["cost_times"][1] if load else 0.0 max_load_cost_time = stat["LOAD"]["cost_times"][2] if load else 0.0 save = stat["SAVE"]["count"] avg_save_block_num = stat["SAVE"]["block_nums"][0] / save if save else 0.0 min_save_block_num = stat["SAVE"]["block_nums"][1] if save else 0 max_save_block_num = stat["SAVE"]["block_nums"][2] if save else 0 avg_save_cost_time = stat["SAVE"]["cost_times"][0] / save if save else 0.0 min_save_cost_time = stat["SAVE"]["cost_times"][1] if save else 0.0 max_save_cost_time = stat["SAVE"]["cost_times"][2] if save else 0.0 num_hit_blocks = stat["HIT"]["num_hit_blocks"] num_total_blocks = stat["HIT"]["num_total_blocks"] hit_rate = 100 * (num_hit_blocks / num_total_blocks) if num_total_blocks else 0.0 logger.info(f"[EMS][{self.PRINT_INTERVAL}Sec Summary] req(load={load},save={save}) " f"LOAD(avg_block_num={avg_load_block_num:.1f} [min:{min_load_block_num}, max:{max_load_block_num}]; " f"avg_cost_ms={avg_load_cost_time:.1f} [min:{min_load_cost_time:.1f}, max:{max_load_cost_time:.1f}]) " f"SAVE(avg_block_num={avg_save_block_num:.1f} [min:{min_save_block_num}, max:{max_save_block_num}]; " f"avg_cost_ms={avg_save_cost_time:.1f} [min:{min_save_cost_time:.1f}, max:{max_save_cost_time:.1f}]) " f"HIT(hit_rate={hit_rate:.1f} [hit:{num_hit_blocks}, total:{num_total_blocks}])") def task_loop(self) -> None: while True: time.sleep(self.HEALTH_CHECK_INTERVAL) self.check_health_status() cur_time = time.perf_counter() if cur_time - self.last_log_time > self.PRINT_INTERVAL: self.last_log_time = cur_time self.print_stat() def start_loop(self) -> None: logger.info("[EMS] EMS start periodic tasks.") self.check_health_status() self.periodic_task_thread = threading.Thread(target=self.task_loop, daemon=True, name="ems_task_loop") self.periodic_task_thread.start() logger.info(f"[EMS] periodic tasks subthread \"ems_task_loop\" started.") - ems_env.py
vi ems_env.py
内容如下:
# Copyright (c) HuaWei Technologies Co., Ltd. 2025-2025. All rights reserved import os class EmsEnv: llm_engine = os.environ.get("LLM_ENGINE", "vllm") model_id = os.environ.get("MODEL_ID", "cc_kvstore@_@ds_default_ns_001") service_name = os.environ.get("SERVICE_NAME", "deepseek") access_id = os.environ.get("ACCELERATE_ID", "cc_kvstore@_@ds_default_ns_001") access_key = os.environ.get("ACCELERATE_KEY", "") ems_timeout: int = int(os.environ.get("EMS_TIMEOUT", "5000")) ems_enable_write_rcache: bool = os.environ.get("EMS_ENABLE_WRITE_RCACHE", "1") == "1" ems_enable_read_local_only: bool = os.environ.get("EMS_ENABLE_READ_LOCAL_ONLY", "0") == "1" ems_num_min_reuse_tokens: int = int(os.environ.get("EMS_NUM_MIN_REUSE_TOKENS", "2048")) ems_num_min_load_blocks: int = int(os.environ.get("EMS_NUM_MIN_LOAD_BLOCKS", "1")) ems_lookup_key_server_ip = os.environ.get("EMS_LOOKUP_KEY_SERVER_IP", "127.0.0.1") ems_lookup_key_server_base_port = os.environ.get("EMS_LOOKUP_KEY_SERVER_BASE_PORT", "50005") - ems_store_connector.py
vi ems_store_connector.py
内容如下:
import os import threading from dataclasses import dataclass from typing import Any, Optional, Tuple, List, Set, Dict import torch import zmq from vllm.distributed import get_tp_group from vllm.v1.attention.backend import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.forward_context import ForwardContext from vllm.logger import logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request from vllm_ascend.distributed.ems_store.ems_adapter import EmsAdapter from vllm_ascend.distributed.ems_store.ems_env import EmsEnv @dataclass class CcReqMeta: req_id: str num_computed_blocks: int num_total_blocks: int block_hashes: List[int] block_ids: Optional[List[int]] operation: str @dataclass class CcConnectorMetadata(KVConnectorMetadata): def __init__(self): self.requests: List[CcReqMeta] = [] def add_requests(self, req_meta: CcReqMeta): self.requests.append(req_meta) class EmsStoreConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config=None ): super().__init__( vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config ) if role == KVConnectorRole.SCHEDULER: self.connector_scheduler = CcConnectorScheduler(vllm_config) if role == KVConnectorRole.WORKER: self.connector_worker = CcConnectorWorker(vllm_config) if get_tp_group().rank_in_group == 0: self.lookup_server = LookupKeyServer(vllm_config.parallel_config.data_parallel_index, self.connector_worker) ############################################################ # Scheduler Side Methods ############################################################ def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) -> Tuple[int, bool]: return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int): self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: return self.connector_scheduler.build_connector_meta(scheduler_output) def request_finished(self, request: Request, block_ids: List[int]) -> Tuple[bool, Dict[str, Any] | None]: return self.connector_scheduler.request_finished(request, block_ids) ############################################################ # Worker Side Methods ############################################################ def register_kv_caches(self, kv_caches: Dict[str, torch.Tensor]): self.connector_worker.register_kv_caches(kv_caches) def start_load_kv(self, forward_context: ForwardContext, **kwargs): self.connector_worker.start_load_kv(self._get_connector_metadata()) def wait_for_layer_load(self, layer_name: str): pass def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs): pass def wait_for_save(self): self.connector_worker.wait_for_save(self._get_connector_metadata()) def get_finished(self, finished_req_ids: Set[str]) -> Tuple[Set[str] | None, Set[str] | None]: return self.connector_worker.get_finished(finished_req_ids) def get_block_ids_with_load_errors(self) -> Set[int]: return self.connector_worker.get_block_ids_with_load_errors() class CcConnectorScheduler: def __init__(self, vllm_config: VllmConfig): logger.info(f"[EMS] CcConnectorScheduler init.") self.block_size = vllm_config.cache_config.block_size self.lookup_client = LookupKeyClient(vllm_config.parallel_config.data_parallel_index) self.processed_requests: Set[str] = set() self.meta_load_reqs: Dict[str, CcReqMeta] = {} self.need_save_req_ids: Set[str] = set() def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) -> Tuple[int, bool]: if request.request_id in self.processed_requests: logger.info(f"[EMS][Scheduler] req {request.request_id} already in processed requests.") return 0, False num_computed_blocks = num_computed_tokens // self.block_size num_total_blocks = (len(request.prompt_token_ids) - 1) // self.block_size if not self._need_load(num_computed_blocks, num_total_blocks): logger.info(f"[EMS][Scheduler] req {request.request_id} no need to load, num_computed_blocks: {num_computed_blocks}, " f"num_total_blocks: {num_total_blocks}.") return 0, False block_hashes = self._cal_block_hashes( request.prompt_token_ids, self.block_size )[num_computed_blocks:num_total_blocks] num_exist_blocks = self.lookup_client.lookup(request.request_id, block_hashes) num_total_blocks = num_computed_blocks + num_exist_blocks block_hashes = block_hashes[:num_exist_blocks] if not self._need_load(num_computed_blocks, num_total_blocks): logger.info(f"[EMS][Scheduler] req {request.request_id} num_total_blocks is updated,still no need to load, num_computed_blocks: {num_computed_blocks}, " f"num_computed_blocks + num_exist_blocks: {num_total_blocks}.") return 0, False req_meta = CcReqMeta( req_id=request.request_id, num_computed_blocks=num_computed_blocks, num_total_blocks=num_total_blocks, block_hashes=block_hashes, block_ids=None, operation="no-op" ) self.meta_load_reqs[request.request_id] = req_meta logger.debug(f"[EMS][Scheduler] req {request.request_id} meta: {req_meta}.") num_new_matched_tokens = (num_total_blocks - num_computed_blocks) * self.block_size logger.info(f"[EMS][Scheduler] matched result, req_id={request.request_id}, " f"num_computed_blocks={num_computed_blocks}, " f"num_exist_blocks={num_exist_blocks}, new num_computed_blocks={num_total_blocks}, " f"num_new_matched_tokens={num_new_matched_tokens}, prompt_len={len(request.prompt_token_ids)}") return num_new_matched_tokens, True def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int): if request.request_id not in self.meta_load_reqs: logger.debug(f"[EMS][Scheduler] req {request.request_id} not in meta_load_reqs.") return if num_external_tokens == 0: logger.debug(f"[EMS][Scheduler] req {request.request_id} num_external_tokens is 0, no need to update block ids.") self.meta_load_reqs[request.request_id].operation = "no-op" return req_meta = self.meta_load_reqs[request.request_id] all_block_ids = blocks.get_block_ids()[0] req_meta.num_total_blocks = min(req_meta.num_total_blocks, len(all_block_ids)) req_meta.block_ids = all_block_ids[req_meta.num_computed_blocks:req_meta.num_total_blocks] req_meta.block_hashes = req_meta.block_hashes[:len(req_meta.block_ids)] req_meta.operation = "load" self.processed_requests.add(request.request_id) logger.info(f"[EMS][Scheduler] req {request.request_id} update block ids, meta: {req_meta}.") def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: connector_meta = CcConnectorMetadata() for req_id, req_meta in self.meta_load_reqs.items(): if req_meta.operation != "load": continue connector_meta.add_requests(req_meta) self.meta_load_reqs.clear() for new_req in scheduler_output.scheduled_new_reqs: num_total_tokens = scheduler_output.num_scheduled_tokens[new_req.req_id] + new_req.num_computed_tokens num_computed_blocks = new_req.num_computed_tokens // self.block_size num_total_blocks = num_total_tokens // self.block_size if num_computed_blocks < num_total_blocks: block_hashes = self._cal_block_hashes(new_req.prompt_token_ids, self.block_size)[ num_computed_blocks:num_total_blocks] block_ids = new_req.block_ids[0][num_computed_blocks:num_total_blocks] req_meta = CcReqMeta( req_id=new_req.req_id, num_computed_blocks=num_computed_blocks, num_total_blocks=num_total_blocks, block_hashes=block_hashes, block_ids=block_ids, operation="save" ) logger.debug(f"[EMS] req {new_req.req_id} need save, meta: {req_meta}.") connector_meta.add_requests(req_meta) self.need_save_req_ids.add(new_req.req_id) return connector_meta def request_finished(self, request: Request, block_ids: List[int], ) -> Tuple[bool, Optional[Dict[str, Any]]]: if request.request_id in self.processed_requests: self.processed_requests.remove(request.request_id) need_save = request.request_id in self.need_save_req_ids self.need_save_req_ids.discard(request.request_id) logger.debug(f"[EMS][Scheduler] request finished, req {request.request_id}, " f"need_save={need_save}") return need_save, None def _need_load(self, num_computed_blocks: int, num_total_blocks: int) -> bool: # 长度小于ems_num_min_reuse_tokens,不load if num_total_blocks * self.block_size < EmsEnv.ems_num_min_reuse_tokens: return False # load block数量小于ems_num_min_load_blocks, 不load if num_total_blocks - num_computed_blocks <= EmsEnv.ems_num_min_load_blocks: return False return True def _cal_block_hashes(self, token_ids: List[int], block_size) -> List[int]: result: List[int] = [] prev_block_hash = 0 num_blocks = len(token_ids) // block_size for block_id in range(num_blocks): block_hash = self._cal_block_hash(token_ids[block_id * block_size:(block_id + 1) * block_size], prev_block_hash) result.append(block_hash) prev_block_hash = block_hash return result def _cal_block_hash(self, block_token_ids: List[int], prev_block_hash: int) -> int: return hash((prev_block_hash, *block_token_ids)) class CcConnectorWorker: def __init__(self, vllm_config: VllmConfig): self.ems_adapter = EmsAdapter(vllm_config) self.finished_req_ids = set() self.finished_save_req_ids = set() def register_kv_caches(self, kv_caches: Dict[str, Tuple[torch.Tensor]]): return self.ems_adapter.register_kv_caches(kv_caches) def start_load_kv(self, metadata: CcConnectorMetadata): for request in metadata.requests: if request.operation != "load": continue self.ems_adapter.async_load( request.req_id, request.block_hashes, request.block_ids, request.num_computed_blocks ) def wait_for_save(self, metadata: CcConnectorMetadata): for request in metadata.requests: if request.operation != "save": continue import torch_npu torch_npu.npu.synchronize() self.ems_adapter.async_save( request.req_id, request.block_hashes, request.block_ids ) def get_finished(self, finished_req_ids: Set[str]) -> Tuple[Set[str], Set[str]]: self.finished_req_ids.update(finished_req_ids) self.finished_save_req_ids.update(self.ems_adapter.get_finished_save_reqs()) finished_save_req_ids = self.finished_req_ids & self.finished_save_req_ids self.finished_req_ids -= finished_save_req_ids self.finished_save_req_ids -= finished_save_req_ids finished_load_req_ids = self.ems_adapter.get_finished_load_reqs() if finished_save_req_ids or finished_load_req_ids: logger.info(f"[EMS][Worker] finished_save_req_ids: {finished_save_req_ids}, " f"finished_load_req_ids: {finished_load_req_ids}.") return finished_save_req_ids, finished_load_req_ids def get_block_ids_with_load_errors(self) -> Set[int]: block_ids_with_load_errors = self.ems_adapter.get_block_ids_with_load_errors() if block_ids_with_load_errors: logger.info(f"[EMS][Worker] block_ids_with_load_errors: {block_ids_with_load_errors}") return block_ids_with_load_errors def _lookup_block_hashes(self, req_id: str, block_hashes: List[int]) -> int: return self.ems_adapter.exists_block_num(req_id, block_hashes) class LookupKeyServer: def __init__(self, data_parallel_rank, worker_connector): self.worker_connector = worker_connector self.zmq_ctx = zmq.Context() self.socket = self.zmq_ctx.socket(zmq.REP) port = int(EmsEnv.ems_lookup_key_server_base_port) + data_parallel_rank self.addr = f"tcp://{EmsEnv.ems_lookup_key_server_ip}:{port}" self.socket.bind(self.addr) def process_lookup(): while True: req_id, block_hashes = self.socket.recv_pyobj() num_exist_blocks = self.worker_connector._lookup_block_hashes(req_id, block_hashes) self.socket.send_pyobj(num_exist_blocks) self.thread = threading.Thread(target=process_lookup, daemon=True, name="ems_exist_proxy") self.thread.start() class LookupKeyClient: def __init__(self, data_parallel_rank): self.zmq_ctx = zmq.Context() self.socket = self.zmq_ctx.socket(zmq.REQ) port = int(EmsEnv.ems_lookup_key_server_base_port) + data_parallel_rank self.addr = f"tcp://{EmsEnv.ems_lookup_key_server_ip}:{port}" self.socket.connect(self.addr) def lookup(self, req_id: str, block_hashes: List[int]) -> int: self.socket.send_pyobj([req_id, block_hashes]) num_exist_blocks = self.socket.recv_pyobj() return num_exist_blocks
- __init__.py
- 创建kv_transfer文件夹,在文件夹下创建__init__.py文件。
mkdir kv_transfer vi __init__.py
__init__.py文件内容如下:
# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory def register_connector(): # override multi_connector as ascend_multi_connector if "MultiConnector" in KVConnectorFactory._registry: KVConnectorFactory._registry.pop("MultiConnector") KVConnectorFactory.register_connector( "MultiConnector", "vllm_ascend.distributed.kv_transfer.ascend_multi_connector", "AscendMultiConnector" ) KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", "MooncakeConnector" ) KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector", "AscendStoreConnector", ) KVConnectorFactory.register_connector( "AscendStoreConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector", "AscendStoreConnector", ) KVConnectorFactory.register_connector( "MooncakeLayerwiseConnector", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector", "MooncakeLayerwiseConnector", ) KVConnectorFactory.register_connector( "UCMConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ucm_connector", "UCMConnectorV1" ) KVConnectorFactory.register_connector( "LMCacheAscendConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.lmcache_ascend_connector", "LMCacheConnectorV1", ) KVConnectorFactory.register_connector( "EmsStoreConnector", "vllm_ascend.distributed.ems_store.ems_store_connector", "EmsStoreConnector", )
- 下载openssl和cyrus安装包。
- 制作镜像。
- 创建Dockerfile。
FROM quay.io/ascend/vllm-ascend:v0.18.0-a3 # Install OpenSSL 1.1.1w from source COPY openssl-1.1.1w.tar.gz . RUN tar xzvf openssl-1.1.1w.tar.gz \ && cd openssl-1.1.1w \ && ./config \ && make \ && make install \ && ln -sf /usr/local/lib/libssl.so.1.1 /usr/lib/aarch64-linux-gnu/libssl.so.1.1 \ && ln -sf /usr/local/lib/libcrypto.so.1.1 /usr/lib/aarch64-linux-gnu/libcrypto.so.1.1 \ && cd .. \ && rm -rf openssl-1.1.1w openssl-1.1.1w.tar.gz # Install cyrus-sasl 2.1.28 from source COPY cyrus-sasl-2.1.28.tar.gz . RUN tar xzvf cyrus-sasl-2.1.28.tar.gz \ && cd cyrus-sasl-2.1.28 \ && ./configure \ && make \ && make install \ && ln -sf /usr/local/lib/libsasl2.so.3 /usr/lib/aarch64-linux-gnu/libsasl2.so.3 \ && ln -sf /usr/local/lib/sasl2 /usr/lib/aarch64-linux-gnu/sasl2 \ && cd .. \ && rm -rf cyrus-sasl-2.1.28 cyrus-sasl-2.1.28.tar.gz RUN mkdir -p /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/ COPY ./ems_store/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/ COPY ./ems_store/ems_store_connector.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/ COPY ./ems_store/ems_env.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/ COPY ./ems_store/ems_adapter.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/ems_store/ COPY ./kv_transfer/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/distributed/kv_transfer/ COPY ems-26.3.0-cp311-cp311-linux_aarch64.whl . RUN pip install --no-deps ems-26.3.0-cp311-cp311-linux_aarch64.whl \ && rm ems-26.3.0-cp311-cp311-linux_aarch64.whl - 构建镜像。
docker build -t vllm-ascend-ems:v0.18.0-a3 .
- 创建Dockerfile。
- 创建包含启动脚本的ConfigMap。根据所选模型的不同,启动脚本参数需相应调整,具体可参见vllm-ascend。
kubectl apply -f config.yaml
部署模板示如下所示。其中nic_name可通过ip route | grep default命令获取。
kind: ConfigMap apiVersion: v1 metadata: name: kimi-k25-pd-cm data: prefill.sh: | nic_name="enp23s0f3" # network card name local_ip=$POD_IP export HCCL_IF_IP=$local_ip export GLOO_SOCKET_IFNAME=$nic_name export TP_SOCKET_IFNAME=$nic_name export HCCL_SOCKET_IFNAME=$nic_name export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD sysctl -w vm.swappiness=0 sysctl -w kernel.numa_balancing=0 sysctl kernel.sched_migration_cost_ns=50000 export VLLM_RPC_TIMEOUT=3600000 export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=30000 export HCCL_OP_EXPANSION_MODE="AIV" export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export OMP_PROC_BIND=false export OMP_NUM_THREADS=1 export TASK_QUEUE_ENABLE=1 export ASCEND_BUFFER_POOL=4:8 export HCCL_BUFFSIZE=256 export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 # --- EMS Configuration Exports --- export EMS_NUM_MIN_REUSE_TOKENS=0 export EMS_NUM_MIN_LOAD_BLOCKS=0 export EMS_BLOCK_GROUP_SIZE=2 export LD_LIBRARY_PATH=/usr/local/python3.11.14/lib/python3.11/site-packages/ems/lib:$LD_LIBRARY_PATH export EMS_LOOKUP_KEY_SERVER_BASE_PORT=60687 vllm serve $MODEL_LOCATION \ --host $POD_IP \ --port "7100" \ --data-parallel-size 2 \ --data-parallel-address $POD_IP \ --data-parallel-rpc-port 12321 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --seed 1024 \ --quantization ascend \ --served-model-name kimi_k25 \ --trust-remote-code \ --max-num-seqs 8 \ --max-model-len 32768 \ --max-num-batched-tokens 16384 \ --disable-hybrid-kv-cache-manager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.8 \ --enforce-eager \ --speculative-config '{"method": "eagle3", "model":"/models/kimi-k2.5-eagle3", "num_speculative_tokens": 3}' \ --additional-config '{"recompute_scheduler_enable":true}' \ --mm-encoder-tp-mode data \ --kv-transfer-config \ '{ "kv_connector": "EmsStoreConnector", "kv_role": "kv_producer" }' decode.sh: | nic_name="enp23s0f3" # network card name local_ip=$POD_IP export HCCL_IF_IP=$local_ip export GLOO_SOCKET_IFNAME=$nic_name export TP_SOCKET_IFNAME=$nic_name export HCCL_SOCKET_IFNAME=$nic_name export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD sysctl -w vm.swappiness=0 sysctl -w kernel.numa_balancing=0 sysctl kernel.sched_migration_cost_ns=50000 export VLLM_RPC_TIMEOUT=3600000 export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=30000 export HCCL_OP_EXPANSION_MODE="AIV" export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export OMP_PROC_BIND=false export OMP_NUM_THREADS=1 export TASK_QUEUE_ENABLE=1 export ASCEND_BUFFER_POOL=4:8 export HCCL_BUFFSIZE=2200 export VLLM_ASCEND_ENABLE_MLAPO=0 # --- EMS Configuration Exports --- export EMS_NUM_MIN_REUSE_TOKENS=0 export EMS_NUM_MIN_LOAD_BLOCKS=0 export EMS_BLOCK_GROUP_SIZE=2 export LD_LIBRARY_PATH=/usr/local/python3.11.14/lib/python3.11/site-packages/ems/lib:$LD_LIBRARY_PATH export EMS_LOOKUP_KEY_SERVER_BASE_PORT=60687 vllm serve $MODEL_LOCATION \ --host $POD_IP \ --port "7100" \ --data-parallel-size 16 \ --data-parallel-address $POD_IP \ --data-parallel-rpc-port 12322 \ --tensor-parallel-size 1 \ --enable-expert-parallel \ --seed 1024 \ --quantization ascend \ --served-model-name kimi_k25 \ --trust-remote-code \ --max-num-seqs 48 \ --max-model-len 32768 \ --max-num-batched-tokens 256 \ --disable-hybrid-kv-cache-manager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 \ --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes":[4,8,16,32,48,64,80,96,112,128,144,160]}' \ --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": false}' \ --speculative-config '{"method": "eagle3", "model":"/models/kimi-k2.5-eagle3", "num_speculative_tokens": 3}' \ --kv-transfer-config \ '{ "kv_connector": "EmsStoreConnector", "kv_role": "kv_consumer" }' - 部署ModelServing。
kubectl apply -f kimi-k25-serv.yaml
部署模板示如下所示。
apiVersion: workload.serving.volcano.sh/v1alpha1 kind: ModelServing metadata: name: kimi-k25-pd namespace: default spec: schedulerName: volcano replicas: 1 plugins: - name: pod-discovery recoveryPolicy: ServingGroupRecreate template: restartGracePeriodSeconds: 60 roles: - name: prefill replicas: 1 workerReplicas: 0 entryTemplate: spec: hostNetwork: true containers: - name: prefill image: docker.io/library/vllm-ascend-ems:v0.18.0-a3 command: - /bin/bash args: - '-c' - cd /workspace && ./prefill.sh env: - name: ROLE value: "prefill" - name: GROUP_NAME valueFrom: fieldRef: fieldPath: metadata.labels['modelserving.volcano.sh/group-name'] - name: ROLE_ID valueFrom: fieldRef: fieldPath: metadata.labels['modelserving.volcano.sh/role-id'] - name: POD_IP valueFrom: fieldRef: fieldPath: status.podIP - name: NODE_IP valueFrom: fieldRef: fieldPath: status.hostIP - name: MODEL_LOCATION value: /models/Kimi-K2.5-w4a8 - name: TP_SIZE value: "8" - name: DP_SIZE value: "2" readinessProbe: httpGet: path: /health port: 7100 scheme: HTTP initialDelaySeconds: 60 periodSeconds: 10 timeoutSeconds: 2 failureThreshold: 3 resources: limits: cpu: '188' huawei.com/ascend-1980: '16' memory: 1800Gi requests: cpu: '64' huawei.com/ascend-1980: '16' memory: 700Gi ports: - containerPort: 7100 name: server volumeMounts: - name: model mountPath: /models - name: dshm mountPath: /dev/shm - name: ems-shm mountPath: /dev/shm/ems - name: hccn-conf mountPath: /etc/hccn.conf - name: hccn-tool mountPath: /usr/local/Ascend/driver/tools/hccn_tool - name: ascend-install-info mountPath: /etc/ascend_install.info - name: config mountPath: /workspace/prefill.sh subPath: prefill.sh volumes: - name: model hostPath: path: /mnt/paas/models type: Directory - name: dshm emptyDir: medium: Memory - name: ems-shm hostPath: path: /mnt/paas/kubernetes/kubelet/ems type: DirectoryOrCreate - name: hccn-conf hostPath: path: /etc/hccn.conf - name: hccn-tool hostPath: path: /usr/local/Ascend/driver/tools/hccn_tool - name: ascend-install-info hostPath: path: /etc/ascend_install.info - name: config configMap: name: kimi-k25-pd-cm defaultMode: 0777 - name: decode replicas: 1 workerReplicas: 0 entryTemplate: spec: hostNetwork: true containers: - name: decode image: docker.io/library/vllm-ascend-ems:v0.18.0-a3 command: - /bin/bash args: - '-c' - cd /workspace && ./decode.sh env: - name: ROLE value: "decode" - name: ENGINE_ID valueFrom: fieldRef: fieldPath: metadata.name - name: POD_IP valueFrom: fieldRef: fieldPath: status.podIP - name: NODE_IP valueFrom: fieldRef: fieldPath: status.hostIP - name: GROUP_NAME valueFrom: fieldRef: fieldPath: metadata.labels['modelserving.volcano.sh/group-name'] - name: ROLE_ID valueFrom: fieldRef: fieldPath: metadata.labels['modelserving.volcano.sh/role-id'] - name: MODEL_LOCATION value: /models/Kimi-K2.5-w4a8 - name: TP_SIZE value: "1" - name: DP_SIZE value: "16" readinessProbe: httpGet: path: /health port: 7100 scheme: HTTP initialDelaySeconds: 60 periodSeconds: 10 timeoutSeconds: 2 failureThreshold: 3 ports: - containerPort: 7100 name: server resources: limits: cpu: '188' huawei.com/ascend-1980: '16' memory: 1800Gi requests: cpu: '64' huawei.com/ascend-1980: '16' memory: 700Gi volumeMounts: - name: model mountPath: /models - name: dshm mountPath: /dev/shm - name: ems-shm mountPath: /dev/shm/ems - name: hccn-conf mountPath: /etc/hccn.conf - name: hccn-tool mountPath: /usr/local/Ascend/driver/tools/hccn_tool - name: ascend-install-info mountPath: /etc/ascend_install.info - name: config mountPath: /workspace/decode.sh subPath: decode.sh volumes: - name: model hostPath: path: /mnt/paas/models type: Directory - name: dshm emptyDir: medium: Memory - name: ems-shm hostPath: path: /mnt/paas/kubernetes/kubelet/ems type: DirectoryOrCreate - name: hccn-conf hostPath: path: /etc/hccn.conf - name: hccn-tool hostPath: path: /usr/local/Ascend/driver/tools/hccn_tool - name: ascend-install-info hostPath: path: /etc/ascend_install.info - name: config configMap: name: kimi-k25-pd-cm defaultMode: 0777 - 配置代理。
- 下载代理脚本。
wget https://raw.githubusercontent.com/vllm-project/vllm-ascend/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py
- 获取Pod IP。
kubectl get pods -owide
返回示例如下。
NAME READY STATUS RESTARTS AGE IP NODE NOMINATED NODE READINESS GATES kimi-k25-pd-0-decode-0-0 1/1 Running 0 9m57s 192.168.0.156 192.168.0.156 <none> <none> kimi-k25-pd-0-prefill-0-0 1/1 Running 0 9m57s 192.168.0.163 192.168.0.163 <none> <none>
- 启动代理服务。
python3 load_balance_proxy_server_example.py \ --port 8181 \ --host 0.0.0.0 \ --prefiller-hosts 192.168.0.163 \ --prefiller-ports 7100 \ --decoder-hosts 192.168.0.156 \ --decoder-ports 7100
- 下载代理脚本。
- 服务验证。使用curl发送标准Chat API请求,验证推理链路是否正常。
- 发送请求。
curl -X POST http://192.168.0.163:8181/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "kimi_k25", "messages": [ { "role": "user", "content": "Hello, how are you?" } ], "max_tokens": 100 }' - 预期返回结果。
若部署成功,应返回包含choices和usage字段的JSON数据。
{ "id": "chatcmpl-fd56a4a3-829f-49c7-9ffa-f890e8******", "object": "chat.completion", "created": 1779713507, "model": "kimi_k25", "choices": [ { "index": 0, "message": { "role": "assistant", "content": " The user is greeting me and asking how I am. This is a standard conversational opening. I should respond politely, let them know I'm doing well (as an AI, I don't have feelings, but it's social convention to respond positively), and invite them to continue the conversation or ask how I can help them today.\n\nI should keep it friendly, concise, and open-ended so they can tell me what they need help with. </think> Hello! I'm doing well, thank you for asking. How can I", "refusal": null, "annotations": null, "audio": null, "function_call": null, "tool_calls": [], "reasoning": null }, "logprobs": null, "finish_reason": "length", "stop_reason": null, "token_ids": null } ], "service_tier": null, "system_fingerprint": null, "usage": { "prompt_tokens": 14, "total_tokens": 114, "completion_tokens": 100, "prompt_tokens_details": null, "completion_tokens_details": null }, "prompt_logprobs": null, "prompt_token_ids": null, "kv_transfer_params": null } - 代理侧日志确认。
INFO: Started server process [303949] INFO: Waiting for application startup. Initialized 1 prefill clients and 1 decode clients. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:8181 (Press CTRL+C to quit) INFO: 192.168.0.163:37794 - "POST /v1/chat/completions HTTP/1.1" 200 OK
- 发送请求。