mirror of
https://github.com/wassname/ml_debug.git
synced 2026-06-27 01:00:14 +08:00
bbe3fe0985
refs/static_analysis.md: JAX-specific grep patterns (in-place mutation, print side effects, key reuse, numpy escape, cast behavior). refs/diagnostics.md: JAX equivalents table (NaN detection, gradcheck, disable_jit, debug.print, debug.breakpoint, checkify).
132 lines
4.9 KiB
Markdown
132 lines
4.9 KiB
Markdown
# 6.1 Static analysis: grep for silent bugs
|
|
|
|
Part of the [ML Debugging skill](../SKILL.md), section 6.1.
|
|
|
|
Run these searches on the codebase before anything else. Each catches a common bug that produces no error but wrong results.
|
|
|
|
**Shape mismatches (silent broadcasting)**
|
|
```
|
|
# Grep patterns:
|
|
\.view\(|\.reshape\( # check dims match intent
|
|
unsqueeze\(|squeeze\( # dimension insertion/removal
|
|
\.expand\(|\.repeat\( # broadcasting
|
|
# Action: for every hit, trace the tensor shape backward. Add assert statements.
|
|
```
|
|
|
|
**Autograd breakers**
|
|
```
|
|
# Grep patterns:
|
|
\.detach\(\) # breaks gradient flow
|
|
\.data\b # bypasses autograd entirely
|
|
with torch\.no_grad # check this isn't wrapping training code
|
|
\.item\(\) # in a loss computation = broken
|
|
\.numpy\(\) # in forward pass = broken
|
|
# Action: every .detach() should have a comment explaining WHY grad is intentionally stopped.
|
|
```
|
|
|
|
**Missing train/eval mode**
|
|
```
|
|
# Grep patterns:
|
|
\.train\(\) # count occurrences
|
|
\.eval\(\) # should pair with .train()
|
|
# Action: verify .eval() before every val loop, .train() before every train loop.
|
|
# Dropout and batchnorm behave differently -- this silently degrades results.
|
|
```
|
|
|
|
**In-place ops on tensors requiring grad**
|
|
```
|
|
# Grep patterns:
|
|
\+=|\-=|\*=|/= # in-place assignment on tensors
|
|
\.add_\(|\.mul_\(|\.zero_\( # in-place methods
|
|
\[.*\]\s*=[^=] # index assignment (excludes ==)
|
|
# Action: in-place ops on leaf tensors with requires_grad=True corrupt autograd.
|
|
# Replace x += y with x = x + y.
|
|
```
|
|
|
|
**Double softmax (softmax input to CrossEntropyLoss)**
|
|
```
|
|
# Grep patterns:
|
|
CrossEntropyLoss|cross_entropy # expects raw logits
|
|
softmax|log_softmax|\.softmax # if applied BEFORE CrossEntropyLoss = double softmax
|
|
# Action: CrossEntropyLoss = log_softmax + NLLLoss internally.
|
|
# If you softmax first, CE computes log_softmax(softmax(x)) -- the softmax
|
|
# compresses logits into (0,1), so log_softmax sees near-uniform inputs.
|
|
# Gradients vanish. Loss plateaus near ln(n_classes).
|
|
```
|
|
|
|
**Wrong optimizer step ordering**
|
|
```
|
|
# Grep patterns -- verify this exact order exists:
|
|
# 1. optimizer.zero_grad()
|
|
# 2. loss.backward()
|
|
# 3. [optional: clip_grad_norm_]
|
|
# 4. optimizer.step()
|
|
# 5. [optional: scheduler.step()]
|
|
# Common bugs: zero_grad after backward (kills grads), step before backward (stale grads),
|
|
# scheduler.step() in wrong loop: per-epoch schedulers (StepLR, CosineAnnealingLR)
|
|
# called per-batch = decays too fast. Per-step schedulers (OneCycleLR) called per-epoch = too slow.
|
|
```
|
|
|
|
**Broadcasting traps**
|
|
```python
|
|
# Diagnostic: print shapes at every binary operation between tensors of different ndim
|
|
# Shapes (3,) and (3,1) silently broadcast to (3,3) -- probably not intended.
|
|
# Shapes (B,1) and (B,N) broadcast fine but verify it's intentional.
|
|
a = torch.randn(3)
|
|
b = torch.randn(3, 1)
|
|
print((a + b).shape) # (3, 3) -- wanted (3,)?
|
|
```
|
|
|
|
**Wrong loss sign**
|
|
```
|
|
# Grep patterns:
|
|
maximize|ascent # gradient ascent when descent intended?
|
|
\-\s*loss # negating loss -- intentional (e.g., reward maximization)?
|
|
1\.0\s*-\s*|1\s*-\s* # 1 - metric as loss -- is the metric bounded [0,1]?
|
|
# Action: verify that minimizing the loss = improving the metric you care about.
|
|
```
|
|
|
|
**Frozen parameters not intended**
|
|
```
|
|
# Grep patterns:
|
|
requires_grad\s*=\s*False # intentional freeze?
|
|
\.freeze\(|\.requires_grad_ # parameter freezing
|
|
for.*param.*\.parameters # check nothing is skipped
|
|
# Diagnostic:
|
|
for name, p in model.named_parameters():
|
|
if not p.requires_grad:
|
|
print(f"FROZEN: {name}")
|
|
```
|
|
|
|
**Data leakage**
|
|
```
|
|
# Grep patterns:
|
|
\.fit_transform\( # on test data = leakage
|
|
train_test_split.*shuffle=True # for time series = leakage
|
|
# Action: fit on train only, transform on both. Use temporal split for time series.
|
|
```
|
|
|
|
**Class imbalance**
|
|
```
|
|
# Grep patterns:
|
|
CrossEntropyLoss\(\) # no weight= argument? check if classes balanced
|
|
weight=.*class # existing balancing -- verify weights are correct
|
|
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
|
|
# 100:1 ratio with unweighted loss = model predicts majority class.
|
|
```
|
|
|
|
---
|
|
|
|
## JAX-specific patterns
|
|
|
|
```
|
|
# Grep patterns for JAX codebases:
|
|
x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set())
|
|
print\( # side effect at trace time only (use jax.debug.print)
|
|
\bif\b.*traced # TracerBoolConversionError risk
|
|
random\.\w+\(key\b # key reuse without prior split (identical samples)
|
|
jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion
|
|
\bnp\. # numpy ops escape the traced computation graph
|
|
\.astype\( # backend-dependent cast behavior (clamped, not wrapped)
|
|
```
|