diff --git a/recipes/launch.slurm b/recipes/launch.slurm new file mode 100644 index 0000000..28d4bee --- /dev/null +++ b/recipes/launch.slurm @@ -0,0 +1,92 @@ +#!/bin/bash +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/h4/logs/%x-%j.out +#SBATCH --err=/fsx/h4/logs/%x-%j.err + +set -x -e + +source ~/.bashrc +conda activate handbook +echo "START TIME: $(date)" + +MODEL=$1 +TASK=$2 +VERSION=$3 +ACCELERATOR=$4 +OPTIONAL_ARGS=$5 + +# Training setup +NUM_NODES=$SLURM_NNODES +GPUS_PER_NODE=8 +WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE)) +# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match +CONFIG_FILE=recipes/$MODEL/$TASK/config_$VERSION.yaml +GRAD_ACC_STEPS=$(yq -r .gradient_accumulation_steps $CONFIG_FILE) + +# Split the string into individual arguments +IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS" + +# Loop through the arguments and find the one with "--gradient_accumulation_steps" +for arg in "${ARGS[@]}"; do + if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then + # Extract the value after the equals sign + GRAD_ACC_STEPS="${arg#*=}" + break # Exit the loop once we find the desired argument + fi +done + +echo "Gradient accumulation steps: $GRAD_ACC_STEPS" +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 + +export CMD=" \ + scripts/run_$TASK.py $CONFIG_FILE + " + +export LAUNCHER="ACCELERATE_LOG_LEVEL=info accelerate launch \ + --config_file recipes/accelerate_configs/$ACCELERATOR.yaml \ + --gradient_accumulation_steps $GRAD_ACC_STEPS \ + --num_machines $NUM_NODES \ + --num_processes $WORLD_SIZE \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --machine_rank \$SLURM_PROCID \ + --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \ + --max_restarts 1 \ + --role \$(hostname -s): \ + --tee 3 \ + " + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1 + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/recipes/zephyr-7b/README.md b/recipes/zephyr-7b/README.md new file mode 100644 index 0000000..cb3847a --- /dev/null +++ b/recipes/zephyr-7b/README.md @@ -0,0 +1,7 @@ + + +## SFT + +```shell +sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b sft full +``` \ No newline at end of file diff --git a/recipes/zephyr-7b/sft/config.yaml b/recipes/zephyr-7b/sft/config_full.yaml similarity index 100% rename from recipes/zephyr-7b/sft/config.yaml rename to recipes/zephyr-7b/sft/config_full.yaml diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 116a654..614a14a 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -90,9 +90,9 @@ def main(): for index in random.sample(range(len(raw_datasets["train"])), 3): logger.info(f"Sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['messages']}") - ##################################### - # Load tokenizer and process datasets - ##################################### + ################ + # Load tokenizer + ################ tokenizer = get_tokenizer(model_args, data_args) #####################