From fbacefd433bf54e9d4fe0c52f08f9e29d67c36bd Mon Sep 17 00:00:00 2001 From: wassname Date: Mon, 1 Jun 2026 03:21:08 +0000 Subject: [PATCH] fix: bf16/cuda path (7 device+dtype bugs) + GPU-bf16 smoke The fp32+CPU smoke never walked the fast/full bf16+cuda path, so the whole device/dtype class was invisible to the only gate. The GPU path had in fact never run end-to-end. Seven bugs, each masked by fp32+CPU: - svd cache key: numpy has no bf16 -> hash via .view(uint8) - svd save: safetensors needs contiguous cpu tensors - svd cache-hit: load_file returns cpu -> .to(W) for device+dtype - delta_S/delta_S_hack created on cpu -> device=lin.weight.device (else the forward hook mixes cpu/cuda) - V_hack is fp32 (svd) but grads are bf16 -> cast to delta_S's space - completion_nll fed cpu ids to the cuda model in extraction - extraction orientation vote D@V.T mixed bf16 D with fp32 V Smoke is now tiny-random on GPU in bf16 (same device+dtype as fast/full), so this class stays caught. All arms (none/erase/route) and extraction paths (miss/hit/refresh) green: cin_t~0.31-0.41, cout~0 (one_sided identity). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- justfile | 9 +++++---- src/projected_grpo/antipasto.py | 18 +++++++++++------- src/projected_grpo/extract_vhack_grad.py | 5 +++-- src/projected_grpo/train.py | 21 ++++++++++++++++++--- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/justfile b/justfile index 89e71d1..d14a7bf 100644 --- a/justfile +++ b/justfile @@ -20,17 +20,18 @@ check: build-pool: uv run python -m projected_grpo.build_pool --pool-dir=out/pools/teacher_pool -# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on CPU, +# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on GPU +# in bf16 (same device+dtype as fast/full, so the cuda+bf16 path is actually covered). # beartype on so jaxtyping signatures get runtime-checked. 30 steps fires the # every-25-step save_ckpt path. erase writes g_proj; cache-miss extracts v_hack. smoke *ARGS: build-pool check - BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=erase \ + BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \ --v-hack-path=out/vhack/v_hack_smoke.safetensors \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }} # Vanilla arm: V loaded for the measure_only diagnostic (cin), grad left untouched. smoke-vanilla *ARGS: build-pool - BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=none \ + BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \ --v-hack-path=out/vhack/v_hack_smoke.safetensors \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }} @@ -38,7 +39,7 @@ smoke-vanilla *ARGS: build-pool # two-param optimizer path, periodic ablated-eval, online v_hack refresh + the # basis_overlap guard, and the final kept-vs-ablated BLUF. smoke-route *ARGS: build-pool - BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route \ + BEARTYPE=1 {{ TRAIN }} smoke --intervention=route \ --v-hack-path=out/vhack/v_hack_smoke.safetensors \ --teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \ --eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 {{ ARGS }} diff --git a/src/projected_grpo/antipasto.py b/src/projected_grpo/antipasto.py index 9226af8..f25eed2 100644 --- a/src/projected_grpo/antipasto.py +++ b/src/projected_grpo/antipasto.py @@ -42,14 +42,17 @@ def svd_cached(W: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso """U Σ Vh = W, reduced. Cache key = sha256(W) so a stale cache is impossible (different W -> different file). Native dtype in, fp32 SVD, cast back.""" SVD_CACHE.mkdir(exist_ok=True) - key = hashlib.sha256(W.detach().cpu().contiguous().numpy().tobytes()).hexdigest() + # .view(uint8) so the byte-hash is dtype-agnostic: numpy has no bf16, and every + # preset (smoke included, now) loads weights in bf16. + key = hashlib.sha256(W.detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()).hexdigest() path = SVD_CACHE / f"{key}.safetensors" if path.exists(): - t = load_file(path) - return t["U"].to(W.dtype), t["S"].to(W.dtype), t["Vh"].to(W.dtype) + t = load_file(path) # .to(W): match W's device+dtype + return t["U"].to(W), t["S"].to(W), t["Vh"].to(W) # (cache hit loads on cpu -> move to cuda) U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False) - save_file({"U": U, "S": S, "Vh": Vh}, path) - return U.to(W.dtype), S.to(W.dtype), Vh.to(W.dtype) + # safetensors needs contiguous cpu tensors; svd outputs are neither on cuda. + save_file({"U": U.contiguous().cpu(), "S": S.contiguous().cpu(), "Vh": Vh.contiguous().cpu()}, path) + return U.to(W), S.to(W), Vh.to(W) def _δ_hook(lin: nn.Linear, x_in, y): @@ -68,8 +71,9 @@ def wrap(model: nn.Module) -> dict[str, Wrap]: r = min(lin.in_features, lin.out_features) lin.register_buffer("U", U) lin.register_buffer("Vh", Vh) - lin.register_parameter("δS", nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype))) - lin.register_parameter("δS_hack", nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype))) + z = lambda: nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype, device=lin.weight.device)) + lin.register_parameter("δS", z()) # on the layer's device, else the hook mixes cpu/cuda + lin.register_parameter("δS_hack", z()) lin.register_forward_hook(_δ_hook) wrappers[name] = Wrap(lin, r) for p in model.parameters(): diff --git a/src/projected_grpo/extract_vhack_grad.py b/src/projected_grpo/extract_vhack_grad.py index 59db95d..33fc2f3 100644 --- a/src/projected_grpo/extract_vhack_grad.py +++ b/src/projected_grpo/extract_vhack_grad.py @@ -28,8 +28,9 @@ from projected_grpo.problems import Pair, default_pairs def completion_nll(model, tok, prompt: str, completion: str) -> torch.Tensor: """Mean NLL on completion tokens only.""" + dev = next(model.parameters()).device # full feeds the cuda model; ids must follow p_ids = tok(prompt, return_tensors="pt").input_ids - full = tok(prompt + completion, return_tensors="pt").input_ids + full = tok(prompt + completion, return_tensors="pt").input_ids.to(dev) n_p = p_ids.shape[1] logp = per_token_logps(model(full).logits, full) # (1, L-1) comp = logp[:, n_p - 1:] # logp of completion tokens @@ -56,7 +57,7 @@ def extract_v_hack(model, tok, wrappers: dict[str, Wrap], pairs: list[Pair], V = Vh_d[:k] # (k, r), rows orthonormal in ℝ^r # orient hack-ward by per-pair majority vote (outlier-robust; proj gates on ⟨g,v_i⟩>0) n_pairs = D.shape[0] - votes = torch.sign((D @ V.T > 0).sum(0).float() - n_pairs / 2) # (k,) + votes = torch.sign((D.float() @ V.T > 0).sum(0).float() - n_pairs / 2) # (k,); D bf16, V fp32 from svd votes[votes == 0] = 1.0 V = V * votes[:, None] if τ_axis > 0: # zero noisy axes: S_i/S_0 < τ_axis diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 5e36807..40bda98 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -7,7 +7,7 @@ Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 (drop the generate -> grade -> backward -> project -> step Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park δS_hack. -Presets: smoke (tiny-random, CPU, the only gate) | fast | full. +Presets: smoke (tiny-random, GPU bf16, the only gate) | fast | full. python -m projected_grpo.train smoke --intervention=erase ... """ @@ -105,7 +105,7 @@ class Config: PRESETS = { - "smoke": Config(), + "smoke": Config(device="cuda", dtype="bf16"), # tiny-random on GPU: walks the real cuda+bf16 path "fast": Config(model="Qwen/Qwen3-4B", device="cuda", dtype="bf16", steps=60, group=8, prompts_per_step=4, max_new=512, lr=3e-3, adam_beta1=0.5, adam_beta2=0.9, mix_ratio=0.125, v_hack_path="out/vhack/v_hack_full.safetensors"), @@ -162,6 +162,8 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d m = dict(reward=R.mean().item(), hack_s=_mean([r.exploited for r in s_rew]), gt_s=_mean([r.gt_correct for r in s_rew]), hack_t=_mean([r["exploited"] for r in t_rows]), gt_t=_mean([r["gt_correct"] for r in t_rows])) + if s_rew: # stash one student gen for the jsonl + the final coherence eyeball + m["gen"] = dict(text=texts[0], hack=s_rew[0].exploited, gt=s_rew[0].gt_correct, reward=s_rew[0].reward) empty = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()} if R.max() - R.min() < 1e-6: # zero-variance group => adv≡0 => pure waste (bail) @@ -256,6 +258,8 @@ def main(cfg: Config): V_raw, V_sv = resolve_or_extract(model, tok, wrappers, cfg.v_hack_path, cfg.vhack_pairs_path, cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout) V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac) + δS_ref = next(iter(wrappers.values())).δS # V lives in δS's space: match dtype+device + V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda logger.info(f"V_hack over {len(V_hack)} modules (k_use={cfg.v_hack_k}); " f"arm={cfg.intervention} -> measure_only={cfg.intervention == 'none'}") @@ -311,6 +315,8 @@ def main(cfg: Config): reward=_mean([m["reward"] for m in ms]), loss=_mean([m.get("loss", 0.0) for m in ms]), gn=gn.item(), cin_s=cin_s, cin_t=cin_t, cin=cin, cout=cout, fired=fired, dSh=dSh) + if "gen" in ms[0]: # log the first prompt's first student gen (text + flags) + row["gen"] = ms[0]["gen"] # online refresh: re-extract V against the CURRENT model (under quarantine ablation) if cfg.vhack_refresh_every and V_hack and step % cfg.vhack_refresh_every == 0: @@ -319,6 +325,7 @@ def main(cfg: Config): V_raw, V_sv = extract_v_hack(model, tok, wrappers, default_pairs(), cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout) V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac) + V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # keep V in δS's bf16/cuda space shared = [n for n in V_hack if n in old] ov = _mean([abs((V_hack[n][0] @ old[n][0]).item()) for n in shared]) if shared else 0.0 row["basis_overlap"] = ov # GUARD: should sit near 1.0; <~0.2 => refresh rotated off-hack @@ -333,7 +340,8 @@ def main(cfg: Config): rows.append(row) log.write(json.dumps(row) + "\n"); log.flush() pbar.set_postfix(hack_s=f"{row['hack_s']:.2f}", cin_s=f"{cin_s:.2f}", cin_t=f"{cin_t:.2f}", - cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}") + cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}", + refresh=False) # don't force a bar redraw every step (pollutes piped logs) if step % 25 == 0: save_ckpt(wrappers, rows, cfg.out_tag or f"_{cfg.intervention}_s{cfg.seed}") @@ -370,6 +378,13 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id): "overlap with V (relu-before-agg), cout→0 residual hack-ward leak (identity under " "one_sided, not efficacy), |δSh| quarantine norm (>0 iff route).") logger.info(f"out: {out}") + # Last student generation -- a coherence eyeball before the numbers. SHOULD: real + # code/prose for the problem. If it is token salad the policy diverged and the eval + # numbers below are moot. (Empty on tiny-random if no student rollout was graded.) + last_gen = next((r["gen"] for r in reversed(rows) if r.get("gen")), None) + if last_gen: + logger.info(f"last student gen (hack={last_gen['hack']} gt={last_gen['gt']} " + f"reward={last_gen['reward']:.2f}):\n{last_gen['text']}") logger.info(f"main metric: {cue} hack_s={hack:.3f} solve={solve:.3f} cin_t={cin_t:.3f} cout={cout:.3f} " f"[arm={cfg.intervention},seed={cfg.seed},steps={len(rows)}]") print(f"\n{caption}\n")