mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
train: online v_hack refresh every N steps
Re-extract the hack subspace V against the current (delta_S-modified) model on the same hand-crafted PAIRS, every --vhack-refresh-every steps. Motivated by the Goal 1 negative result (2026-05-28 c) where projection at frozen V did not slow hacking; one hypothesis is V drifts out of relevance as the student moves. Off by default (0). Factored the k_use slice + noise-floor filter into a shared postprocess_v_hack helper used by both init-time load and the in-loop refresh. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+61
-13
@@ -167,6 +167,12 @@ class Config:
|
|||||||
# global threshold get filtered out entirely (projection skips them — they
|
# global threshold get filtered out entirely (projection skips them — they
|
||||||
# didn't carry hack signal anyway). 0 = no filter.
|
# didn't carry hack signal anyway). 0 = no filter.
|
||||||
v_hack_drop_bottom_frac: float = 0.25
|
v_hack_drop_bottom_frac: float = 0.25
|
||||||
|
# Online refresh: every N optimizer steps, re-run extract_v_hack against the
|
||||||
|
# current (delta_S-modified) model using the same hand-crafted PAIRS. Motivated
|
||||||
|
# by Goal-1 negative result (2026-05-28 c) where v_hack appeared to lose
|
||||||
|
# discriminative power as the student drifted. 0 = off (load once at start
|
||||||
|
# and freeze). Refresh cost ~14*2 backwards on Qwen3-4B ~ 1-2 min wall.
|
||||||
|
vhack_refresh_every: int = 0
|
||||||
# Per-source cin diagnostic: split each prompt's backward into student-only
|
# Per-source cin diagnostic: split each prompt's backward into student-only
|
||||||
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
||||||
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
||||||
@@ -339,20 +345,37 @@ def load_v_hack(
|
|||||||
f"--model={model_name} --out-path={path}`."
|
f"--model={model_name} --out-path={path}`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
v_hack = postprocess_v_hack(
|
||||||
|
v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path),
|
||||||
|
)
|
||||||
|
return v_hack
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_v_hack(
|
||||||
|
v_hack: dict[str, Float[torch.Tensor, "k r"]],
|
||||||
|
v_sv: dict[str, Float[torch.Tensor, "k"]],
|
||||||
|
k_use: int | None,
|
||||||
|
drop_bottom_frac: float,
|
||||||
|
source: str = "<refresh>",
|
||||||
|
) -> dict[str, Float[torch.Tensor, "k r"]]:
|
||||||
|
"""Apply k_use slice + global noise-floor filter.
|
||||||
|
|
||||||
|
Shared between `load_v_hack` (init-time, reading from safetensors) and the
|
||||||
|
in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs).
|
||||||
|
Mutates neither input dict; returns a fresh filtered dict.
|
||||||
|
|
||||||
|
Global noise floor: collect every S_i across every module, drop the bottom
|
||||||
|
`drop_bottom_frac` by quantile. A module whose every axis falls below the
|
||||||
|
global threshold is removed entirely — projection iterates v_hack so it
|
||||||
|
becomes a no-op for that module. Threshold recomputes per call (tracks
|
||||||
|
current S distribution).
|
||||||
|
"""
|
||||||
k_max = next(iter(v_hack.values())).shape[0]
|
k_max = next(iter(v_hack.values())).shape[0]
|
||||||
if k_use is not None:
|
if k_use is not None:
|
||||||
if k_use > k_max:
|
if k_use > k_max:
|
||||||
raise ValueError(
|
raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})")
|
||||||
f"requested k_use={k_use} exceeds saved k_max={k_max} in {path}. "
|
|
||||||
f"Re-extract with `--top-k={k_use}`."
|
|
||||||
)
|
|
||||||
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
|
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
|
||||||
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
|
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
|
||||||
|
|
||||||
# Global noise floor: drop the bottom drop_bottom_frac of all (module, axis)
|
|
||||||
# pairs by S_i. One quantile across every S_i in every module. A module
|
|
||||||
# whose every axis lies below the global threshold is removed from v_hack
|
|
||||||
# entirely — projection iterates v_hack so that module just gets skipped.
|
|
||||||
n_dropped_modules = 0
|
n_dropped_modules = 0
|
||||||
n_axes_before = sum(v.shape[0] for v in v_hack.values())
|
n_axes_before = sum(v.shape[0] for v in v_hack.values())
|
||||||
threshold = None
|
threshold = None
|
||||||
@@ -368,11 +391,9 @@ def load_v_hack(
|
|||||||
n_dropped_modules += 1
|
n_dropped_modules += 1
|
||||||
v_hack = filtered
|
v_hack = filtered
|
||||||
n_axes_after = sum(v.shape[0] for v in v_hack.values())
|
n_axes_after = sum(v.shape[0] for v in v_hack.values())
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"loaded v_hack from {path}: modules={len(v_hack)} (dropped {n_dropped_modules}); "
|
f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); "
|
||||||
f"k_saved={k_max}, k_use={k_use or k_max}; "
|
f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept "
|
||||||
f"axes={n_axes_after}/{n_axes_before} kept "
|
|
||||||
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
|
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
|
||||||
)
|
)
|
||||||
return v_hack
|
return v_hack
|
||||||
@@ -1069,6 +1090,33 @@ table columns:
|
|||||||
opt.step()
|
opt.step()
|
||||||
sched.step()
|
sched.step()
|
||||||
|
|
||||||
|
# Online v_hack refresh: re-extract against the *current* model so the
|
||||||
|
# hack subspace tracks where the student is being pulled now (rather
|
||||||
|
# than at step 0). Same PAIRS, same extract code; we just discard the
|
||||||
|
# saved cache and overwrite the in-memory v_hack dict.
|
||||||
|
if cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0:
|
||||||
|
from .extract_vhack_grad import extract_v_hack
|
||||||
|
from .pairs import PAIRS as VHACK_PAIRS
|
||||||
|
_was_training = model.training
|
||||||
|
model.eval()
|
||||||
|
opt.zero_grad(set_to_none=True)
|
||||||
|
_new_V, _new_S, _, _ = extract_v_hack(
|
||||||
|
model, tok, wrappers, VHACK_PAIRS,
|
||||||
|
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
|
||||||
|
n_heldout=2, device=device,
|
||||||
|
)
|
||||||
|
_post = postprocess_v_hack(
|
||||||
|
_new_V, _new_S, k_use=cfg.v_hack_k,
|
||||||
|
drop_bottom_frac=cfg.v_hack_drop_bottom_frac,
|
||||||
|
source=f"refresh@step{step}",
|
||||||
|
)
|
||||||
|
v_hack.clear()
|
||||||
|
v_hack.update({n: V.to(device) for n, V in _post.items()})
|
||||||
|
opt.zero_grad(set_to_none=True) # extract leaves .grad populated
|
||||||
|
if _was_training:
|
||||||
|
model.train()
|
||||||
|
logger.info(f"v_hack refreshed @ step={step}: {len(v_hack)} modules")
|
||||||
|
|
||||||
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
|
rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1)
|
||||||
rew_mean = rewards_t.mean().item()
|
rew_mean = rewards_t.mean().item()
|
||||||
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0
|
rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user