mirror of
https://github.com/wassname/ml-debug.git
synced 2026-06-27 16:15:57 +08:00
fb753d093e
SKILL.md is now folklore only: verbatim practitioner quotes ordered most-general-first, transformer/LLM fine-tuning entries in their own section, minimal context, links and footnotes. New sources: unsloth, axolotl (+training stability), HF course ch8.4, Bekman debug_utils (evidence frozen in docs/evidence/). The synthesized material (mental models, priors, symptom tables, agent loop, triage, anti-patterns) moves to PLAYBOOK.md, framed as menus of hypotheses rather than authoritative diagnoses. Made-up symptom tables no longer sit next to sourced quotes. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
400 lines
16 KiB
Markdown
400 lines
16 KiB
Markdown
Source: https://docs.axolotl.ai/docs/training_stability.html
|
||
Title: Training Stability & Debugging - Axolotl documentation (undated, fetched 2026)
|
||
Fetched-via: uvx markitdown https://docs.axolotl.ai/docs/training_stability.html
|
||
Fetch-status: verbatim, nav/sidebar/TOC boilerplate trimmed
|
||
|
||
# Training Stability & Debugging
|
||
|
||
Guide to monitoring, debugging, and stabilizing training runs in axolotl
|
||
|
||
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
|
||
|
||
## Monitoring Training
|
||
|
||
### Key Metrics for SFT
|
||
|
||
Every SFT run should be monitored through at least these four metrics:
|
||
|
||
| Metric | What It Tells You | Healthy Range |
|
||
| --- | --- | --- |
|
||
| `train/loss` | How well the model fits training data | Decreasing; typically 0.5–2.0 for chat fine-tuning |
|
||
| `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
|
||
| `grad_norm` | Gradient magnitude | 0.1–10.0; spikes above 100 indicate instability |
|
||
| `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) |
|
||
|
||
TipSet Up Logging Early
|
||
|
||
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
|
||
|
||
```
|
||
wandb_project: my-project
|
||
wandb_run_id: # optional, for resuming
|
||
logging_steps: 1
|
||
```
|
||
|
||
### Key Metrics for RL (GRPO)
|
||
|
||
GRPO training logs a richer set of metrics. These are the critical ones:
|
||
|
||
| Metric | Healthy Range | Red Flag |
|
||
| --- | --- | --- |
|
||
| `rewards/<name>/mean` | > 0.15 within 20 steps | Stays at 0 – reward function is broken or task is too hard |
|
||
| `reward_std` | > 0 on most steps | Always 0 – no learning signal (all completions get the same reward) |
|
||
| `frac_reward_zero_std` | < 0.8 | 1.0 on every step – zero-advantage skip fires constantly, no gradient updates |
|
||
| `grad_norm` | 0.001–1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
|
||
| `entropy` | 0.05–0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
|
||
| `kl` | 0.0–0.5 | > 2.0 suggests policy has diverged too far from reference |
|
||
| `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
|
||
| `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` |
|
||
| `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive |
|
||
| `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking |
|
||
| `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` – increase it |
|
||
|
||
NoteEBFT-Specific Metrics
|
||
|
||
For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.3–0.9), `ebft/diversity` (healthy 0.01–0.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10).
|
||
|
||
## SFT Stability
|
||
|
||
### Loss Plateau
|
||
|
||
**Symptom**: Loss stops decreasing early in training, well above expected values.
|
||
|
||
**Causes and fixes**:
|
||
|
||
* **Learning rate too low**: Increase by 2–5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
|
||
* **Insufficient warmup**: Set `warmup_steps` to 5–10% of total steps. Too-aggressive learning at the start can push the model into a flat region.
|
||
* **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable.
|
||
* **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
|
||
|
||
### Loss Spikes
|
||
|
||
**Symptom**: Loss suddenly jumps by 2–10x then (possibly) recovers.
|
||
|
||
**Causes and fixes**:
|
||
|
||
* **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches.
|
||
* **Learning rate too high**: Reduce by 2–5x, or increase warmup.
|
||
* **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise.
|
||
* **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis.
|
||
|
||
### Overfitting
|
||
|
||
**Symptom**: Train loss keeps decreasing but eval loss starts increasing.
|
||
|
||
**Fixes**:
|
||
|
||
* Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`.
|
||
* Reduce `num_epochs` or `max_steps`.
|
||
* Increase `weight_decay` (try 0.01–0.1).
|
||
* Use a smaller LoRA rank (`lora_r`). Typical values: 8–32.
|
||
* Increase dropout: `lora_dropout: 0.05`.
|
||
|
||
## RL/GRPO Stability
|
||
|
||
### Reward Never Increases
|
||
|
||
If `rewards/*/mean` stays at 0 for more than 20 steps:
|
||
|
||
1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values.
|
||
|
||
```
|
||
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
|
||
```
|
||
2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform.
|
||
3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.
|
||
4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config.
|
||
|
||
### Entropy Collapse (Mode Collapse)
|
||
|
||
**Symptom**: `entropy` drops below 0.01; all completions become nearly identical.
|
||
|
||
**Fixes**:
|
||
|
||
* Increase `temperature` in generation kwargs (try 0.8–1.0).
|
||
* Reduce learning rate.
|
||
* Add a KL penalty term (`beta` parameter in GRPO config).
|
||
* Check that `num_generations` is sufficient (16+ gives better advantage estimates).
|
||
|
||
### IS Ratio Divergence
|
||
|
||
**Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0.
|
||
|
||
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
|
||
|
||
**Fixes**:
|
||
|
||
* Decrease `vllm_sync_interval` (sync weights more often).
|
||
* Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples.
|
||
* Use `importance_sampling_level: token` for finer-grained correction.
|
||
|
||
### Gradient Norm Instability
|
||
|
||
**Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly.
|
||
|
||
**Fixes**:
|
||
|
||
* Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs).
|
||
* Reduce learning rate.
|
||
* Increase `gradient_accumulation_steps` to smooth out noisy batches.
|
||
* Check for NaN issues (see next section).
|
||
|
||
## NaN and Inf Handling
|
||
|
||
### Common Causes
|
||
|
||
| Cause | Where It Manifests | Detection |
|
||
| --- | --- | --- |
|
||
| FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately |
|
||
| Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN |
|
||
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
|
||
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
|
||
|
||
### FP8-Specific NaN Issues
|
||
|
||
FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
|
||
|
||
**Fixes applied in axolotl**:
|
||
|
||
* The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`.
|
||
* A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`.
|
||
* Embedding padding is zero-padded for FP8 compatibility.
|
||
|
||
ImportantAfter Modifying Triton Kernels
|
||
|
||
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
|
||
|
||
```
|
||
rm -rf ~/.triton/cache
|
||
```
|
||
|
||
### General NaN Debugging Steps
|
||
|
||
1. **Enable anomaly detection** (slow, but pinpoints the source):
|
||
|
||
```
|
||
torch.autograd.set_detect_anomaly(True)
|
||
```
|
||
2. **Check grad\_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad\_norm was fine on the previous step, the forward pass is the problem.
|
||
3. **Reduce to single GPU, single batch**: Eliminate distributed training variables.
|
||
4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
|
||
|
||
## OOM Debugging
|
||
|
||
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
|
||
|
||
### Step 1: Reduce Batch Size
|
||
|
||
The single highest-impact change. VRAM scales roughly linearly with batch size.
|
||
|
||
```
|
||
micro_batch_size: 1 # Start here
|
||
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
|
||
```
|
||
|
||
For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is:
|
||
|
||
```
|
||
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
|
||
```
|
||
|
||
Reduce `micro_batch_size` to 2–4 for GRPO.
|
||
|
||
### Step 2: Enable Gradient Checkpointing
|
||
|
||
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
|
||
|
||
```
|
||
gradient_checkpointing: true
|
||
gradient_checkpointing_kwargs:
|
||
use_reentrant: false # Recommended default
|
||
```
|
||
|
||
WarningReentrant Checkpointing Exceptions
|
||
|
||
Some configurations require `use_reentrant: true`:
|
||
|
||
* DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`)
|
||
* EBFT strided mode with flex\_attention
|
||
|
||
### Step 3: Use Quantization
|
||
|
||
Load the base model in reduced precision:
|
||
|
||
```
|
||
# 4-bit QLoRA
|
||
adapter: qlora
|
||
load_in_4bit: true
|
||
|
||
# 8-bit
|
||
load_in_8bit: true
|
||
|
||
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
|
||
fp8: true
|
||
```
|
||
|
||
### Step 4: Reduce Sequence Length
|
||
|
||
```
|
||
sequence_len: 1024 # Down from 2048 or 4096
|
||
```
|
||
|
||
For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention.
|
||
|
||
### Step 5: Use Flash Attention
|
||
|
||
Reduces attention memory from O(n^2) to O(n):
|
||
|
||
```
|
||
attn_implementation: flash_attention_2
|
||
```
|
||
|
||
### Step 6: Offload with DeepSpeed
|
||
|
||
For extreme cases, offload optimizer states or parameters to CPU:
|
||
|
||
```
|
||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||
```
|
||
|
||
### Diagnosing the Specific Culprit
|
||
|
||
Use the `profiler_steps` config option to capture GPU memory snapshots:
|
||
|
||
```
|
||
profiler_steps: [1, 2]
|
||
```
|
||
|
||
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
|
||
|
||
## Common Errors
|
||
|
||
| Error Message | Likely Cause | Fix |
|
||
| --- | --- | --- |
|
||
| `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers |
|
||
| `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` |
|
||
| `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above |
|
||
| `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](../docs/nccl.html); check `NCCL_DEBUG=INFO` output |
|
||
| `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config |
|
||
| `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config |
|
||
| `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check |
|
||
| `generation_batch_size not divisible by num_generations` | micro\_batch\_size too small | Set `micro_batch_size >= num_generations` and make it divisible |
|
||
| `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled |
|
||
| `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex\_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||
| `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotl’s `weight_serde.py` (auto bf16 to fp16 conversion) |
|
||
| `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent |
|
||
| `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` |
|
||
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP |
|
||
|
||
## Profiling
|
||
|
||
### PyTorch Profiler
|
||
|
||
Axolotl supports PyTorch profiler integration via the config:
|
||
|
||
```
|
||
profiler_steps: [1, 2, 3]
|
||
```
|
||
|
||
This captures profiler traces for the specified steps. View them in TensorBoard:
|
||
|
||
```
|
||
tensorboard --logdir output_dir/runs
|
||
```
|
||
|
||
Or open the `.json` trace file in `chrome://tracing`.
|
||
|
||
### CUDA Memory Snapshots
|
||
|
||
For detailed memory analysis, use PyTorch’s memory snapshot API. Add this to your training script or use it interactively:
|
||
|
||
```
|
||
import torch
|
||
|
||
# Enable memory history tracking
|
||
torch.cuda.memory._record_memory_history()
|
||
|
||
# ... run your training step ...
|
||
|
||
# Save snapshot
|
||
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
||
```
|
||
|
||
Visualize with PyTorch’s memory visualizer:
|
||
|
||
```
|
||
python -m torch.cuda.memory._viz memory_snapshot.pickle
|
||
```
|
||
|
||
### Quick GPU Memory Check
|
||
|
||
During training, monitor GPU utilization in a separate terminal:
|
||
|
||
```
|
||
watch -n 1 nvidia-smi
|
||
```
|
||
|
||
For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorch’s view of memory, which may differ from `nvidia-smi` (see [FAQ](../docs/faq.html)).
|
||
|
||
## W&B and Logging
|
||
|
||
### Enabling Logging
|
||
|
||
```
|
||
wandb_project: my-project
|
||
wandb_entity: my-team # optional
|
||
wandb_run_id: run-123 # optional, for resuming
|
||
wandb_name: experiment-name # optional
|
||
logging_steps: 1 # log every step (recommended for RL)
|
||
```
|
||
|
||
### Debug Logging
|
||
|
||
For detailed axolotl-internal debug output:
|
||
|
||
```
|
||
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
|
||
```
|
||
|
||
TipAlways Log to a File
|
||
|
||
Pipe training output to a log file so you can inspect it after the run:
|
||
|
||
```
|
||
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
|
||
```
|
||
|
||
### What Axolotl Logs
|
||
|
||
**SFT metrics** (logged every `logging_steps`):
|
||
|
||
* `train/loss`, `eval/loss` – training and validation loss
|
||
* `train/grad_norm` – gradient L2 norm (before clipping)
|
||
* `train/learning_rate` – current learning rate
|
||
* `memory/max_alloc`, `memory/max_reserved` – peak GPU memory
|
||
|
||
**GRPO/RL metrics** (logged every step):
|
||
|
||
* `rewards/<name>/mean`, `rewards/<name>/std` – per-reward-function statistics
|
||
* `reward`, `reward_std` – aggregated reward across all reward functions
|
||
* `frac_reward_zero_std` – fraction of prompt groups where all completions got the same reward
|
||
* `completions/mean_length`, `completions/min_length`, `completions/max_length` – completion token lengths
|
||
* `completions/clipped_ratio` – fraction of completions that hit the max length
|
||
* `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` – lengths of naturally terminated completions
|
||
* `kl` – KL divergence between policy and reference
|
||
* `entropy` – policy entropy (measure of output diversity)
|
||
* `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` – PPO clipping statistics
|
||
* `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` – log-probability difference between policy and sampling distribution
|
||
* `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` – IS ratio statistics for off-policy correction
|
||
* `num_tokens` – total tokens processed
|
||
|
||
### Reading W&B Charts
|
||
|
||
For a healthy GRPO run, expect to see:
|
||
|
||
1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.3–0.8 depending on task difficulty. Not monotonic – fluctuations are normal.
|
||
2. **`entropy`**: Gradual decrease from initial values (often 0.3–0.6) as the model becomes more confident. Should not collapse to near-zero.
|
||
3. **`grad_norm`**: Mostly in the 0.001–1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
|
||
4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
|
||
5. **`completions/mean_length`**: Should reflect the task’s natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.
|