更新时间:2026-02-05 GMT+08:00
分享

准备训练脚本

本案例提供训练脚本样例a2_multi.sh,需要用户参考本章节内容制作sh文件,并上传到OBS桶中。

脚本中的python命令后的强化学习配置参数,可以直接增加,用以覆盖examples/math/gsm8k_grpo_npu.yaml的默认值,对于yaml中不存在但是实际有效的参数,在参数名称前增加"+"即可生效。部分配置参数说明如下,更多配置参数,参考开源AReaL的yaml文档

部分yaml配置解释

参数

参数说明

示例值

total_train_epochs

训练总轮数

10

total_train_steps

训练总步数,实际步数取total_train_epochs*单轮步数和total_train_steps较小的那个。该参数在gsm8k_grpo_npu.yaml文件中不存在,可以在脚本中使用"+total_train_steps=50"指定

50

cluster.n_gpus_per_node

每个节点的NPU数量

8

cluster.n_nodes

使用的总节点数

2

gconfig.max_new_tokens

推理生成的最大句长

2048

train_dataset.batch_size

训练数据时的batch size

256

gconfig.n_samples

推理时的采样数量

4

saver.freq_steps

检查点保存频率,按训练步数计

50

save.freq_epochs

检查点保存频率,按训练轮数计

1

a2_multi.sh脚本样例

a2_multi.sh脚本样例内容如下:

#!/bin/bash
#set -e

ws="/home/ma-user/AReaL"
export HCCL_IF_BASE_PORT=62100
export HCCL_NPU_SOCKET_PORT_RANGE="62000-62050"
unset https_proxy http_proxy proxy ASCEND_RT_VISIBLE_DEVICES
source /usr/local/Ascend/nnal/atb/set_env.sh
# 按照实际情况修改
NODE_NUM=${VC_WORKER_NUM:-1}
MASTER_ADDR="${VC_WORKER_HOSTS%%,*}"
DOMAIN_IP=$(python -c "import socket; print(socket.gethostbyname('${MASTER_ADDR}'))")
MASTER_NODE_IP="${DOMAIN_IP}"
MASTER_NODE_PORT="6699"
MODEL_PATH=${MODEL_PATH}
DATASET_PATH=${DATASET_PATH}
ALLOCATION_MODE=${ALLOCATION_MODE}
timestamp=$(date +%s)
GBS=${GBS:-128}
FILE_ROOT_PATH=${FILE_ROOT_PATH:-"/home/ma-user/areal_exp"}
CKPT_ROOT_PATH=${CKPT_ROOT_PATH:-"/home/ma-user/ckpts/areal_space"}
MAX_WAIT_SECONDS=900
CHECK_INTERVAL=3

get_current_ip() {
    local ip
    ip=$(ip -4 addr | grep -v '127.0.0.1' | grep -oP '(?<=inet\s)\d+(\.\d+){3}' | head -n1 2>/dev/null)
    if [[ -z "$ip" ]]; then
        ip=$(ip -4 addr | grep -v '127.0.0.1' | awk '/inet/ {gsub(/\/.*/,""); print $2}' | head -n1 2>/dev/null)
    fi
    echo "${ip:-}"
}

CURRENT_IP=$(get_current_ip)
if [[ -z "$CURRENT_IP" ]]; then
    echo "无法提取当前节点IP"
    exit 1
fi
echo "current ip:$CURRENT_IP,master ip:$MASTER_NODE_IP"


if [ "$CURRENT_IP" = "$MASTER_NODE_IP" ]; then
    echo "start Ray master node..."
    ray stop --force
    ray start --head --port="$MASTER_NODE_PORT"

    elapsed_seconds=0
    echo "wait $NODE_NUM nodes..."
    while true; do
        current_node_num=$(ray status | grep -c "node_" || echo 0)
        if ! [[ "$current_node_num" =~ ^[0-9]+$ ]]; then
            current_node_num=0
        fi
        echo "current_node_num:$current_node_num"
        if [ "$current_node_num" -ge "$NODE_NUM" ]; then
            echo "all $NODE_NUM workers node is ready."
            break
        fi

        if [ "$elapsed_seconds" -ge "$MAX_WAIT_SECONDS" ]; then
            echo "timeout($MAX_WAIT_SECONDS s),current_node_num:$current_node_num"
            ray stop --force
            exit 1
        fi
        sleep "$CHECK_INTERVAL"
        elapsed_seconds=$((elapsed_seconds + CHECK_INTERVAL))
    done


    echo "start job......"
    
    echo "args is MODEL_PATH=${MODEL_PATH} DATASET_PATH=${DATASET_PATH} GBS=${GBS} FILE_ROOT_PATH=${FILE_ROOT_PATH} CKPT_ROOT_PATH=${CKPT_ROOT_PATH}"
    python -u -m areal.launcher.ray "${ws}/examples/math/gsm8k_rl.py" \
        --config "${ws}/examples/math/gsm8k_grpo_npu.yaml" rollout.dump_to_file=false \
        experiment_name=${timestamp}-gsm8k-grpo-multi cluster.fileroot=${FILE_ROOT_PATH} \
        trial_name=trial-0 saver.freq_steps=20 saver.fileroot=${CKPT_ROOT_PATH} recover.fileroot=${CKPT_ROOT_PATH} \
        allocation_mode=${ALLOCATION_MODE} \
        cluster.n_nodes="$NODE_NUM" \
        cluster.n_gpus_per_node=${MA_NUM_GPUS} cluster.name_resolve.type=ray  \
        gconfig.max_new_tokens=2048 \
        scheduler.type=ray \
        actor.path=${MODEL_PATH} \
        train_dataset.path=${DATASET_PATH} \
        valid_dataset.path=${DATASET_PATH} \
        train_dataset.batch_size=${GBS} | tee /home/ma-user/modelarts/log/latest.log
	
    sleep 60s
    tail -n 500 /home/ma-user/modelarts/log/latest.log | grep -q -E "Training completes|JobState.COMPLETED" && exit 0 || exit 1
	

else
    echo "start workers,connect to master:$MASTER_NODE_IP:$MASTER_NODE_PORT"
    sleep 90s
    ray start \
        --address="${MASTER_NODE_IP}:${MASTER_NODE_PORT}"
fi

相关文档