文档首页/
AI开发平台ModelArts/
最佳实践/
LLM大语言模型推理/
LLM大语言模型推理历史版本文档/
主流开源大模型基于Lite Server&Cluster适配Ascend-vLLM PyTorch NPU推理指导(6.5.905)/
附录/
rank_table_tools.py
更新时间:2025-07-29 GMT+08:00
rank_table_tools.py
rank_table_tools.py工具类用于生成rank table文件。
在wait_ki_rank_table_completed.py和get_ip_list.py中会用到,无需修改。
import json import os import time GLOBAL_RANK_TABLE_PATH_ENV = 'START_UP_GLOBAL_RANK_TABLE_FILE_PATH' class RankTableTools: def __init__(self, device_type: str = 'a2'): self._device_type = device_type @property def device_type(self): return self._device_type @staticmethod def load_rank_table_file(file_path): with open(file_path, 'r') as f: return json.load(f) def get_ki_global_rank_table(self): ki_global_rank_table_path = os.getenv(GLOBAL_RANK_TABLE_PATH_ENV) return self.load_rank_table_file(ki_global_rank_table_path) def wait_ki_global_rank_table_ready(self): while True: ki_global_rank_table = self.get_ki_global_rank_table() if ki_global_rank_table["status"] == "completed": print("Rank table status is completed!") return time.sleep(1) @staticmethod def __has_specify_server(rank_table, server_ip): server_group_list = rank_table['server_group_list'] for group in server_group_list: server_list = group["server_list"] for i in range(len(server_list)): if server_list[i]["server_ip"] == server_ip: return True return False def gen_global_merged_rank_table( self, in_rank_table: dict, scheduler_group_id: int, num_p: int, p_group_id_list: list, num_d: int, d_group_id_list: list): global_merged_rank_table = { "version": "1.0", "status": "completed", "server_group_list": [] } if self.device_type == 'a2': # 生成scheduler group self.__gen_global_s_rank_table( in_rank_table, global_merged_rank_table, scheduler_group_id) # 生成p group next_global_group_id = self.__gen_global_p_d_rank_table( in_rank_table, global_merged_rank_table, 1 if self.device_type == 'a2' else 0, p_group_id_list, num_p) # 生成d group self.__gen_global_p_d_rank_table( in_rank_table, global_merged_rank_table, next_global_group_id, d_group_id_list, num_d) return global_merged_rank_table @staticmethod def __gen_global_s_rank_table(in_rank_table, out_rank_table, instance_group_id): out_rank_table["server_group_list"].append({ "group_id": "0", "server_count": 1, "server_list": [] }) out_rank_table["server_group_list"][0]["server_list"].append({ "server_id": in_rank_table["server_group_list"][instance_group_id]["server_list"][0]["server_id"], "server_ip": in_rank_table["server_group_list"][instance_group_id]["server_list"][0]["server_ip"] }) def __gen_global_p_d_rank_table( self, in_rank_table, out_rank_table, out_rt_start_group_id, instance_group_id_list, num_instance): forward_step = int(len(instance_group_id_list) / num_instance) for instance_idx in range(0, num_instance): out_rank_table["server_group_list"].append({ "group_id": str(out_rt_start_group_id), "server_count": 1 if self.device_type == 'a2' else forward_step, "server_list": [] }) rank_id = 0 for server_idx in range(0, forward_step): group_id_idx = instance_idx * forward_step + server_idx server = in_rank_table["server_group_list"][instance_group_id_list[group_id_idx]]["server_list"][0] if (self.device_type == 'a2' and server_idx == 0) or self.device_type == 'a3': out_rank_table["server_group_list"][out_rt_start_group_id]["server_list"].append({ "server_id": server["server_id"], "server_ip": server["server_ip"], "device": [] }) rank_id = self.__append_global_rank_table_device( out_rank_table, out_rt_start_group_id, 0 if self.device_type == 'a2' else server_idx, server["device"], rank_id) out_rt_start_group_id += 1 return out_rt_start_group_id @staticmethod def __append_global_rank_table_device(out_rank_table, out_rt_start_group_id, server_idx, devices: list, rank_id): for device_idx in range(0, len(devices)): out_rank_table["server_group_list"][out_rt_start_group_id]["server_list"][server_idx]["device"].append({ "device_id": devices[device_idx]["device_id"], "device_ip": devices[device_idx]["device_ip"], "rank_id": str(rank_id) }) rank_id += 1 return rank_id def gen_local_merged_rank_table(self, in_rank_table: dict, group_id_list: list): local_merged_rank_table = { "version": "1.0", "status": "completed", "group_id": "0", "server_count": 1 if self.device_type == 'a2' else len(group_id_list), "server_list": [] } rank_id = 0 for group_id_idx in range(0, len(group_id_list)): server = in_rank_table["server_group_list"][group_id_list[group_id_idx]]["server_list"][0] if (self.device_type == 'a2' and group_id_idx == 0) or self.device_type == 'a3': local_merged_rank_table["server_list"].append({ "server_id": server["server_id"], "server_ip": server["server_ip"], "device": [] }) for device_idx in range(0, len(server["device"])): server_idx = 0 if self.device_type == 'a2' else group_id_idx local_merged_rank_table["server_list"][server_idx]["device"].append({ "device_id": server["device"][device_idx]["device_id"], "device_ip": server["device"][device_idx]["device_ip"], "rank_id": str(rank_id) }) rank_id += 1 return local_merged_rank_table @staticmethod def gen_local_scheduler_rank_table(in_rank_table: dict, scheduler_group_id: int): server = in_rank_table["server_group_list"][scheduler_group_id]["server_list"][0] return { "version": "1.0", "status": "completed", "group_id": "0", "server_count": "1", "server_list": [{ "server_id": server["server_id"], "server_ip": server["server_ip"] }] }
父主题: 附录