mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:22:17 +08:00
Make it work for realz
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=production-cluster
|
||||
#SBATCH --output=/fsx/h4/logs/%x-%j.out
|
||||
#SBATCH --err=/fsx/h4/logs/%x-%j.err
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
conda activate handbook
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
MODEL=$1
|
||||
TASK=$2
|
||||
VERSION=$3
|
||||
ACCELERATOR=$4
|
||||
OPTIONAL_ARGS=$5
|
||||
|
||||
# Training setup
|
||||
NUM_NODES=$SLURM_NNODES
|
||||
GPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
|
||||
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
|
||||
CONFIG_FILE=recipes/$MODEL/$TASK/config_$VERSION.yaml
|
||||
GRAD_ACC_STEPS=$(yq -r .gradient_accumulation_steps $CONFIG_FILE)
|
||||
|
||||
# Split the string into individual arguments
|
||||
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
|
||||
|
||||
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
|
||||
for arg in "${ARGS[@]}"; do
|
||||
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
|
||||
# Extract the value after the equals sign
|
||||
GRAD_ACC_STEPS="${arg#*=}"
|
||||
break # Exit the loop once we find the desired argument
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
|
||||
# so processes know who to talk to
|
||||
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
MASTER_PORT=6000
|
||||
|
||||
export CMD=" \
|
||||
scripts/run_$TASK.py $CONFIG_FILE
|
||||
"
|
||||
|
||||
export LAUNCHER="ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
|
||||
--gradient_accumulation_steps $GRAD_ACC_STEPS \
|
||||
--num_machines $NUM_NODES \
|
||||
--num_processes $WORLD_SIZE \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port $MASTER_PORT \
|
||||
--machine_rank \$SLURM_PROCID \
|
||||
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
|
||||
--max_restarts 1 \
|
||||
--role \$(hostname -s): \
|
||||
--tee 3 \
|
||||
"
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# AWS specific
|
||||
export NCCL_PROTO=simple
|
||||
export RDMAV_FORK_SAFE=1
|
||||
export FI_EFA_FORK_SAFE=1
|
||||
export FI_EFA_USE_DEVICE_RDMA=1
|
||||
export FI_PROVIDER=efa
|
||||
export FI_LOG_LEVEL=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_SOCKET_IFNAME=ens
|
||||
|
||||
# srun error handling:
|
||||
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
|
||||
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
|
||||
SRUN_ARGS=" \
|
||||
--wait=60 \
|
||||
--kill-on-bad-exit=1 \
|
||||
"
|
||||
|
||||
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
@@ -0,0 +1,7 @@
|
||||
|
||||
|
||||
## SFT
|
||||
|
||||
```shell
|
||||
sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b sft full
|
||||
```
|
||||
+3
-3
@@ -90,9 +90,9 @@ def main():
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['messages']}")
|
||||
|
||||
#####################################
|
||||
# Load tokenizer and process datasets
|
||||
#####################################
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
#####################
|
||||
|
||||
Reference in New Issue
Block a user