feat(ml_debug): add JAX grep patterns and diagnostic equivalents

refs/static_analysis.md: JAX-specific grep patterns (in-place mutation,
print side effects, key reuse, numpy escape, cast behavior).
refs/diagnostics.md: JAX equivalents table (NaN detection, gradcheck,
disable_jit, debug.print, debug.breakpoint, checkify).
This commit is contained in:
wassname
2026-03-06 14:10:39 +08:00
parent 7ac7aacac7
commit bbe3fe0985
2 changed files with 28 additions and 0 deletions
+13
View File
@@ -169,3 +169,16 @@ for name, p in model.named_parameters():
# Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN. # Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN.
# After training: weights diverging to +/-inf = exploding. All same value = dead. # After training: weights diverging to +/-inf = exploding. All same value = dead.
``` ```
---
## JAX diagnostic equivalents
| Diagnostic | PyTorch | JAX |
|------------|---------|-----|
| NaN detection | `torch.autograd.detect_anomaly()` | `jax.config.update("jax_debug_nans", True)` |
| Gradient check | `torch.autograd.gradcheck(fn, inputs)` | `jax.test_util.check_grads(fn, args, order=2)` |
| Eager debug (no compile) | N/A (already eager) | `jax.config.update("jax_disable_jit", True)` |
| Print inside compiled | N/A | `jax.debug.print("{x}", x=x)` |
| Breakpoint inside compiled | `pdb.set_trace()` | `jax.debug.breakpoint()` |
| Runtime assertions inside compiled | `assert` | `jax.experimental.checkify` |
+15
View File
@@ -114,3 +114,18 @@ weight=.*class # existing balancing -- verify weights are correct
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check"). # Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
# 100:1 ratio with unweighted loss = model predicts majority class. # 100:1 ratio with unweighted loss = model predicts majority class.
``` ```
---
## JAX-specific patterns
```
# Grep patterns for JAX codebases:
x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set())
print\( # side effect at trace time only (use jax.debug.print)
\bif\b.*traced # TracerBoolConversionError risk
random\.\w+\(key\b # key reuse without prior split (identical samples)
jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion
\bnp\. # numpy ops escape the traced computation graph
\.astype\( # backend-dependent cast behavior (clamped, not wrapped)
```