更新时间:2025-07-29 GMT+08:00
分享

rank_table_tools.py

rank_table_tools.py工具类用于生成rank table文件。

wait_ki_rank_table_completed.pyget_ip_list.py中会用到,无需修改。

import json
import os
import time

GLOBAL_RANK_TABLE_PATH_ENV = 'START_UP_GLOBAL_RANK_TABLE_FILE_PATH'

class RankTableTools:
    def __init__(self, device_type: str = 'a2'):
        self._device_type = device_type

    @property
    def device_type(self):
        return self._device_type

    @staticmethod
    def load_rank_table_file(file_path):
        with open(file_path, 'r') as f:
            return json.load(f)

    def get_ki_global_rank_table(self):
        ki_global_rank_table_path = os.getenv(GLOBAL_RANK_TABLE_PATH_ENV)
        return self.load_rank_table_file(ki_global_rank_table_path)

    def wait_ki_global_rank_table_ready(self):
        while True:
            ki_global_rank_table = self.get_ki_global_rank_table()
            if ki_global_rank_table["status"] == "completed":
                print("Rank table status is completed!")
                return
            time.sleep(1)

    @staticmethod
    def __has_specify_server(rank_table, server_ip):
        server_group_list = rank_table['server_group_list']
        for group in server_group_list:
            server_list = group["server_list"]
            for i in range(len(server_list)):
                if server_list[i]["server_ip"] == server_ip:
                    return True
        return False

    def gen_global_merged_rank_table(
            self,
            in_rank_table: dict,
            scheduler_group_id: int,
            num_p: int,
            p_group_id_list: list,
            num_d: int,
            d_group_id_list: list):
        global_merged_rank_table = {
            "version": "1.0",
            "status": "completed",
            "server_group_list": []
        }
        if self.device_type == 'a2':
            # 生成scheduler group
            self.__gen_global_s_rank_table(
                in_rank_table, global_merged_rank_table, scheduler_group_id)
        # 生成p group
        next_global_group_id = self.__gen_global_p_d_rank_table(
            in_rank_table, global_merged_rank_table, 1 if self.device_type == 'a2' else 0, p_group_id_list, num_p)
        # 生成d group
        self.__gen_global_p_d_rank_table(
            in_rank_table, global_merged_rank_table, next_global_group_id, d_group_id_list, num_d)
        return global_merged_rank_table

    @staticmethod
    def __gen_global_s_rank_table(in_rank_table, out_rank_table, instance_group_id):
        out_rank_table["server_group_list"].append({
            "group_id": "0",
            "server_count": 1,
            "server_list": []
        })
        out_rank_table["server_group_list"][0]["server_list"].append({
            "server_id": in_rank_table["server_group_list"][instance_group_id]["server_list"][0]["server_id"],
            "server_ip": in_rank_table["server_group_list"][instance_group_id]["server_list"][0]["server_ip"]
        })

    def __gen_global_p_d_rank_table(
            self, in_rank_table, out_rank_table, out_rt_start_group_id, instance_group_id_list, num_instance):
        forward_step = int(len(instance_group_id_list) / num_instance)
        for instance_idx in range(0, num_instance):
            out_rank_table["server_group_list"].append({
                "group_id": str(out_rt_start_group_id),
                "server_count": 1 if self.device_type == 'a2' else forward_step,
                "server_list": []
            })
            rank_id = 0
            for server_idx in range(0, forward_step):
                group_id_idx = instance_idx * forward_step + server_idx
                server = in_rank_table["server_group_list"][instance_group_id_list[group_id_idx]]["server_list"][0]
                if (self.device_type == 'a2' and server_idx == 0) or self.device_type == 'a3':
                    out_rank_table["server_group_list"][out_rt_start_group_id]["server_list"].append({
                        "server_id": server["server_id"],
                        "server_ip": server["server_ip"],
                        "device": []
                    })
                rank_id = self.__append_global_rank_table_device(
                    out_rank_table,
                    out_rt_start_group_id,
                    0 if self.device_type == 'a2' else server_idx,
                    server["device"],
                    rank_id)
            out_rt_start_group_id += 1
        return out_rt_start_group_id

    @staticmethod
    def __append_global_rank_table_device(out_rank_table, out_rt_start_group_id, server_idx, devices: list, rank_id):
        for device_idx in range(0, len(devices)):
            out_rank_table["server_group_list"][out_rt_start_group_id]["server_list"][server_idx]["device"].append({
                "device_id": devices[device_idx]["device_id"],
                "device_ip": devices[device_idx]["device_ip"],
                "rank_id": str(rank_id)
            })
            rank_id += 1
        return rank_id

    def gen_local_merged_rank_table(self, in_rank_table: dict, group_id_list: list):
        local_merged_rank_table = {
            "version": "1.0",
            "status": "completed",
            "group_id": "0",
            "server_count": 1 if self.device_type == 'a2' else len(group_id_list),
            "server_list": []
        }
        rank_id = 0
        for group_id_idx in range(0, len(group_id_list)):
            server = in_rank_table["server_group_list"][group_id_list[group_id_idx]]["server_list"][0]
            if (self.device_type == 'a2' and group_id_idx == 0) or self.device_type == 'a3':
                local_merged_rank_table["server_list"].append({
                    "server_id": server["server_id"],
                    "server_ip": server["server_ip"],
                    "device": []
                })
            for device_idx in range(0, len(server["device"])):
                server_idx = 0 if self.device_type == 'a2' else group_id_idx
                local_merged_rank_table["server_list"][server_idx]["device"].append({
                    "device_id": server["device"][device_idx]["device_id"],
                    "device_ip": server["device"][device_idx]["device_ip"],
                    "rank_id": str(rank_id)
                })
                rank_id += 1
        return local_merged_rank_table

    @staticmethod
    def gen_local_scheduler_rank_table(in_rank_table: dict, scheduler_group_id: int):
        server = in_rank_table["server_group_list"][scheduler_group_id]["server_list"][0]
        return {
            "version": "1.0",
            "status": "completed",
            "group_id": "0",
            "server_count": "1",
            "server_list": [{
                "server_id": server["server_id"],
                "server_ip": server["server_ip"]
            }]
        }

相关文档