diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py index a417822..75315d1 100644 --- a/src/projected_grpo/proj.py +++ b/src/projected_grpo/proj.py @@ -25,6 +25,7 @@ def project_delta_S_grad( wrappers: dict, v_hack: dict[str, torch.Tensor], preserve_magnitude: bool, + measure_only: bool = False, ) -> dict[str, float]: """Per-module one-sided removal of v_hack-aligned component from delta_S.grad. @@ -32,6 +33,9 @@ def project_delta_S_grad( If cos(g, v) > 0: g' = g - v (remove projection onto v). Optionally rescale g' to ||g|| to preserve update magnitude. Else leave g untouched. + If `measure_only`: same cosine math, but the gradient is NOT mutated. + Used by vanilla arm to report cos_in trajectory as a diagnostic. + Returns aggregate diagnostics: mean_cos_in, mean_cos_out, frac_fired. """ cos_in_list, cos_out_list, n_fired = [], [], 0 @@ -53,7 +57,8 @@ def project_delta_S_grad( g_proj = g_proj * (gn / gp_n) cos_out = (g_proj @ v) / g_proj.norm().clamp_min(1e-12) cos_out_list.append(cos_out.item()) - info["delta_S"].grad = g_proj + if not measure_only: + info["delta_S"].grad = g_proj n_fired += 1 else: cos_out_list.append(cos_in.item()) diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 8ea178e..7e8c888 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -317,13 +317,15 @@ def main(cfg: Config) -> int: delta_params = [info["delta_S"] for info in wrappers.values()] logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}") - # v_hack only needed for projected arm. Vanilla H4 sanity runs do not - # require a precomputed v_hack and should not be blocked by missing one. - if cfg.arm == "projected": + # v_hack: loaded for both arms when the file is present, so vanilla also + # reports cos_in as a diagnostic (no projection applied). If not present + # and arm=vanilla, skip silently — H4 sanity runs without v_hack remain valid. + v_hack = None + if cfg.v_hack_path.exists(): v_hack_cpu = load_v_hack(cfg.v_hack_path, model_name, wrappers) v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()} - else: - v_hack = None + elif cfg.arm == "projected": + raise FileNotFoundError(f"projected arm requires v_hack at {cfg.v_hack_path}") opt = torch.optim.AdamW( delta_params, lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(cfg.adam_beta1, cfg.adam_beta2), @@ -535,9 +537,12 @@ def main(cfg: Config) -> int: agg_loss += loss.item() t_fb += time.perf_counter() - _tfb - # One projection on accumulated grads (projected arm only). - if cfg.arm == "projected": - diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude) + # Diagnostic cos_in for both arms; projection only mutates grad if arm=projected. + if v_hack is not None: + diag = project_delta_S_grad( + wrappers, v_hack, cfg.preserve_magnitude, + measure_only=(cfg.arm != "projected"), + ) else: diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")}