train: online v_hack refresh every N steps

Re-extract the hack subspace V against the current (delta_S-modified) model on
the same hand-crafted PAIRS, every --vhack-refresh-every steps. Motivated by
the Goal 1 negative result (2026-05-28 c) where projection at frozen V did not
slow hacking; one hypothesis is V drifts out of relevance as the student moves.

Off by default (0). Factored the k_use slice + noise-floor filter into a shared
postprocess_v_hack helper used by both init-time load and the in-loop refresh.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-28 09:42:17 +00:00
parent 1e3d39e318
commit 16e2c37de6
+61 -13
View File
@@ -167,6 +167,12 @@ class Config:
# global threshold get filtered out entirely (projection skips them — they # global threshold get filtered out entirely (projection skips them — they
# didn't carry hack signal anyway). 0 = no filter. # didn't carry hack signal anyway). 0 = no filter.
v_hack_drop_bottom_frac: float = 0.25 v_hack_drop_bottom_frac: float = 0.25
# Online refresh: every N optimizer steps, re-run extract_v_hack against the
# current (delta_S-modified) model using the same hand-crafted PAIRS. Motivated
# by Goal-1 negative result (2026-05-28 c) where v_hack appeared to lose
# discriminative power as the student drifted. 0 = off (load once at start
# and freeze). Refresh cost ~14*2 backwards on Qwen3-4B ~ 1-2 min wall.
vhack_refresh_every: int = 0
# Per-source cin diagnostic: split each prompt's backward into student-only # Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full # + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
@@ -339,20 +345,37 @@ def load_v_hack(
f"--model={model_name} --out-path={path}`." f"--model={model_name} --out-path={path}`."
) )
v_hack = postprocess_v_hack(
v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path),
)
return v_hack
def postprocess_v_hack(
v_hack: dict[str, Float[torch.Tensor, "k r"]],
v_sv: dict[str, Float[torch.Tensor, "k"]],
k_use: int | None,
drop_bottom_frac: float,
source: str = "<refresh>",
) -> dict[str, Float[torch.Tensor, "k r"]]:
"""Apply k_use slice + global noise-floor filter.
Shared between `load_v_hack` (init-time, reading from safetensors) and the
in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs).
Mutates neither input dict; returns a fresh filtered dict.
Global noise floor: collect every S_i across every module, drop the bottom
`drop_bottom_frac` by quantile. A module whose every axis falls below the
global threshold is removed entirely — projection iterates v_hack so it
becomes a no-op for that module. Threshold recomputes per call (tracks
current S distribution).
"""
k_max = next(iter(v_hack.values())).shape[0] k_max = next(iter(v_hack.values())).shape[0]
if k_use is not None: if k_use is not None:
if k_use > k_max: if k_use > k_max:
raise ValueError( raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})")
f"requested k_use={k_use} exceeds saved k_max={k_max} in {path}. "
f"Re-extract with `--top-k={k_use}`."
)
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()} v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()} v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
# Global noise floor: drop the bottom drop_bottom_frac of all (module, axis)
# pairs by S_i. One quantile across every S_i in every module. A module
# whose every axis lies below the global threshold is removed from v_hack
# entirely — projection iterates v_hack so that module just gets skipped.
n_dropped_modules = 0 n_dropped_modules = 0
n_axes_before = sum(v.shape[0] for v in v_hack.values()) n_axes_before = sum(v.shape[0] for v in v_hack.values())
threshold = None threshold = None
@@ -368,11 +391,9 @@ def load_v_hack(
n_dropped_modules += 1 n_dropped_modules += 1
v_hack = filtered v_hack = filtered
n_axes_after = sum(v.shape[0] for v in v_hack.values()) n_axes_after = sum(v.shape[0] for v in v_hack.values())
logger.info( logger.info(
f"loaded v_hack from {path}: modules={len(v_hack)} (dropped {n_dropped_modules}); " f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); "
f"k_saved={k_max}, k_use={k_use or k_max}; " f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept "
f"axes={n_axes_after}/{n_axes_before} kept "
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})" f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
) )
return v_hack return v_hack
@@ -1069,6 +1090,33 @@ table columns:
opt.step() opt.step()
sched.step() sched.step()
# Online v_hack refresh: re-extract against the *current* model so the
# hack subspace tracks where the student is being pulled now (rather
# than at step 0). Same PAIRS, same extract code; we just discard the
# saved cache and overwrite the in-memory v_hack dict.
if cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0:
from .extract_vhack_grad import extract_v_hack
from .pairs import PAIRS as VHACK_PAIRS
_was_training = model.training
model.eval()
opt.zero_grad(set_to_none=True)
_new_V, _new_S, _, _ = extract_v_hack(
model, tok, wrappers, VHACK_PAIRS,
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
n_heldout=2, device=device,
)
_post = postprocess_v_hack(
_new_V, _new_S, k_use=cfg.v_hack_k,
drop_bottom_frac=cfg.v_hack_drop_bottom_frac,
source=f"refresh@step{step}",
)
v_hack.clear()
v_hack.update({n: V.to(device) for n, V in _post.items()})
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
if _was_training:
model.train()
logger.info(f"v_hack refreshed @ step={step}: {len(v_hack)} modules")
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1) rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
rew_mean = rewards_t.mean().item() rew_mean = rewards_t.mean().item()
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0 rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0