mirror of
https://github.com/wassname/grpo_proj2.git
synced 2026-06-27 16:30:33 +08:00
fix: bf16/cuda path (7 device+dtype bugs) + GPU-bf16 smoke
The fp32+CPU smoke never walked the fast/full bf16+cuda path, so the whole device/dtype class was invisible to the only gate. The GPU path had in fact never run end-to-end. Seven bugs, each masked by fp32+CPU: - svd cache key: numpy has no bf16 -> hash via .view(uint8) - svd save: safetensors needs contiguous cpu tensors - svd cache-hit: load_file returns cpu -> .to(W) for device+dtype - delta_S/delta_S_hack created on cpu -> device=lin.weight.device (else the forward hook mixes cpu/cuda) - V_hack is fp32 (svd) but grads are bf16 -> cast to delta_S's space - completion_nll fed cpu ids to the cuda model in extraction - extraction orientation vote D@V.T mixed bf16 D with fp32 V Smoke is now tiny-random on GPU in bf16 (same device+dtype as fast/full), so this class stays caught. All arms (none/erase/route) and extraction paths (miss/hit/refresh) green: cin_t~0.31-0.41, cout~0 (one_sided identity). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -20,17 +20,18 @@ check:
|
||||
build-pool:
|
||||
uv run python -m projected_grpo.build_pool --pool-dir=out/pools/teacher_pool
|
||||
|
||||
# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on CPU,
|
||||
# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on GPU
|
||||
# in bf16 (same device+dtype as fast/full, so the cuda+bf16 path is actually covered).
|
||||
# beartype on so jaxtyping signatures get runtime-checked. 30 steps fires the
|
||||
# every-25-step save_ckpt path. erase writes g_proj; cache-miss extracts v_hack.
|
||||
smoke *ARGS: build-pool check
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=erase \
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
|
||||
# Vanilla arm: V loaded for the measure_only diagnostic (cin), grad left untouched.
|
||||
smoke-vanilla *ARGS: build-pool
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=none \
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
|
||||
@@ -38,7 +39,7 @@ smoke-vanilla *ARGS: build-pool
|
||||
# two-param optimizer path, periodic ablated-eval, online v_hack refresh + the
|
||||
# basis_overlap guard, and the final kept-vs-ablated BLUF.
|
||||
smoke-route *ARGS: build-pool
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} smoke --intervention=route \
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=route \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 {{ ARGS }}
|
||||
|
||||
@@ -42,14 +42,17 @@ def svd_cached(W: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
|
||||
"""U Σ Vh = W, reduced. Cache key = sha256(W) so a stale cache is impossible
|
||||
(different W -> different file). Native dtype in, fp32 SVD, cast back."""
|
||||
SVD_CACHE.mkdir(exist_ok=True)
|
||||
key = hashlib.sha256(W.detach().cpu().contiguous().numpy().tobytes()).hexdigest()
|
||||
# .view(uint8) so the byte-hash is dtype-agnostic: numpy has no bf16, and every
|
||||
# preset (smoke included, now) loads weights in bf16.
|
||||
key = hashlib.sha256(W.detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()).hexdigest()
|
||||
path = SVD_CACHE / f"{key}.safetensors"
|
||||
if path.exists():
|
||||
t = load_file(path)
|
||||
return t["U"].to(W.dtype), t["S"].to(W.dtype), t["Vh"].to(W.dtype)
|
||||
t = load_file(path) # .to(W): match W's device+dtype
|
||||
return t["U"].to(W), t["S"].to(W), t["Vh"].to(W) # (cache hit loads on cpu -> move to cuda)
|
||||
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
save_file({"U": U, "S": S, "Vh": Vh}, path)
|
||||
return U.to(W.dtype), S.to(W.dtype), Vh.to(W.dtype)
|
||||
# safetensors needs contiguous cpu tensors; svd outputs are neither on cuda.
|
||||
save_file({"U": U.contiguous().cpu(), "S": S.contiguous().cpu(), "Vh": Vh.contiguous().cpu()}, path)
|
||||
return U.to(W), S.to(W), Vh.to(W)
|
||||
|
||||
|
||||
def _δ_hook(lin: nn.Linear, x_in, y):
|
||||
@@ -68,8 +71,9 @@ def wrap(model: nn.Module) -> dict[str, Wrap]:
|
||||
r = min(lin.in_features, lin.out_features)
|
||||
lin.register_buffer("U", U)
|
||||
lin.register_buffer("Vh", Vh)
|
||||
lin.register_parameter("δS", nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype)))
|
||||
lin.register_parameter("δS_hack", nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype)))
|
||||
z = lambda: nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype, device=lin.weight.device))
|
||||
lin.register_parameter("δS", z()) # on the layer's device, else the hook mixes cpu/cuda
|
||||
lin.register_parameter("δS_hack", z())
|
||||
lin.register_forward_hook(_δ_hook)
|
||||
wrappers[name] = Wrap(lin, r)
|
||||
for p in model.parameters():
|
||||
|
||||
@@ -28,8 +28,9 @@ from projected_grpo.problems import Pair, default_pairs
|
||||
|
||||
def completion_nll(model, tok, prompt: str, completion: str) -> torch.Tensor:
|
||||
"""Mean NLL on completion tokens only."""
|
||||
dev = next(model.parameters()).device # full feeds the cuda model; ids must follow
|
||||
p_ids = tok(prompt, return_tensors="pt").input_ids
|
||||
full = tok(prompt + completion, return_tensors="pt").input_ids
|
||||
full = tok(prompt + completion, return_tensors="pt").input_ids.to(dev)
|
||||
n_p = p_ids.shape[1]
|
||||
logp = per_token_logps(model(full).logits, full) # (1, L-1)
|
||||
comp = logp[:, n_p - 1:] # logp of completion tokens
|
||||
@@ -56,7 +57,7 @@ def extract_v_hack(model, tok, wrappers: dict[str, Wrap], pairs: list[Pair],
|
||||
V = Vh_d[:k] # (k, r), rows orthonormal in ℝ^r
|
||||
# orient hack-ward by per-pair majority vote (outlier-robust; proj gates on ⟨g,v_i⟩>0)
|
||||
n_pairs = D.shape[0]
|
||||
votes = torch.sign((D @ V.T > 0).sum(0).float() - n_pairs / 2) # (k,)
|
||||
votes = torch.sign((D.float() @ V.T > 0).sum(0).float() - n_pairs / 2) # (k,); D bf16, V fp32 from svd
|
||||
votes[votes == 0] = 1.0
|
||||
V = V * votes[:, None]
|
||||
if τ_axis > 0: # zero noisy axes: S_i/S_0 < τ_axis
|
||||
|
||||
@@ -7,7 +7,7 @@ Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 (drop the
|
||||
generate -> grade -> backward -> project -> step
|
||||
|
||||
Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park δS_hack.
|
||||
Presets: smoke (tiny-random, CPU, the only gate) | fast | full.
|
||||
Presets: smoke (tiny-random, GPU bf16, the only gate) | fast | full.
|
||||
|
||||
python -m projected_grpo.train smoke --intervention=erase ...
|
||||
"""
|
||||
@@ -105,7 +105,7 @@ class Config:
|
||||
|
||||
|
||||
PRESETS = {
|
||||
"smoke": Config(),
|
||||
"smoke": Config(device="cuda", dtype="bf16"), # tiny-random on GPU: walks the real cuda+bf16 path
|
||||
"fast": Config(model="Qwen/Qwen3-4B", device="cuda", dtype="bf16", steps=60, group=8,
|
||||
prompts_per_step=4, max_new=512, lr=3e-3, adam_beta1=0.5, adam_beta2=0.9,
|
||||
mix_ratio=0.125, v_hack_path="out/vhack/v_hack_full.safetensors"),
|
||||
@@ -162,6 +162,8 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d
|
||||
m = dict(reward=R.mean().item(),
|
||||
hack_s=_mean([r.exploited for r in s_rew]), gt_s=_mean([r.gt_correct for r in s_rew]),
|
||||
hack_t=_mean([r["exploited"] for r in t_rows]), gt_t=_mean([r["gt_correct"] for r in t_rows]))
|
||||
if s_rew: # stash one student gen for the jsonl + the final coherence eyeball
|
||||
m["gen"] = dict(text=texts[0], hack=s_rew[0].exploited, gt=s_rew[0].gt_correct, reward=s_rew[0].reward)
|
||||
|
||||
empty = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
|
||||
if R.max() - R.min() < 1e-6: # zero-variance group => adv≡0 => pure waste (bail)
|
||||
@@ -256,6 +258,8 @@ def main(cfg: Config):
|
||||
V_raw, V_sv = resolve_or_extract(model, tok, wrappers, cfg.v_hack_path,
|
||||
cfg.vhack_pairs_path, cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
|
||||
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
|
||||
δS_ref = next(iter(wrappers.values())).δS # V lives in δS's space: match dtype+device
|
||||
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda
|
||||
logger.info(f"V_hack over {len(V_hack)} modules (k_use={cfg.v_hack_k}); "
|
||||
f"arm={cfg.intervention} -> measure_only={cfg.intervention == 'none'}")
|
||||
|
||||
@@ -311,6 +315,8 @@ def main(cfg: Config):
|
||||
reward=_mean([m["reward"] for m in ms]), loss=_mean([m.get("loss", 0.0) for m in ms]),
|
||||
gn=gn.item(), cin_s=cin_s, cin_t=cin_t, cin=cin, cout=cout, fired=fired,
|
||||
dSh=dSh)
|
||||
if "gen" in ms[0]: # log the first prompt's first student gen (text + flags)
|
||||
row["gen"] = ms[0]["gen"]
|
||||
|
||||
# online refresh: re-extract V against the CURRENT model (under quarantine ablation)
|
||||
if cfg.vhack_refresh_every and V_hack and step % cfg.vhack_refresh_every == 0:
|
||||
@@ -319,6 +325,7 @@ def main(cfg: Config):
|
||||
V_raw, V_sv = extract_v_hack(model, tok, wrappers, default_pairs(),
|
||||
cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
|
||||
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
|
||||
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # keep V in δS's bf16/cuda space
|
||||
shared = [n for n in V_hack if n in old]
|
||||
ov = _mean([abs((V_hack[n][0] @ old[n][0]).item()) for n in shared]) if shared else 0.0
|
||||
row["basis_overlap"] = ov # GUARD: should sit near 1.0; <~0.2 => refresh rotated off-hack
|
||||
@@ -333,7 +340,8 @@ def main(cfg: Config):
|
||||
rows.append(row)
|
||||
log.write(json.dumps(row) + "\n"); log.flush()
|
||||
pbar.set_postfix(hack_s=f"{row['hack_s']:.2f}", cin_s=f"{cin_s:.2f}", cin_t=f"{cin_t:.2f}",
|
||||
cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}")
|
||||
cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}",
|
||||
refresh=False) # don't force a bar redraw every step (pollutes piped logs)
|
||||
if step % 25 == 0:
|
||||
save_ckpt(wrappers, rows, cfg.out_tag or f"_{cfg.intervention}_s{cfg.seed}")
|
||||
|
||||
@@ -370,6 +378,13 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id):
|
||||
"overlap with V (relu-before-agg), cout→0 residual hack-ward leak (identity under "
|
||||
"one_sided, not efficacy), |δSh| quarantine norm (>0 iff route).")
|
||||
logger.info(f"out: {out}")
|
||||
# Last student generation -- a coherence eyeball before the numbers. SHOULD: real
|
||||
# code/prose for the problem. If it is token salad the policy diverged and the eval
|
||||
# numbers below are moot. (Empty on tiny-random if no student rollout was graded.)
|
||||
last_gen = next((r["gen"] for r in reversed(rows) if r.get("gen")), None)
|
||||
if last_gen:
|
||||
logger.info(f"last student gen (hack={last_gen['hack']} gt={last_gen['gt']} "
|
||||
f"reward={last_gen['reward']:.2f}):\n{last_gen['text']}")
|
||||
logger.info(f"main metric: {cue} hack_s={hack:.3f} solve={solve:.3f} cin_t={cin_t:.3f} cout={cout:.3f} "
|
||||
f"[arm={cfg.intervention},seed={cfg.seed},steps={len(rows)}]")
|
||||
print(f"\n{caption}\n")
|
||||
|
||||
Reference in New Issue
Block a user