Files
ml-debug/refs/diagnostics.md
T
wassname 8ee980d62f diagnostics: add NaN-poisoning leakage tracer + Karpathy backprop-to-input check; README citation
NaN poisoning: inject NaN where info must not come from (future/test/labels), run the real pipeline, assert past outputs stay finite. Documents false negatives (pandas skipna, nanmean) and false positives (softmax rows, batch stats). Backprop-to-input is its gradient dual for inside the model; quote already frozen in docs/evidence/karpathy_recipe_training_nn_2019.md.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 10:18:51 +08:00

11 KiB

6.2 Diagnostic code snippets

Part of the ML Debugging skill, section 6.2.

Copy-paste these. Each tests one thing.

Data pipeline sanity check

batch = next(iter(train_loader))
for k, v in (batch.items() if isinstance(batch, dict) else enumerate(batch)):
    if isinstance(v, torch.Tensor):
        print(f"{k}: shape={v.shape}, dtype={v.dtype}, "
              f"range=[{v.min():.3f}, {v.max():.3f}], "
              f"mean={v.float().mean():.3f}, std={v.float().std():.3f}, "
              f"nan={v.isnan().sum()}, inf={v.isinf().sum()}")
    else:
        print(f"{k}: type={type(v)}, len={len(v) if hasattr(v, '__len__') else 'scalar'}")
# Check: inputs ~mean 0, std 1? Labels in expected range? No NaN/Inf? Shapes match model?

Init loss check

model.eval()
with torch.no_grad():
    batch = next(iter(train_loader))
    out = model(batch['input'])  # adapt to your interface
    loss = loss_fn(out, batch['target'])
    print(f"Init loss: {loss.item():.4f}")

# Expected init loss (random predictions):
# - CrossEntropy, C classes:  -ln(1/C) = ln(C)
#     C=2: 0.693, C=10: 2.303, C=100: 4.605, C=1000: 6.908
# - Binary CrossEntropy:      -ln(0.5) = 0.693
# - MSE (targets ~N(0,1)):    ~1.0 (if init outputs ~0) or ~var(targets)
# - L1 (targets ~N(0,1)):     ~0.8
#
# If init loss << expected: model is cheating (data leakage, shortcut)
# If init loss >> expected: wrong loss fn, bad init, or data pipeline broken

Overfit-one-batch test

model.train()
batch = next(iter(train_loader))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for step in range(200):
    optimizer.zero_grad()
    out = model(batch['input'])
    loss = loss_fn(out, batch['target'])
    loss.backward()
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0)
    optimizer.step()
    if step % 20 == 0:
        print(f"step {step:3d}  loss={loss.item():.4f}  grad_norm={grad_norm:.4f}")

# Expected: loss drops to ~0 within 200 steps.
# If not: model can't even memorize 1 batch -- architecture or gradient problem.

Gradient flow check (per-layer)

loss.backward()
for name, p in model.named_parameters():
    if p.grad is not None:
        g = p.grad
        print(f"{name:40s}  grad: mean={g.mean():+.2e}, std={g.std():.2e}, "
              f"max={g.abs().max():.2e}, zero%={100*(g==0).float().mean():.0f}")
    else:
        print(f"{name:40s}  grad: None")  # <-- not in computation graph!
# Check: no None grads (disconnected), no all-zero grads (dead layer),
# no huge grads (explosion), reasonable magnitude across layers.

NaN/Inf detector hooks

def nan_hook(module, input, output):
    def _check(t, label):
        if isinstance(t, torch.Tensor) and (torch.isnan(t).any() or torch.isinf(t).any()):
            raise RuntimeError(
                f"NaN/Inf in {module.__class__.__name__} {label}, "
                f"shape={t.shape}, nan={t.isnan().sum()}, inf={t.isinf().sum()}")
    if isinstance(output, torch.Tensor):
        _check(output, "output")
    elif isinstance(output, dict):
        for k, v in output.items():
            _check(v, f"output[{k!r}]")
    elif isinstance(output, (tuple, list)):
        for i, o in enumerate(output):
            _check(o, f"output[{i}]")

for name, module in model.named_modules():
    module.register_forward_hook(nan_hook)
# Run one forward pass. First module to raise = source of the NaN.

Random input test [Slavv]

# Pass random noise instead of real data. If loss/error behaves the same,
# the data pipeline is destroying information before the model sees it.
model.eval()
real_batch = next(iter(train_loader))
fake_input = torch.randn_like(real_batch['input'])
with torch.no_grad():
    real_out = model(real_batch['input'])
    fake_out = model(fake_input)
    real_loss = loss_fn(real_out, real_batch['target']).item()
    fake_loss = loss_fn(fake_out, real_batch['target']).item()
    print(f"Real input loss: {real_loss:.4f}")
    print(f"Random input loss: {fake_loss:.4f}")
# If similar: model isn't using the input. Check preprocessing, data loading, feature selection.
# If very different: model sees real signal. Problem is elsewhere.

NaN poisoning (leakage tracer) [wassname; forward-pass dual of Karpathy's gradient check below]

# Leakage can hide anywhere: normalization fit on the full dataset, target
# leaking into features, window functions peeking ahead, bad splits. Instead
# of auditing each spot, inject NaN where information must NOT come from
# (the future, the test set, the label) and run the real pipeline. NaN is
# absorbing under +,-,*,/ so it spreads like dye: if any "past"/train output
# is NaN, you have a leak, and you can bisect the pipeline to find the stage
# where it crossed.
import numpy as np
X = np.random.randn(1000, n_features)
y = np.random.randn(1000)
X[cutoff:] = np.nan          # poison the future / test rows
y[cutoff:] = np.nan

Xt, yt = pipeline(X, y)       # the REAL pipeline: features, scaling, splits, windowing
assert np.isfinite(Xt[:cutoff]).all(), "leak: future reached past features"
assert np.isfinite(yt[:cutoff]).all(), "leak: future reached past targets"
# To localize: assert finiteness after each pipeline stage; first failing
# stage is where the leak crosses.

# CAVEAT false negatives (dye silently filtered -- false assurance):
#   pandas mean/std/sum default to skipna=True; np.nanmean; dropna/fillna;
#   imputers; df.rolling(...).mean() skips NaN too.
#   Fallback: poison with a huge sentinel (1e12) instead -- survives nanmean
#   and shows up as an absurd value in anything it touches.
# CAVEAT false positives (dye spreads along a legitimate axis):
#   softmax over an axis containing NaN goes all-NaN even with a CORRECT
#   additive -inf causal mask (NaN + -inf = NaN). So this cannot validate
#   causal masking inside a transformer -- use the gradient check below.
#   But NaN crossing via batch statistics is often a TRUE positive: a scaler
#   fit on train+test lets test rows poison train features. That's the leak.

Backprop-to-input dependency check [Karpathy 2019]

# The gradient-based dual of NaN poisoning: works INSIDE models where NaN
# gives false positives (attention softmax, batch/layer stats).
# Karpathy: "set the loss to be something trivial like the sum of all outputs
# of example i... ensure that you get a non-zero gradient only on the i-th input."
# Catches view-instead-of-transpose bugs that mix info across the batch dim.

# Batch independence: output i must depend only on input i
x = torch.randn(8, seq, dim, requires_grad=True)
model(x)[3].sum().backward()
assert (x.grad[[0,1,2,4,5,6,7]] == 0).all(), "leak across batch dim"

# Causal masking: output at t must not depend on inputs > t
x = torch.randn(1, seq, dim, requires_grad=True)
t = seq // 2
model(x)[0, t].sum().backward()
assert (x.grad[0, t+1:] == 0).all(), "leak: position t sees the future"
# Run in eval mode; dropout and exotic attn kernels can add noise.

Prime dimension trick [Slavv]

# Use prime/weird numbers for each dimension to catch silent broadcasting.
# If batch=7, seq=13, hidden=17, any mismatched reshape/view that "works"
# by accident with powers-of-2 will fail with primes.
x = torch.randn(7, 13, 17)  # (batch=7, seq=13, hidden=17)
out = model(x)
print(f"in={x.shape} -> out={out.shape}")
# If this crashes but normal shapes don't: you have a broadcasting bug.

Class imbalance check

from collections import Counter
all_labels = []
for batch in train_loader:
    labels = batch['target'] if isinstance(batch, dict) else batch[1]
    all_labels.extend(labels.flatten().tolist())
counts = Counter(all_labels)
total = sum(counts.values())
for cls, n in sorted(counts.items(), key=lambda x: -x[1]):
    print(f"  class {cls}: {n:6d} ({100*n/total:.1f}%)")
# Ratio > 10:1 = likely need weighted loss or resampling.
# Ratio > 100:1 = model will predict majority class and look "accurate".

Confidence-sorted error inspection [common practice, cf. FSDL error analysis]

# Find the model's most confident wrong predictions. These reveal
# systematic bugs (e.g., cropping cutting off relevant features).
model.eval()
errors = []
with torch.no_grad():
    for batch in val_loader:
        logits = model(batch['input'])
        probs = torch.softmax(logits, dim=-1)
        confidence, predicted = probs.max(dim=-1)
        wrong = predicted != batch['target']
        for i in wrong.nonzero(as_tuple=True)[0]:
            errors.append((confidence[i].item(), predicted[i].item(),
                           batch['target'][i].item(), i.item()))
errors.sort(reverse=True)  # most confident mistakes first
for conf, pred, true, idx in errors[:10]:
    print(f"  conf={conf:.3f} predicted={pred} true={true} idx={idx}")
# Inspect the actual inputs for these indices. Pattern = systematic bug.

Update-to-data ratio check [Karpathy nn-zero-to-hero Lec 4; evidence: karpathy_nn_zero_to_hero_lec4_diagnostics.md]

# Track during training: how large are updates relative to parameter magnitudes?
# Target: ~1e-3 (log10 ~ -3). Much higher = LR too large. Much lower = LR too small.
ud = []
# Inside training loop (after optimizer.step()):
with torch.no_grad():
    ud.append({
        name: ((lr * p.grad).std() / p.data.std()).log10().item()
        for name, p in model.named_parameters()
        if p.grad is not None and p.ndim >= 2
    })
# After training, plot per-layer ratios:
import matplotlib.pyplot as plt
for name in ud[0]:
    plt.plot([d[name] for d in ud], label=name)
plt.axhline(-3, color='k', linestyle='--')  # target ratio
plt.legend(); plt.ylabel('log10(update/param ratio)'); plt.show()
# If a layer's ratio is much above -3: reduce LR or add gradient clipping.
# If much below -3: that layer is barely updating -- possible dead/frozen layer.

Weight/bias distribution check [Slavv, CS231n]

for name, p in model.named_parameters():
    print(f"{name:40s}  mean={p.data.mean():+.4f}  std={p.data.std():.4f}  "
          f"min={p.data.min():+.4f}  max={p.data.max():+.4f}  "
          f"shape={list(p.shape)}")
# Healthy: roughly Gaussian, std ~0.01-1.0 depending on init scheme.
# Bad signs: all zeros, huge values (>100), std ~0 (collapsed), NaN.
# After training: weights diverging to +/-inf = exploding. All same value = dead.

JAX diagnostic equivalents

Diagnostic PyTorch JAX
NaN detection torch.autograd.detect_anomaly() jax.config.update("jax_debug_nans", True)
Gradient check torch.autograd.gradcheck(fn, inputs) jax.test_util.check_grads(fn, args, order=2)
Eager debug (no compile) N/A (already eager) jax.config.update("jax_disable_jit", True)
Print inside compiled N/A jax.debug.print("{x}", x=x)
Breakpoint inside compiled pdb.set_trace() jax.debug.breakpoint()
Runtime assertions inside compiled assert jax.experimental.checkify