diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 2c822c9..7c2b3ab 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -1430,11 +1430,20 @@ def main(cfg: Config) -> int: logger.disable("projected_grpo.extract_vhack_grad") logger.disable("__main__") try: - _new_V, _new_S, _, _ = extract_v_hack( - model, tok, wrappers, VHACK_PAIRS, - top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, - n_heldout=2, device=device, - ) + # Extract with the quarantine ablated (delta_S_hack=0). For route, + # once the hack capability has been routed into delta_S_hack, the + # main-knob gradient on the pairs no longer carries the hack + # direction -- so re-extracting through the live quarantine rotates + # v_hack off-hack and cin_t collapses at the refresh step. Ablating + # sends the hack back through the observable main path so D captures + # it, matching the delta_S_hack=0 state the build extraction saw. + # No-op for erase (delta_S_hack is never trained, stays 0). + with ablate_quarantine(wrappers): + _new_V, _new_S, _, _ = extract_v_hack( + model, tok, wrappers, VHACK_PAIRS, + top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, + n_heldout=2, device=device, + ) _post = postprocess_v_hack( _new_V, _new_S, k_use=cfg.v_hack_k, drop_bottom_frac=cfg.v_hack_drop_bottom_frac,