diff --git a/docs/evidence/karpathy_nn_zero_to_hero_lec4_diagnostics.md b/docs/evidence/karpathy_nn_zero_to_hero_lec4_diagnostics.md new file mode 100644 index 0000000..5de7f03 --- /dev/null +++ b/docs/evidence/karpathy_nn_zero_to_hero_lec4_diagnostics.md @@ -0,0 +1,113 @@ +# nn-zero-to-hero Lecture 4: Activations, Gradients, BatchNorm + +**Source:** Andrej Karpathy, nn-zero-to-hero lecture series +**Notebook:** lectures/makemore/makemore_part3_bn.ipynb +**URL:** https://github.com/karpathy/nn-zero-to-hero +**Lecture description:** "We dive into some of the internals of MLPs with multiple layers and scrutinize the statistics of the forward pass activations, backward pass gradients, and some of the typical diagnostic tools and visualizations you'd want to use to understand the health of your deep network." + +--- + +## Incremental improvements documented (from notebook markdown) + +``` +original: train 2.1245 val 2.1682 +fix softmax wrong: train 2.07 val 2.13 (overconfident init) +fix tanh saturated: train 2.0356 val 2.1027 (init scale too large) +use kaiming init: train 2.0377 val 2.1070 (semi-principled) +add batch norm: train 2.0668 val 2.1048 (stable across random seeds) +``` + +Each row = one targeted fix. The ordering demonstrates the hierarchy: data/loss first, then init, then architecture. + +--- + +## Activation saturation check (tanh) + +```python +for i, layer in enumerate(layers[:-1]): + if isinstance(layer, Tanh): + t = layer.out + print('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%' % + (i, layer.__class__.__name__, t.mean(), t.std(), + (t.abs() > 0.97).float().mean() * 100)) + hy, hx = torch.histogram(t, density=True) + plt.plot(hx[:-1].detach(), hy.detach()) +# Healthy: distributions roughly Gaussian, saturation <5%. +# Bad: bimodal at +/-1 = too saturated (weights too large at init, or missing BN). +# Bad: all near 0 = dead layer (weights too small or gain 0). +``` + +--- + +## Gradient distribution check (per-layer) + +```python +for i, layer in enumerate(layers[:-1]): + if isinstance(layer, Tanh): + t = layer.out.grad # requires retain_grad() in training loop + print('layer %d (%10s): mean %+f, std %e' % + (i, layer.__class__.__name__, t.mean(), t.std())) + hy, hx = torch.histogram(t, density=True) + plt.plot(hx[:-1].detach(), hy.detach()) +# Healthy: similar gradient std across layers (no vanishing/exploding gradient). +# Bad: gradient std shrinks toward earlier layers = vanishing gradient. +# Bad: gradient std explodes = need BN, gradient clipping, or better init. +``` + +--- + +## Grad:data ratio check (weight matrices) + +```python +for i, p in enumerate(parameters): + t = p.grad + if p.ndim == 2: + print('weight %10s | mean %+f | std %e | grad:data ratio %e' % + (tuple(p.shape), t.mean(), t.std(), t.std() / p.std())) + hy, hx = torch.histogram(t, density=True) + plt.plot(hx[:-1].detach(), hy.detach()) +# grad:data ratio ~ 1e-3 is healthy. +# Much higher: gradients dominate weights, learning rate too large. +# Much lower: weights barely moving, potentially dead layer. +``` + +--- + +## Update-to-data ratio tracker (training loop) + +```python +ud = [] + +# Inside training loop: +for p in parameters: + p.data += -lr * p.grad + +with torch.no_grad(): + ud.append([((lr * p.grad).std() / p.data.std()).log10().item() + for p in parameters]) + +# After training, plot: +plt.figure(figsize=(20, 4)) +legends = [] +for i, p in enumerate(parameters): + if p.ndim == 2: + plt.plot([ud[j][i] for j in range(len(ud))]) + legends.append('param %d' % i) +plt.plot([0, len(ud)], [-3, -3], 'k') # target ~1e-3 +plt.legend(legends) +# Each line should stay near -3. +# Rising above -3: LR too large, may diverge. +# Sinking below -3: LR too small, near-zero updates. +# Diverging between layers: need better initialization or BN. +``` + +--- + +## Key pedagogical insight from the notebook + +The notebook demonstrates by construction (not just assertion) that: +1. Saturated tanh at init → slow learning (gradient vanishes through tanh) +2. Kaiming init → ~same scale activations throughout depth +3. BatchNorm → robust to poor init; normalization forces healthy activation stats + +The incremental improvement log (above) makes this concrete: each targeted fix yields measurable improvement. This is the same pattern as the recipe blog post but with code and measured results. diff --git a/docs/evidence/nanochat_deepwiki_llm_pretraining_2026.md b/docs/evidence/nanochat_deepwiki_llm_pretraining_2026.md index d52652c..217f4cf 100644 --- a/docs/evidence/nanochat_deepwiki_llm_pretraining_2026.md +++ b/docs/evidence/nanochat_deepwiki_llm_pretraining_2026.md @@ -1,91 +1,219 @@ # nanochat: LLM Pretraining Engineering Notes -**Source:** deepwiki.com/karpathy/nanochat (AI-generated wiki from karpathy/nanochat repo) +**Sources:** +- deepwiki.com/karpathy/nanochat (sections 3, 12, 13) -- AI-generated wiki from source + LOG.md +- github.com/karpathy/nanochat/blob/main/dev/LOG.md -- primary experiment log **URLs:** https://deepwiki.com/karpathy/nanochat, https://github.com/karpathy/nanochat **Date accessed:** 2026-03 -**Context:** nanochat is Karpathy's 2026 open-source minimal LLM speedrun (GPT-2 level in ~2.5h on 8xH100, ~3500 lines, ~$48). -**Caveat:** The deepwiki page is AI-generated from source code; treat as secondary documentation, not direct quotes. +**Context:** nanochat is Karpathy's 2026 open-source minimal LLM speedrun (GPT-2 level in ~2.5h on 8xH100, ~3500 lines). The LOG.md documents 320+ HP sweeps from Jan-Mar 2026. +**Caveat:** deepwiki pages are AI-generated from source code; treat as secondary docs. LOG.md quotes are primary (verbatim from the experiment log). --- -## Design principle: explicit over implicit +## 1. Dataset >> Architecture (empirical) -> Explicit over implicit: No `torch.amp.autocast` magic; precision managed via `COMPUTE_DTYPE` global +From LOG.md (2026-03-04): +> "This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction." -Auto-detected at runtime: bfloat16 on SM 80+ (A100/H100), float32 on older GPUs. +The 27% came from one dataset swap (FineWeb-EDU 100B → ClimbMix 400B). The previous 5 architecture/dataset attempts all failed: +1. Vanilla FineWeb (CORE 0.2602 → 0.2241) +2. FinePDFs mixture (0.2602 → 0.2549) +3. Dolma3_mix-6T (failed) +4-5. Two more undocumented attempts. -**Debugging application:** Override globally for numerical stability debugging: -```bash -NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" +**Lesson:** If training is slow or CORE is low, swap datasets before tuning architecture. + +--- + +## 2. Scale-dependent HP sensitivity: tune at target scale + +From deepwiki section 12 (sourced from LOG.md sweeps): + +> "Fine-tuned d12 hyperparameters actively hurt d20 performance." + +- d12 → d20 HP transfer fails: improvement magnitude shrinks (~0.002 at d12 → ~0.0007 at d20) +- `x0_beta1` sweep at d20: flat plateau 0.90-0.96, **sharp cliff at 0.98** (catastrophic: +0.0033 bpb) +- "Add only changes that were validated at d20+" before production + +**Sweep methodology:** +1. Quick experiment at d12 (~5 min): directional signal +2. Validate at target scale d20 (~20 min) +3. If still promising, validate at production d24+ (~1-2 hours) + +--- + +## 3. Multi-axis validation: steps, FLOPs, wall-clock + +From LOG.md (throughout): +> "Improvements must show gains across multiple axes: per-step efficiency (loss vs. step), wall-clock efficiency (loss vs. time), and compute efficiency (loss vs. FLOPs)." + +**FP8 example (LOG.md 2026-02-02):** +- Microbenchmark: 1.38x speedup +- Full training: 1.17x tok/sec +- Capability-matched (accounting for precision loss): **~5% real gain** + +> "torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops." + +**MoE example (LOG.md 2026-02-19):** MFU dropped 46% → 35%; per-step improvement didn't compensate; net negative. + +--- + +## 4. Negative results: what doesn't work at GPT-2 scale + +**SwiGLU** (2026-02-05): Iso-FLOP swap, tested d12 and d24. Worse on step efficiency, wall clock, FLOPs. ReLU² remains superior. + +**Mixture of Experts** (2026-02-19): +- `torch._grouped_mm` dispatch overhead: MFU 46% → 35% +- Per-step improvement doesn't compensate throughput hit +- FP8 unsupported for grouped matmul (needs separate API + custom Triton kernels) +- Verdict: "MoE is not worth the trouble for nanochat right now." + +**Multi-Token Prediction:** +13GB memory, MFU −1%, no per-step improvement, wall-clock worse. + +**Batch size ramping:** Small gains observed but code complexity not justified. + +**Five data mixtures** all worse than FineWeb-EDU before ClimbMix (see §1). + +--- + +## 5. MFU monitoring: primary throughput health check + +> "In wandb, `train/mfu` (Model FLOPs Utilization) should be >40%" + +MFU <40% suggests: +- GPU memory underutilized (device batch size too small) +- I/O bottleneck (data loading slower than compute) +- Excessive distributed synchronization overhead + +MFU calculation: `(flops_per_token × batch_tokens_per_sec) / (gpu_peak_flops × n_gpus)` + +Normal range 40-60% on 8xH100 for transformer training. + +--- + +## 6. BOS alignment: loss improvement may be "fake" + +From deepwiki section 12: +> "The 'lower validation loss' from BOS-alignment is misleading—it's just fewer noisy tokens, not better learning." + +Best-fit packing (adopted) vs greedy-crop (baseline): +- Greedy-crop: 39.4% of tokens are crops (mid-document) +- Best-fit: 34.6% crops -- still significant + +Both ensure sequences start at document boundaries (BOS token). Sequences that start mid-document add confusing tokens and inflate validation loss. + +**Implication:** When comparing two training runs with different dataloaders, check if the loss comparison is apples-to-apples. + +--- + +## 7. Explicit dtype management > autocast + +From LOG.md (2026-03-04): +> "autocast is 'magic we don't control' — it silently decides which ops run in which precision via internal allowlists." + +Replaced autocast with: +```python +COMPUTE_DTYPE = torch.bfloat16 if sm >= 80 else torch.float32 # auto-detected +# Override: NANOCHAT_DTYPE=float32 python train.py ``` -Avoids hunting through scattered `with autocast():` blocks when debugging NaN/Inf. + +Custom `Linear` class casts weights to match input dtype: `F.linear(x, self.weight.to(dtype=x.dtype))`. + +**Debugging application:** Override `NANOCHAT_DTYPE=float32` globally to debug NaN/Inf without hunting `with autocast():` blocks. + +FA3 (Hopper kernels): doesn't support fp16/fp32 → automatic fallback to SDPA. --- -## Monitoring: MFU target +## 8. FP16 + distributed: inf detection must be synchronized -> When performance is unexpectedly low: Check `train/mfu` (Model FLOPs Utilization) should be >40% +From deepwiki section 12: +> "If any rank's gradient contains inf, **all ranks must clip to avoid divergence**." -MFU <40% suggests: GPU memory underutilized (batch size too small), I/O bottleneck (data loading slower than compute), or excessive distributed-training synchronization overhead. +Pattern: +```python +grad_norm = clip_grad_norm_(model.parameters(), 1.0) +dist.all_reduce(grad_norm, op=dist.ReduceOp.MAX) # "is any rank inf?" +if torch.isinf(grad_norm): + optimizer.zero_grad(); continue # skip step on ALL ranks +``` + +Single-GPU testing hides this bug. Always test distributed code multi-GPU. --- -## Data pipeline: BOS-aligned dataloader +## 9. Empirical scaling laws (from 320+ sweeps) -> BOS-aligned best-fit dataloader ensuring every sequence starts with document boundary +**Batch size** (sourced from Cerebras "Power Lines" paper): +``` +B_opt ∝ D^0.383 (D = target training tokens) +``` +Reference: d12 at B=2^19. 10× more tokens → only ~2.4× bigger batch (sublinear). -Sequences must start at document boundaries (BOS token), not mid-document. Prevents loss spikes from predicting the start of an unrelated document as if it were a continuation. +| Depth | Target Tokens | Auto Batch | +|-------|--------------|------------| +| d8 | 0.44B | 2^18 = 262K | +| d12-16| 0.7B-2.5B | 2^19 = 524K | +| d18-26| 3.4B-9.6B | 2^20 = 1.05M | + +**Weight decay** (empirically derived, LOG.md): + +| Depth | Width | Optimal WD | +|-------|-------|-----------| +| d8 | 512 | ~0.40 | +| d12 | 768 | ~0.22 | +| d16 | 1024 | ~0.10 | +| d20 | 1280 | ~0.08 | + +Power law fit: `WD ∝ 1/width²`. Scale from reference: `WD_target = WD_ref × (width_ref/width_target)²`. --- -## Systematic HP development: 320+ sweeps +## 10. Python GC overhead: disable after warmup -> The dev/LOG.md experiment log documents 320+ hyperparameter sweeps and design decisions made since January 2026. +From deepwiki section 3: +> "GC is disabled after step 1 to prevent 500ms overhead from cycle detection." -**Principled generalization criterion:** Changes must work across model depths (d8 to d50+), not just the target size. Improvements that only help at one scale are artifacts, not general algorithmic improvements. +500ms × 880 steps ≈ 7 minutes lost to GC on a 2.76h run (4.4% overhead). Disable safely after step 1 when allocation patterns stabilize. --- -## Four-axis improvement validation +## 11. Cautious weight decay + torch.compile gotcha -When implementing an optimization, validate across: -1. Loss per training step (convergence speed) -2. Loss per wall-clock time (helps despite potentially slower per-step?) -3. Loss per FLOP (better hardware utilization vs. better algorithm?) +From deepwiki section 12: +> "Must inline logic in optimizer step. Passing `weight_decay` as function argument triggers torch.compile recompilation on schedule changes." -Prevents optimizations that appear good on one metric but regress on others. +```python +# Good: read at step time from group dict +for group in param_groups: + wd = group["weight_decay"] # no recompile on schedule change + +# Bad: pass as argument (recompiles when wd changes) +def step(self, wd): # triggers recompile every step if wd schedule varies +``` --- -## Scaling laws (empirical, from 320+ sweeps) +## 12. Compute-optimal ratio: 10.5 (Kaplan-style counting) -- Batch size: `B ∝ D^0.383` where D = target training tokens (sublinear, not linear scaling) -- Learning rate: per-component scaling with `√(768/n_embd)` factors -- Weight decay: `WD ∝ 1/width²` +From LOG.md sweeps across parameter-counting methods: +- Kaplan-style (projections including lm_head, no embeddings): stable 10.5 ratio across scales +- Chinchilla-style (all params): varies 3.0-4.0 -Credence ~60-65%: stated as empirical, derivation not provided. +For speedrun: deliberately undertrain to ratio ~9.5 (saves ~2-3h) to hit GPT-2 CORE threshold. --- -## OOM debugging: reduce device batch, keep effective batch +## 13. FP8 summary -> Reducing device-batch-size from 32 to 16 triggers 2× gradient accumulation - -Gradient accumulation maintains effective batch size. OOM errors often solvable without changing the training recipe. +- Effective speedup at d24 scale: ~5% (capability-matched), not the microbenchmark 1.38x +- Memory saving: ~9GB activations stored as FP8 vs BF16 +- `torch.compile` mandatory: without it, FP8 is 4× slower +- Only works on Hopper (H100, SM 90+) +- During evaluation: **disable FP8** (use BF16/FP32) -- FP8 introduces ~5% accuracy variance --- -## FP8 caveat +## 14. Key gap this fills -FP8 only works on Hopper architecture (H100). Remove `--fp8` on A100 or older. - ---- - -## Key gap this fills - -The existing ml_debug skill sources (2017-2021) predate modern LLM pretraining at scale. nanochat is one of the few open-source codebases that publicly documents the empirical decisions behind training a transformer from scratch in 2026, including 320+ sweep results. It covers: -- Loss spike prevention (BOS alignment) -- Distributed training OOM (gradient accumulation) -- Precision management (explicit dtype, FP8 caveat) -- MFU monitoring -- Cross-scale generalization testing +The existing ml_debug skill sources (2017-2021) predate modern LLM pretraining at scale. nanochat is one of the few open-source codebases that publicly documents the empirical decisions behind training a transformer from scratch in 2026, with quantified results: 320+ sweeps, negative results, scaling laws, and specific failure modes. diff --git a/refs/diagnostics.md b/refs/diagnostics.md index 4aece81..77c01b2 100644 --- a/refs/diagnostics.md +++ b/refs/diagnostics.md @@ -159,7 +159,7 @@ for conf, pred, true, idx in errors[:10]: # Inspect the actual inputs for these indices. Pattern = systematic bug. ``` -**Update-to-data ratio check** [Karpathy nn-zero-to-hero Lec 4] +**Update-to-data ratio check** [Karpathy nn-zero-to-hero Lec 4; evidence: karpathy_nn_zero_to_hero_lec4_diagnostics.md] ```python # Track during training: how large are updates relative to parameter magnitudes? # Target: ~1e-3 (log10 ~ -3). Much higher = LR too large. Much lower = LR too small.