feat(ml_debug): expand nanochat evidence, add lec4 diagnostics file

nanochat_deepwiki_llm_pretraining_2026.md rewritten with content from
dev/LOG.md and deepwiki sections 3/12/13:
- 14 labelled findings with direct quotes and empirical numbers
- Dataset >> architecture (27% gain, 5 failed attempts before ClimbMix)
- Scale-dependent HP sensitivity (d12 HPs hurt d20)
- Multi-axis validation (steps/wall-clock/FLOPs)
- Negative results: MoE/SwiGLU/MTP all failed at this scale
- MFU monitoring, batch size Bopt∝D^0.383, WD∝1/width² tables
- FP8 reality: 1.38x micro → 1.17x full → 5% capability-matched
- Python GC 500ms overhead, torch.compile recompile gotcha

karpathy_nn_zero_to_hero_lec4_diagnostics.md: new evidence file
- Activation saturation check (tanh >0.97)
- Gradient distribution check per-layer
- Grad:data ratio (target ~1e-3)
- Update-to-data ratio tracker with full plotting code
- Incremental improvement log from notebook
This commit is contained in:
wassname
2026-03-10 05:38:33 +08:00
parent ced4edc200
commit c9c53f8e7f
3 changed files with 289 additions and 48 deletions
@@ -0,0 +1,113 @@
# nn-zero-to-hero Lecture 4: Activations, Gradients, BatchNorm
**Source:** Andrej Karpathy, nn-zero-to-hero lecture series
**Notebook:** lectures/makemore/makemore_part3_bn.ipynb
**URL:** https://github.com/karpathy/nn-zero-to-hero
**Lecture description:** "We dive into some of the internals of MLPs with multiple layers and scrutinize the statistics of the forward pass activations, backward pass gradients, and some of the typical diagnostic tools and visualizations you'd want to use to understand the health of your deep network."
---
## Incremental improvements documented (from notebook markdown)
```
original: train 2.1245 val 2.1682
fix softmax wrong: train 2.07 val 2.13 (overconfident init)
fix tanh saturated: train 2.0356 val 2.1027 (init scale too large)
use kaiming init: train 2.0377 val 2.1070 (semi-principled)
add batch norm: train 2.0668 val 2.1048 (stable across random seeds)
```
Each row = one targeted fix. The ordering demonstrates the hierarchy: data/loss first, then init, then architecture.
---
## Activation saturation check (tanh)
```python
for i, layer in enumerate(layers[:-1]):
if isinstance(layer, Tanh):
t = layer.out
print('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%' %
(i, layer.__class__.__name__, t.mean(), t.std(),
(t.abs() > 0.97).float().mean() * 100))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
# Healthy: distributions roughly Gaussian, saturation <5%.
# Bad: bimodal at +/-1 = too saturated (weights too large at init, or missing BN).
# Bad: all near 0 = dead layer (weights too small or gain 0).
```
---
## Gradient distribution check (per-layer)
```python
for i, layer in enumerate(layers[:-1]):
if isinstance(layer, Tanh):
t = layer.out.grad # requires retain_grad() in training loop
print('layer %d (%10s): mean %+f, std %e' %
(i, layer.__class__.__name__, t.mean(), t.std()))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
# Healthy: similar gradient std across layers (no vanishing/exploding gradient).
# Bad: gradient std shrinks toward earlier layers = vanishing gradient.
# Bad: gradient std explodes = need BN, gradient clipping, or better init.
```
---
## Grad:data ratio check (weight matrices)
```python
for i, p in enumerate(parameters):
t = p.grad
if p.ndim == 2:
print('weight %10s | mean %+f | std %e | grad:data ratio %e' %
(tuple(p.shape), t.mean(), t.std(), t.std() / p.std()))
hy, hx = torch.histogram(t, density=True)
plt.plot(hx[:-1].detach(), hy.detach())
# grad:data ratio ~ 1e-3 is healthy.
# Much higher: gradients dominate weights, learning rate too large.
# Much lower: weights barely moving, potentially dead layer.
```
---
## Update-to-data ratio tracker (training loop)
```python
ud = []
# Inside training loop:
for p in parameters:
p.data += -lr * p.grad
with torch.no_grad():
ud.append([((lr * p.grad).std() / p.data.std()).log10().item()
for p in parameters])
# After training, plot:
plt.figure(figsize=(20, 4))
legends = []
for i, p in enumerate(parameters):
if p.ndim == 2:
plt.plot([ud[j][i] for j in range(len(ud))])
legends.append('param %d' % i)
plt.plot([0, len(ud)], [-3, -3], 'k') # target ~1e-3
plt.legend(legends)
# Each line should stay near -3.
# Rising above -3: LR too large, may diverge.
# Sinking below -3: LR too small, near-zero updates.
# Diverging between layers: need better initialization or BN.
```
---
## Key pedagogical insight from the notebook
The notebook demonstrates by construction (not just assertion) that:
1. Saturated tanh at init → slow learning (gradient vanishes through tanh)
2. Kaiming init → ~same scale activations throughout depth
3. BatchNorm → robust to poor init; normalization forces healthy activation stats
The incremental improvement log (above) makes this concrete: each targeted fix yields measurable improvement. This is the same pattern as the recipe blog post but with code and measured results.
@@ -1,91 +1,219 @@
# nanochat: LLM Pretraining Engineering Notes
**Source:** deepwiki.com/karpathy/nanochat (AI-generated wiki from karpathy/nanochat repo)
**Sources:**
- deepwiki.com/karpathy/nanochat (sections 3, 12, 13) -- AI-generated wiki from source + LOG.md
- github.com/karpathy/nanochat/blob/main/dev/LOG.md -- primary experiment log
**URLs:** https://deepwiki.com/karpathy/nanochat, https://github.com/karpathy/nanochat
**Date accessed:** 2026-03
**Context:** nanochat is Karpathy's 2026 open-source minimal LLM speedrun (GPT-2 level in ~2.5h on 8xH100, ~3500 lines, ~$48).
**Caveat:** The deepwiki page is AI-generated from source code; treat as secondary documentation, not direct quotes.
**Context:** nanochat is Karpathy's 2026 open-source minimal LLM speedrun (GPT-2 level in ~2.5h on 8xH100, ~3500 lines). The LOG.md documents 320+ HP sweeps from Jan-Mar 2026.
**Caveat:** deepwiki pages are AI-generated from source code; treat as secondary docs. LOG.md quotes are primary (verbatim from the experiment log).
---
## Design principle: explicit over implicit
## 1. Dataset >> Architecture (empirical)
> Explicit over implicit: No `torch.amp.autocast` magic; precision managed via `COMPUTE_DTYPE` global
From LOG.md (2026-03-04):
> "This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction."
Auto-detected at runtime: bfloat16 on SM 80+ (A100/H100), float32 on older GPUs.
The 27% came from one dataset swap (FineWeb-EDU 100B → ClimbMix 400B). The previous 5 architecture/dataset attempts all failed:
1. Vanilla FineWeb (CORE 0.2602 → 0.2241)
2. FinePDFs mixture (0.2602 → 0.2549)
3. Dolma3_mix-6T (failed)
4-5. Two more undocumented attempts.
**Debugging application:** Override globally for numerical stability debugging:
```bash
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello"
**Lesson:** If training is slow or CORE is low, swap datasets before tuning architecture.
---
## 2. Scale-dependent HP sensitivity: tune at target scale
From deepwiki section 12 (sourced from LOG.md sweeps):
> "Fine-tuned d12 hyperparameters actively hurt d20 performance."
- d12 → d20 HP transfer fails: improvement magnitude shrinks (~0.002 at d12 → ~0.0007 at d20)
- `x0_beta1` sweep at d20: flat plateau 0.90-0.96, **sharp cliff at 0.98** (catastrophic: +0.0033 bpb)
- "Add only changes that were validated at d20+" before production
**Sweep methodology:**
1. Quick experiment at d12 (~5 min): directional signal
2. Validate at target scale d20 (~20 min)
3. If still promising, validate at production d24+ (~1-2 hours)
---
## 3. Multi-axis validation: steps, FLOPs, wall-clock
From LOG.md (throughout):
> "Improvements must show gains across multiple axes: per-step efficiency (loss vs. step), wall-clock efficiency (loss vs. time), and compute efficiency (loss vs. FLOPs)."
**FP8 example (LOG.md 2026-02-02):**
- Microbenchmark: 1.38x speedup
- Full training: 1.17x tok/sec
- Capability-matched (accounting for precision loss): **~5% real gain**
> "torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops."
**MoE example (LOG.md 2026-02-19):** MFU dropped 46% → 35%; per-step improvement didn't compensate; net negative.
---
## 4. Negative results: what doesn't work at GPT-2 scale
**SwiGLU** (2026-02-05): Iso-FLOP swap, tested d12 and d24. Worse on step efficiency, wall clock, FLOPs. ReLU² remains superior.
**Mixture of Experts** (2026-02-19):
- `torch._grouped_mm` dispatch overhead: MFU 46% → 35%
- Per-step improvement doesn't compensate throughput hit
- FP8 unsupported for grouped matmul (needs separate API + custom Triton kernels)
- Verdict: "MoE is not worth the trouble for nanochat right now."
**Multi-Token Prediction:** +13GB memory, MFU 1%, no per-step improvement, wall-clock worse.
**Batch size ramping:** Small gains observed but code complexity not justified.
**Five data mixtures** all worse than FineWeb-EDU before ClimbMix (see §1).
---
## 5. MFU monitoring: primary throughput health check
> "In wandb, `train/mfu` (Model FLOPs Utilization) should be >40%"
MFU <40% suggests:
- GPU memory underutilized (device batch size too small)
- I/O bottleneck (data loading slower than compute)
- Excessive distributed synchronization overhead
MFU calculation: `(flops_per_token × batch_tokens_per_sec) / (gpu_peak_flops × n_gpus)`
Normal range 40-60% on 8xH100 for transformer training.
---
## 6. BOS alignment: loss improvement may be "fake"
From deepwiki section 12:
> "The 'lower validation loss' from BOS-alignment is misleading—it's just fewer noisy tokens, not better learning."
Best-fit packing (adopted) vs greedy-crop (baseline):
- Greedy-crop: 39.4% of tokens are crops (mid-document)
- Best-fit: 34.6% crops -- still significant
Both ensure sequences start at document boundaries (BOS token). Sequences that start mid-document add confusing tokens and inflate validation loss.
**Implication:** When comparing two training runs with different dataloaders, check if the loss comparison is apples-to-apples.
---
## 7. Explicit dtype management > autocast
From LOG.md (2026-03-04):
> "autocast is 'magic we don't control' — it silently decides which ops run in which precision via internal allowlists."
Replaced autocast with:
```python
COMPUTE_DTYPE = torch.bfloat16 if sm >= 80 else torch.float32 # auto-detected
# Override: NANOCHAT_DTYPE=float32 python train.py
```
Avoids hunting through scattered `with autocast():` blocks when debugging NaN/Inf.
Custom `Linear` class casts weights to match input dtype: `F.linear(x, self.weight.to(dtype=x.dtype))`.
**Debugging application:** Override `NANOCHAT_DTYPE=float32` globally to debug NaN/Inf without hunting `with autocast():` blocks.
FA3 (Hopper kernels): doesn't support fp16/fp32 → automatic fallback to SDPA.
---
## Monitoring: MFU target
## 8. FP16 + distributed: inf detection must be synchronized
> When performance is unexpectedly low: Check `train/mfu` (Model FLOPs Utilization) should be >40%
From deepwiki section 12:
> "If any rank's gradient contains inf, **all ranks must clip to avoid divergence**."
MFU <40% suggests: GPU memory underutilized (batch size too small), I/O bottleneck (data loading slower than compute), or excessive distributed-training synchronization overhead.
Pattern:
```python
grad_norm = clip_grad_norm_(model.parameters(), 1.0)
dist.all_reduce(grad_norm, op=dist.ReduceOp.MAX) # "is any rank inf?"
if torch.isinf(grad_norm):
optimizer.zero_grad(); continue # skip step on ALL ranks
```
Single-GPU testing hides this bug. Always test distributed code multi-GPU.
---
## Data pipeline: BOS-aligned dataloader
## 9. Empirical scaling laws (from 320+ sweeps)
> BOS-aligned best-fit dataloader ensuring every sequence starts with document boundary
**Batch size** (sourced from Cerebras "Power Lines" paper):
```
B_opt ∝ D^0.383 (D = target training tokens)
```
Reference: d12 at B=2^19. 10× more tokens → only ~2.4× bigger batch (sublinear).
Sequences must start at document boundaries (BOS token), not mid-document. Prevents loss spikes from predicting the start of an unrelated document as if it were a continuation.
| Depth | Target Tokens | Auto Batch |
|-------|--------------|------------|
| d8 | 0.44B | 2^18 = 262K |
| d12-16| 0.7B-2.5B | 2^19 = 524K |
| d18-26| 3.4B-9.6B | 2^20 = 1.05M |
**Weight decay** (empirically derived, LOG.md):
| Depth | Width | Optimal WD |
|-------|-------|-----------|
| d8 | 512 | ~0.40 |
| d12 | 768 | ~0.22 |
| d16 | 1024 | ~0.10 |
| d20 | 1280 | ~0.08 |
Power law fit: `WD ∝ 1/width²`. Scale from reference: `WD_target = WD_ref × (width_ref/width_target)²`.
---
## Systematic HP development: 320+ sweeps
## 10. Python GC overhead: disable after warmup
> The dev/LOG.md experiment log documents 320+ hyperparameter sweeps and design decisions made since January 2026.
From deepwiki section 3:
> "GC is disabled after step 1 to prevent 500ms overhead from cycle detection."
**Principled generalization criterion:** Changes must work across model depths (d8 to d50+), not just the target size. Improvements that only help at one scale are artifacts, not general algorithmic improvements.
500ms × 880 steps ≈ 7 minutes lost to GC on a 2.76h run (4.4% overhead). Disable safely after step 1 when allocation patterns stabilize.
---
## Four-axis improvement validation
## 11. Cautious weight decay + torch.compile gotcha
When implementing an optimization, validate across:
1. Loss per training step (convergence speed)
2. Loss per wall-clock time (helps despite potentially slower per-step?)
3. Loss per FLOP (better hardware utilization vs. better algorithm?)
From deepwiki section 12:
> "Must inline logic in optimizer step. Passing `weight_decay` as function argument triggers torch.compile recompilation on schedule changes."
Prevents optimizations that appear good on one metric but regress on others.
```python
# Good: read at step time from group dict
for group in param_groups:
wd = group["weight_decay"] # no recompile on schedule change
# Bad: pass as argument (recompiles when wd changes)
def step(self, wd): # triggers recompile every step if wd schedule varies
```
---
## Scaling laws (empirical, from 320+ sweeps)
## 12. Compute-optimal ratio: 10.5 (Kaplan-style counting)
- Batch size: `B ∝ D^0.383` where D = target training tokens (sublinear, not linear scaling)
- Learning rate: per-component scaling with `√(768/n_embd)` factors
- Weight decay: `WD ∝ 1/width²`
From LOG.md sweeps across parameter-counting methods:
- Kaplan-style (projections including lm_head, no embeddings): stable 10.5 ratio across scales
- Chinchilla-style (all params): varies 3.0-4.0
Credence ~60-65%: stated as empirical, derivation not provided.
For speedrun: deliberately undertrain to ratio ~9.5 (saves ~2-3h) to hit GPT-2 CORE threshold.
---
## OOM debugging: reduce device batch, keep effective batch
## 13. FP8 summary
> Reducing device-batch-size from 32 to 16 triggers 2× gradient accumulation
Gradient accumulation maintains effective batch size. OOM errors often solvable without changing the training recipe.
- Effective speedup at d24 scale: ~5% (capability-matched), not the microbenchmark 1.38x
- Memory saving: ~9GB activations stored as FP8 vs BF16
- `torch.compile` mandatory: without it, FP8 is 4× slower
- Only works on Hopper (H100, SM 90+)
- During evaluation: **disable FP8** (use BF16/FP32) -- FP8 introduces ~5% accuracy variance
---
## FP8 caveat
## 14. Key gap this fills
FP8 only works on Hopper architecture (H100). Remove `--fp8` on A100 or older.
---
## Key gap this fills
The existing ml_debug skill sources (2017-2021) predate modern LLM pretraining at scale. nanochat is one of the few open-source codebases that publicly documents the empirical decisions behind training a transformer from scratch in 2026, including 320+ sweep results. It covers:
- Loss spike prevention (BOS alignment)
- Distributed training OOM (gradient accumulation)
- Precision management (explicit dtype, FP8 caveat)
- MFU monitoring
- Cross-scale generalization testing
The existing ml_debug skill sources (2017-2021) predate modern LLM pretraining at scale. nanochat is one of the few open-source codebases that publicly documents the empirical decisions behind training a transformer from scratch in 2026, with quantified results: 320+ sweeps, negative results, scaling laws, and specific failure modes.