feat(ml_debug): add JAX grep patterns and diagnostic equivalents

refs/static_analysis.md: JAX-specific grep patterns (in-place mutation,
print side effects, key reuse, numpy escape, cast behavior).
refs/diagnostics.md: JAX equivalents table (NaN detection, gradcheck,
disable_jit, debug.print, debug.breakpoint, checkify).
This commit is contained in:
wassname
2026-03-06 14:10:39 +08:00
parent 7ac7aacac7
commit bbe3fe0985
2 changed files with 28 additions and 0 deletions
+15
View File
@@ -114,3 +114,18 @@ weight=.*class # existing balancing -- verify weights are correct
# Diagnostic: count labels per class (see diagnostics.md "Class imbalance check").
# 100:1 ratio with unweighted loss = model predicts majority class.
```
---
## JAX-specific patterns
```
# Grep patterns for JAX codebases:
x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set())
print\( # side effect at trace time only (use jax.debug.print)
\bif\b.*traced # TracerBoolConversionError risk
random\.\w+\(key\b # key reuse without prior split (identical samples)
jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion
\bnp\. # numpy ops escape the traced computation graph
\.astype\( # backend-dependent cast behavior (clamped, not wrapped)
```