mirror of
https://github.com/wassname/ml-debug.git
synced 2026-06-27 14:45:29 +08:00
feat(ml_debug): add JAX grep patterns and diagnostic equivalents
refs/static_analysis.md: JAX-specific grep patterns (in-place mutation, print side effects, key reuse, numpy escape, cast behavior). refs/diagnostics.md: JAX equivalents table (NaN detection, gradcheck, disable_jit, debug.print, debug.breakpoint, checkify).
This commit is contained in:
@@ -169,3 +169,16 @@ for name, p in model.named_parameters():
|
||||
# Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN.
|
||||
# After training: weights diverging to +/-inf = exploding. All same value = dead.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## JAX diagnostic equivalents
|
||||
|
||||
| Diagnostic | PyTorch | JAX |
|
||||
|------------|---------|-----|
|
||||
| NaN detection | `torch.autograd.detect_anomaly()` | `jax.config.update("jax_debug_nans", True)` |
|
||||
| Gradient check | `torch.autograd.gradcheck(fn, inputs)` | `jax.test_util.check_grads(fn, args, order=2)` |
|
||||
| Eager debug (no compile) | N/A (already eager) | `jax.config.update("jax_disable_jit", True)` |
|
||||
| Print inside compiled | N/A | `jax.debug.print("{x}", x=x)` |
|
||||
| Breakpoint inside compiled | `pdb.set_trace()` | `jax.debug.breakpoint()` |
|
||||
| Runtime assertions inside compiled | `assert` | `jax.experimental.checkify` |
|
||||
|
||||
@@ -114,3 +114,18 @@ weight=.*class # existing balancing -- verify weights are correct
|
||||
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
|
||||
# 100:1 ratio with unweighted loss = model predicts majority class.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## JAX-specific patterns
|
||||
|
||||
```
|
||||
# Grep patterns for JAX codebases:
|
||||
x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set())
|
||||
print\( # side effect at trace time only (use jax.debug.print)
|
||||
\bif\b.*traced # TracerBoolConversionError risk
|
||||
random\.\w+\(key\b # key reuse without prior split (identical samples)
|
||||
jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion
|
||||
\bnp\. # numpy ops escape the traced computation graph
|
||||
\.astype\( # backend-dependent cast behavior (clamped, not wrapped)
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user