SKILL.md is now folklore only: verbatim practitioner quotes ordered most-general-first, transformer/LLM fine-tuning entries in their own section, minimal context, links and footnotes. New sources: unsloth, axolotl (+training stability), HF course ch8.4, Bekman debug_utils (evidence frozen in docs/evidence/). The synthesized material (mental models, priors, symptom tables, agent loop, triage, anti-patterns) moves to PLAYBOOK.md, framed as menus of hypotheses rather than authoritative diagnoses. Made-up symptom tables no longer sit next to sourced quotes. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
16 KiB
Source: https://docs.axolotl.ai/docs/training_stability.html Title: Training Stability & Debugging - Axolotl documentation (undated, fetched 2026) Fetched-via: uvx markitdown https://docs.axolotl.ai/docs/training_stability.html Fetch-status: verbatim, nav/sidebar/TOC boilerplate trimmed
Training Stability & Debugging
Guide to monitoring, debugging, and stabilizing training runs in axolotl
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
Monitoring Training
Key Metrics for SFT
Every SFT run should be monitored through at least these four metrics:
| Metric | What It Tells You | Healthy Range |
|---|---|---|
train/loss |
How well the model fits training data | Decreasing; typically 0.5–2.0 for chat fine-tuning |
eval/loss |
Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
grad_norm |
Gradient magnitude | 0.1–10.0; spikes above 100 indicate instability |
learning_rate |
Current LR from scheduler | Should follow expected schedule (warmup then decay) |
TipSet Up Logging Early
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
wandb_project: my-project
wandb_run_id: # optional, for resuming
logging_steps: 1
Key Metrics for RL (GRPO)
GRPO training logs a richer set of metrics. These are the critical ones:
| Metric | Healthy Range | Red Flag |
|---|---|---|
rewards/<name>/mean |
> 0.15 within 20 steps | Stays at 0 – reward function is broken or task is too hard |
reward_std |
> 0 on most steps | Always 0 – no learning signal (all completions get the same reward) |
frac_reward_zero_std |
< 0.8 | 1.0 on every step – zero-advantage skip fires constantly, no gradient updates |
grad_norm |
0.001–1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
entropy |
0.05–0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
kl |
0.0–0.5 | > 2.0 suggests policy has diverged too far from reference |
sampling/sampling_logp_difference/mean |
< 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
sampling/importance_sampling_ratio/min |
> 0.1 | Near 0 indicates stale off-policy data; increase vllm_sync_interval |
clip_ratio/region_mean |
< 0.1 | > 0.3 means PPO clipping is too aggressive |
completions/mean_length |
Task-dependent | Monotonically increasing to max length suggests reward hacking |
completions/clipped_ratio |
< 0.3 | > 0.8 means most completions hit max_completion_length – increase it |
NoteEBFT-Specific Metrics
For EBFT training, also monitor ebft/alignment (should trend upward, healthy 0.3–0.9), ebft/diversity (healthy 0.01–0.1; > 1.0 indicates mode collapse), and ebft/cfm_loss (should trend downward, < 10).
SFT Stability
Loss Plateau
Symptom: Loss stops decreasing early in training, well above expected values.
Causes and fixes:
- Learning rate too low: Increase by 2–5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
- Insufficient warmup: Set
warmup_stepsto 5–10% of total steps. Too-aggressive learning at the start can push the model into a flat region. - Data quality: Check that labels are correctly masked. Use
axolotl preprocessand inspect tokenized samples to confirm only the target tokens are trainable. - Weight decay too high: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
Loss Spikes
Symptom: Loss suddenly jumps by 2–10x then (possibly) recovers.
Causes and fixes:
- Bad data samples: A single malformed or extremely long example can cause a spike. Enable
sample_packing: falsetemporarily and check if spikes correlate with specific batches. - Learning rate too high: Reduce by 2–5x, or increase warmup.
- Gradient accumulation mismatch: Effective batch size =
micro_batch_size * gradient_accumulation_steps * num_gpus. Very large effective batch sizes amplify gradient noise. - Mixed precision issues: With
bf16: true, some operations can lose precision. If spikes are severe, tryfp32for diagnosis.
Overfitting
Symptom: Train loss keeps decreasing but eval loss starts increasing.
Fixes:
- Increase
val_set_size(e.g., 0.05) and monitoreval/loss. - Reduce
num_epochsormax_steps. - Increase
weight_decay(try 0.01–0.1). - Use a smaller LoRA rank (
lora_r). Typical values: 8–32. - Increase dropout:
lora_dropout: 0.05.
RL/GRPO Stability
Reward Never Increases
If rewards/*/mean stays at 0 for more than 20 steps:
-
Test reward function standalone: Run it outside training with known inputs to verify it returns nonzero values.
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))" -
Check dataset columns: The reward function receives
**kwargscontaining dataset columns. Verify the columns it needs (e.g.,answer) are not removed by the dataset transform. -
Check completion content: Enable
log_completions: truein thetrl:config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task. -
Verify vLLM is serving the right model: Hit the vLLM health endpoint and confirm the model name matches your config.
Entropy Collapse (Mode Collapse)
Symptom: entropy drops below 0.01; all completions become nearly identical.
Fixes:
- Increase
temperaturein generation kwargs (try 0.8–1.0). - Reduce learning rate.
- Add a KL penalty term (
betaparameter in GRPO config). - Check that
num_generationsis sufficient (16+ gives better advantage estimates).
IS Ratio Divergence
Symptom: sampling/importance_sampling_ratio/min drops near 0, or sampling/sampling_logp_difference/mean exceeds 1.0.
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
Fixes:
- Decrease
vllm_sync_interval(sync weights more often). - Enable
off_policy_mask_threshold(e.g., 0.5) to mask stale off-policy samples. - Use
importance_sampling_level: tokenfor finer-grained correction.
Gradient Norm Instability
Symptom: grad_norm oscillates wildly or exceeds 10.0 regularly.
Fixes:
- Enable gradient clipping:
max_grad_norm: 1.0(default in most configs). - Reduce learning rate.
- Increase
gradient_accumulation_stepsto smooth out noisy batches. - Check for NaN issues (see next section).
NaN and Inf Handling
Common Causes
| Cause | Where It Manifests | Detection |
|---|---|---|
| FP8 zero-scale division | Forward pass logits | grad_norm: nan, loss becomes NaN immediately |
| Gradient explosion | Backward pass | grad_norm spikes to inf, then loss goes NaN |
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
FP8-Specific NaN Issues
FP8 quantization (fp8: true) can produce NaN when the activation quantization kernel divides by max(abs(x)) / 448. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
Fixes applied in axolotl:
- The
act_quant_kernelhas a zero-guard:s = tl.where(s == 0, 1.0, s). - A safety net
nan_to_num(logits, nan=0.0)is applied in_get_per_token_logps_and_entropies. - Embedding padding is zero-padded for FP8 compatibility.
ImportantAfter Modifying Triton Kernels
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
rm -rf ~/.triton/cache
General NaN Debugging Steps
-
Enable anomaly detection (slow, but pinpoints the source):
torch.autograd.set_detect_anomaly(True) -
Check grad_norm: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.
-
Reduce to single GPU, single batch: Eliminate distributed training variables.
-
Inspect data: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
OOM Debugging
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
Step 1: Reduce Batch Size
The single highest-impact change. VRAM scales roughly linearly with batch size.
micro_batch_size: 1 # Start here
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
For GRPO specifically, the logits tensor for policy logprob computation can be very large. batch_size * num_generations * seq_len * vocab_size in bf16. For example, with num_generations: 16 and micro_batch_size: 8, the logits tensor alone is:
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
Reduce micro_batch_size to 2–4 for GRPO.
Step 2: Enable Gradient Checkpointing
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false # Recommended default
WarningReentrant Checkpointing Exceptions
Some configurations require use_reentrant: true:
- DeepSpeed ZeRO-3 (non-reentrant causes
CheckpointError) - EBFT strided mode with flex_attention
Step 3: Use Quantization
Load the base model in reduced precision:
# 4-bit QLoRA
adapter: qlora
load_in_4bit: true
# 8-bit
load_in_8bit: true
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
fp8: true
Step 4: Reduce Sequence Length
sequence_len: 1024 # Down from 2048 or 4096
For GRPO, also reduce max_completion_length. Memory scales quadratically with sequence length when using standard attention.
Step 5: Use Flash Attention
Reduces attention memory from O(n^2) to O(n):
attn_implementation: flash_attention_2
Step 6: Offload with DeepSpeed
For extreme cases, offload optimizer states or parameters to CPU:
deepspeed: deepspeed_configs/zero3_bf16.json
Diagnosing the Specific Culprit
Use the profiler_steps config option to capture GPU memory snapshots:
profiler_steps: [1, 2]
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
Common Errors
| Error Message | Likely Cause | Fix |
|---|---|---|
exitcode: -9 |
System RAM exhaustion | Reduce dataset size, dataset_num_proc, or number of data workers |
exitcode: -7 (DeepSpeed) |
DeepSpeed version issue | pip install -U deepspeed |
CUDA out of memory |
GPU VRAM exhaustion | Follow OOM debugging steps above |
RuntimeError: NCCL communicator was aborted |
GPU communication failure | See NCCL docs; check NCCL_DEBUG=INFO output |
ValueError: Asking to pad but the tokenizer does not have a padding token |
Missing pad token | Add special_tokens: { pad_token: "<|endoftext|>" } to config |
'DummyOptim' object has no attribute 'step' |
DeepSpeed on single GPU | Remove deepspeed: section from config |
unable to load strategy X then None is not callable |
Reward module not importable | Run cd experiments && python -c "import my_rewards" to check |
generation_batch_size not divisible by num_generations |
micro_batch_size too small | Set micro_batch_size >= num_generations and make it divisible |
'weight' must be 2-D |
FSDP1 flattened parameters | Use fsdp_version: 2 or skip unwrap_model when FSDP is enabled |
CheckpointError (tensor count mismatch) |
Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set use_reentrant: true in gradient_checkpointing_kwargs |
BFloat16 TypeError during weight sync |
NumPy does not support bf16 | Fixed in axolotl’s weight_serde.py (auto bf16 to fp16 conversion) |
Content end boundary is before start boundary |
Chat template parsing issue | Check eos_token matches template; file a GitHub issue if persistent |
CAS service error during data processing |
HuggingFace XET issue | Set export HF_HUB_DISABLE_XET=1 |
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set async_prefetch: false with FSDP |
Profiling
PyTorch Profiler
Axolotl supports PyTorch profiler integration via the config:
profiler_steps: [1, 2, 3]
This captures profiler traces for the specified steps. View them in TensorBoard:
tensorboard --logdir output_dir/runs
Or open the .json trace file in chrome://tracing.
CUDA Memory Snapshots
For detailed memory analysis, use PyTorch’s memory snapshot API. Add this to your training script or use it interactively:
import torch
# Enable memory history tracking
torch.cuda.memory._record_memory_history()
# ... run your training step ...
# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
Visualize with PyTorch’s memory visualizer:
python -m torch.cuda.memory._viz memory_snapshot.pickle
Quick GPU Memory Check
During training, monitor GPU utilization in a separate terminal:
watch -n 1 nvidia-smi
For programmatic access within axolotl, the logged metrics memory/max_alloc and memory/max_reserved come from torch.cuda.max_memory_allocated() and torch.cuda.max_memory_reserved(). Note these report PyTorch’s view of memory, which may differ from nvidia-smi (see FAQ).
W&B and Logging
Enabling Logging
wandb_project: my-project
wandb_entity: my-team # optional
wandb_run_id: run-123 # optional, for resuming
wandb_name: experiment-name # optional
logging_steps: 1 # log every step (recommended for RL)
Debug Logging
For detailed axolotl-internal debug output:
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
TipAlways Log to a File
Pipe training output to a log file so you can inspect it after the run:
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
What Axolotl Logs
SFT metrics (logged every logging_steps):
train/loss,eval/loss– training and validation losstrain/grad_norm– gradient L2 norm (before clipping)train/learning_rate– current learning ratememory/max_alloc,memory/max_reserved– peak GPU memory
GRPO/RL metrics (logged every step):
rewards/<name>/mean,rewards/<name>/std– per-reward-function statisticsreward,reward_std– aggregated reward across all reward functionsfrac_reward_zero_std– fraction of prompt groups where all completions got the same rewardcompletions/mean_length,completions/min_length,completions/max_length– completion token lengthscompletions/clipped_ratio– fraction of completions that hit the max lengthcompletions/mean_terminated_length,completions/min_terminated_length,completions/max_terminated_length– lengths of naturally terminated completionskl– KL divergence between policy and referenceentropy– policy entropy (measure of output diversity)clip_ratio/region_mean,clip_ratio/low_mean,clip_ratio/high_mean– PPO clipping statisticssampling/sampling_logp_difference/mean,sampling/sampling_logp_difference/max– log-probability difference between policy and sampling distributionsampling/importance_sampling_ratio/min,sampling/importance_sampling_ratio/mean,sampling/importance_sampling_ratio/max– IS ratio statistics for off-policy correctionnum_tokens– total tokens processed
Reading W&B Charts
For a healthy GRPO run, expect to see:
reward/mean: Gradual upward trend. May start near 0 and reach 0.3–0.8 depending on task difficulty. Not monotonic – fluctuations are normal.entropy: Gradual decrease from initial values (often 0.3–0.6) as the model becomes more confident. Should not collapse to near-zero.grad_norm: Mostly in the 0.001–1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.kl: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.completions/mean_length: Should reflect the task’s natural answer length. If it steadily increases tomax_completion_length, the model may be reward-hacking by generating longer outputs.