diff --git a/refs/diagnostics.md b/refs/diagnostics.md index 0e084c4..b259066 100644 --- a/refs/diagnostics.md +++ b/refs/diagnostics.md @@ -169,3 +169,16 @@ for name, p in model.named_parameters(): # Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN. # After training: weights diverging to +/-inf = exploding. All same value = dead. ``` + +--- + +## JAX diagnostic equivalents + +| Diagnostic | PyTorch | JAX | +|------------|---------|-----| +| NaN detection | `torch.autograd.detect_anomaly()` | `jax.config.update("jax_debug_nans", True)` | +| Gradient check | `torch.autograd.gradcheck(fn, inputs)` | `jax.test_util.check_grads(fn, args, order=2)` | +| Eager debug (no compile) | N/A (already eager) | `jax.config.update("jax_disable_jit", True)` | +| Print inside compiled | N/A | `jax.debug.print("{x}", x=x)` | +| Breakpoint inside compiled | `pdb.set_trace()` | `jax.debug.breakpoint()` | +| Runtime assertions inside compiled | `assert` | `jax.experimental.checkify` | diff --git a/refs/static_analysis.md b/refs/static_analysis.md index a079e6a..7d5b8c6 100644 --- a/refs/static_analysis.md +++ b/refs/static_analysis.md @@ -114,3 +114,18 @@ weight=.*class # existing balancing -- verify weights are correct # Diagnostic: count labels per class (see diagnostics.md "Class imbalance check"). # 100:1 ratio with unweighted loss = model predicts majority class. ``` + +--- + +## JAX-specific patterns + +``` +# Grep patterns for JAX codebases: +x\[.*\]\s*=\s*[^=] # in-place mutation inside jit (use .at[].set()) +print\( # side effect at trace time only (use jax.debug.print) +\bif\b.*traced # TracerBoolConversionError risk +random\.\w+\(key\b # key reuse without prior split (identical samples) +jnp\.sum\(\[|jnp\.array\( # list inside jit = compilation explosion +\bnp\. # numpy ops escape the traced computation graph +\.astype\( # backend-dependent cast behavior (clamped, not wrapped) +```