mirror of
https://github.com/wassname/ml-debug.git
synced 2026-06-27 17:49:08 +08:00
feat(ml_debug): add JAX grep patterns and diagnostic equivalents
refs/static_analysis.md: JAX-specific grep patterns (in-place mutation, print side effects, key reuse, numpy escape, cast behavior). refs/diagnostics.md: JAX equivalents table (NaN detection, gradcheck, disable_jit, debug.print, debug.breakpoint, checkify).
This commit is contained in:
@@ -169,3 +169,16 @@ for name, p in model.named_parameters():
|
|||||||
# Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN.
|
# Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN.
|
||||||
# After training: weights diverging to +/-inf = exploding. All same value = dead.
|
# After training: weights diverging to +/-inf = exploding. All same value = dead.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## JAX diagnostic equivalents
|
||||||
|
|
||||||
|
| Diagnostic | PyTorch | JAX |
|
||||||
|
|------------|---------|-----|
|
||||||
|
| NaN detection | `torch.autograd.detect_anomaly()` | `jax.config.update("jax_debug_nans", True)` |
|
||||||
|
| Gradient check | `torch.autograd.gradcheck(fn, inputs)` | `jax.test_util.check_grads(fn, args, order=2)` |
|
||||||
|
| Eager debug (no compile) | N/A (already eager) | `jax.config.update("jax_disable_jit", True)` |
|
||||||
|
| Print inside compiled | N/A | `jax.debug.print("{x}", x=x)` |
|
||||||
|
| Breakpoint inside compiled | `pdb.set_trace()` | `jax.debug.breakpoint()` |
|
||||||
|
| Runtime assertions inside compiled | `assert` | `jax.experimental.checkify` |
|
||||||
|
|||||||
@@ -114,3 +114,18 @@ weight=.*class # existing balancing -- verify weights are correct
|
|||||||
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
|
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
|
||||||
# 100:1 ratio with unweighted loss = model predicts majority class.
|
# 100:1 ratio with unweighted loss = model predicts majority class.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## JAX-specific patterns
|
||||||
|
|
||||||
|
```
|
||||||
|
# Grep patterns for JAX codebases:
|
||||||
|
x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set())
|
||||||
|
print\( # side effect at trace time only (use jax.debug.print)
|
||||||
|
\bif\b.*traced # TracerBoolConversionError risk
|
||||||
|
random\.\w+\(key\b # key reuse without prior split (identical samples)
|
||||||
|
jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion
|
||||||
|
\bnp\. # numpy ops escape the traced computation graph
|
||||||
|
\.astype\( # backend-dependent cast behavior (clamped, not wrapped)
|
||||||
|
```
|
||||||
|
|||||||
Reference in New Issue
Block a user