Files
ml-debug/docs/evidence/axolotl_training_stability.md
wassname fb753d093e restructure: quotes-first SKILL.md, synthesized playbook split out
SKILL.md is now folklore only: verbatim practitioner quotes ordered
most-general-first, transformer/LLM fine-tuning entries in their own
section, minimal context, links and footnotes. New sources: unsloth,
axolotl (+training stability), HF course ch8.4, Bekman debug_utils
(evidence frozen in docs/evidence/).

The synthesized material (mental models, priors, symptom tables, agent
loop, triage, anti-patterns) moves to PLAYBOOK.md, framed as menus of
hypotheses rather than authoritative diagnoses. Made-up symptom tables
no longer sit next to sourced quotes.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 14:33:32 +08:00

400 lines
16 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
Source: https://docs.axolotl.ai/docs/training_stability.html
Title: Training Stability & Debugging - Axolotl documentation (undated, fetched 2026)
Fetched-via: uvx markitdown https://docs.axolotl.ai/docs/training_stability.html
Fetch-status: verbatim, nav/sidebar/TOC boilerplate trimmed
# Training Stability & Debugging
Guide to monitoring, debugging, and stabilizing training runs in axolotl
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
## Monitoring Training
### Key Metrics for SFT
Every SFT run should be monitored through at least these four metrics:
| Metric | What It Tells You | Healthy Range |
| --- | --- | --- |
| `train/loss` | How well the model fits training data | Decreasing; typically 0.52.0 for chat fine-tuning |
| `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
| `grad_norm` | Gradient magnitude | 0.110.0; spikes above 100 indicate instability |
| `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) |
TipSet Up Logging Early
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
```
wandb_project: my-project
wandb_run_id: # optional, for resuming
logging_steps: 1
```
### Key Metrics for RL (GRPO)
GRPO training logs a richer set of metrics. These are the critical ones:
| Metric | Healthy Range | Red Flag |
| --- | --- | --- |
| `rewards/<name>/mean` | > 0.15 within 20 steps | Stays at 0 reward function is broken or task is too hard |
| `reward_std` | > 0 on most steps | Always 0 no learning signal (all completions get the same reward) |
| `frac_reward_zero_std` | < 0.8 | 1.0 on every step zero-advantage skip fires constantly, no gradient updates |
| `grad_norm` | 0.0011.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
| `entropy` | 0.050.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
| `kl` | 0.00.5 | > 2.0 suggests policy has diverged too far from reference |
| `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
| `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` |
| `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive |
| `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking |
| `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` increase it |
NoteEBFT-Specific Metrics
For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.30.9), `ebft/diversity` (healthy 0.010.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10).
## SFT Stability
### Loss Plateau
**Symptom**: Loss stops decreasing early in training, well above expected values.
**Causes and fixes**:
* **Learning rate too low**: Increase by 25x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
* **Insufficient warmup**: Set `warmup_steps` to 510% of total steps. Too-aggressive learning at the start can push the model into a flat region.
* **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable.
* **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
### Loss Spikes
**Symptom**: Loss suddenly jumps by 210x then (possibly) recovers.
**Causes and fixes**:
* **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches.
* **Learning rate too high**: Reduce by 25x, or increase warmup.
* **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise.
* **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis.
### Overfitting
**Symptom**: Train loss keeps decreasing but eval loss starts increasing.
**Fixes**:
* Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`.
* Reduce `num_epochs` or `max_steps`.
* Increase `weight_decay` (try 0.010.1).
* Use a smaller LoRA rank (`lora_r`). Typical values: 832.
* Increase dropout: `lora_dropout: 0.05`.
## RL/GRPO Stability
### Reward Never Increases
If `rewards/*/mean` stays at 0 for more than 20 steps:
1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values.
```
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
```
2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform.
3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.
4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config.
### Entropy Collapse (Mode Collapse)
**Symptom**: `entropy` drops below 0.01; all completions become nearly identical.
**Fixes**:
* Increase `temperature` in generation kwargs (try 0.81.0).
* Reduce learning rate.
* Add a KL penalty term (`beta` parameter in GRPO config).
* Check that `num_generations` is sufficient (16+ gives better advantage estimates).
### IS Ratio Divergence
**Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0.
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
**Fixes**:
* Decrease `vllm_sync_interval` (sync weights more often).
* Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples.
* Use `importance_sampling_level: token` for finer-grained correction.
### Gradient Norm Instability
**Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly.
**Fixes**:
* Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs).
* Reduce learning rate.
* Increase `gradient_accumulation_steps` to smooth out noisy batches.
* Check for NaN issues (see next section).
## NaN and Inf Handling
### Common Causes
| Cause | Where It Manifests | Detection |
| --- | --- | --- |
| FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately |
| Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN |
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
### FP8-Specific NaN Issues
FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
**Fixes applied in axolotl**:
* The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`.
* A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`.
* Embedding padding is zero-padded for FP8 compatibility.
ImportantAfter Modifying Triton Kernels
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
```
rm -rf ~/.triton/cache
```
### General NaN Debugging Steps
1. **Enable anomaly detection** (slow, but pinpoints the source):
```
torch.autograd.set_detect_anomaly(True)
```
2. **Check grad\_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad\_norm was fine on the previous step, the forward pass is the problem.
3. **Reduce to single GPU, single batch**: Eliminate distributed training variables.
4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
## OOM Debugging
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
### Step 1: Reduce Batch Size
The single highest-impact change. VRAM scales roughly linearly with batch size.
```
micro_batch_size: 1 # Start here
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
```
For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is:
```
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
```
Reduce `micro_batch_size` to 24 for GRPO.
### Step 2: Enable Gradient Checkpointing
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
```
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false # Recommended default
```
WarningReentrant Checkpointing Exceptions
Some configurations require `use_reentrant: true`:
* DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`)
* EBFT strided mode with flex\_attention
### Step 3: Use Quantization
Load the base model in reduced precision:
```
# 4-bit QLoRA
adapter: qlora
load_in_4bit: true
# 8-bit
load_in_8bit: true
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
fp8: true
```
### Step 4: Reduce Sequence Length
```
sequence_len: 1024 # Down from 2048 or 4096
```
For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention.
### Step 5: Use Flash Attention
Reduces attention memory from O(n^2) to O(n):
```
attn_implementation: flash_attention_2
```
### Step 6: Offload with DeepSpeed
For extreme cases, offload optimizer states or parameters to CPU:
```
deepspeed: deepspeed_configs/zero3_bf16.json
```
### Diagnosing the Specific Culprit
Use the `profiler_steps` config option to capture GPU memory snapshots:
```
profiler_steps: [1, 2]
```
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
## Common Errors
| Error Message | Likely Cause | Fix |
| --- | --- | --- |
| `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers |
| `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` |
| `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above |
| `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](../docs/nccl.html); check `NCCL_DEBUG=INFO` output |
| `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config |
| `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config |
| `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check |
| `generation_batch_size not divisible by num_generations` | micro\_batch\_size too small | Set `micro_batch_size >= num_generations` and make it divisible |
| `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled |
| `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex\_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
| `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotls `weight_serde.py` (auto bf16 to fp16 conversion) |
| `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent |
| `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` |
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP |
## Profiling
### PyTorch Profiler
Axolotl supports PyTorch profiler integration via the config:
```
profiler_steps: [1, 2, 3]
```
This captures profiler traces for the specified steps. View them in TensorBoard:
```
tensorboard --logdir output_dir/runs
```
Or open the `.json` trace file in `chrome://tracing`.
### CUDA Memory Snapshots
For detailed memory analysis, use PyTorchs memory snapshot API. Add this to your training script or use it interactively:
```
import torch
# Enable memory history tracking
torch.cuda.memory._record_memory_history()
# ... run your training step ...
# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
```
Visualize with PyTorchs memory visualizer:
```
python -m torch.cuda.memory._viz memory_snapshot.pickle
```
### Quick GPU Memory Check
During training, monitor GPU utilization in a separate terminal:
```
watch -n 1 nvidia-smi
```
For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorchs view of memory, which may differ from `nvidia-smi` (see [FAQ](../docs/faq.html)).
## W&B and Logging
### Enabling Logging
```
wandb_project: my-project
wandb_entity: my-team # optional
wandb_run_id: run-123 # optional, for resuming
wandb_name: experiment-name # optional
logging_steps: 1 # log every step (recommended for RL)
```
### Debug Logging
For detailed axolotl-internal debug output:
```
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
```
TipAlways Log to a File
Pipe training output to a log file so you can inspect it after the run:
```
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
```
### What Axolotl Logs
**SFT metrics** (logged every `logging_steps`):
* `train/loss`, `eval/loss` training and validation loss
* `train/grad_norm` gradient L2 norm (before clipping)
* `train/learning_rate` current learning rate
* `memory/max_alloc`, `memory/max_reserved` peak GPU memory
**GRPO/RL metrics** (logged every step):
* `rewards/<name>/mean`, `rewards/<name>/std` per-reward-function statistics
* `reward`, `reward_std` aggregated reward across all reward functions
* `frac_reward_zero_std` fraction of prompt groups where all completions got the same reward
* `completions/mean_length`, `completions/min_length`, `completions/max_length` completion token lengths
* `completions/clipped_ratio` fraction of completions that hit the max length
* `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` lengths of naturally terminated completions
* `kl` KL divergence between policy and reference
* `entropy` policy entropy (measure of output diversity)
* `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` PPO clipping statistics
* `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` log-probability difference between policy and sampling distribution
* `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` IS ratio statistics for off-policy correction
* `num_tokens` total tokens processed
### Reading W&B Charts
For a healthy GRPO run, expect to see:
1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.30.8 depending on task difficulty. Not monotonic fluctuations are normal.
2. **`entropy`**: Gradual decrease from initial values (often 0.30.6) as the model becomes more confident. Should not collapse to near-zero.
3. **`grad_norm`**: Mostly in the 0.0011.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
5. **`completions/mean_length`**: Should reflect the tasks natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.