Global noise-floor filter on v_hack at load time

drop_bottom_frac (default 0.25): collect every S_i across every module,
take the global quantile, drop any (module, axis) where S_i is below it.
Modules whose every axis falls below the global threshold are removed
from the returned dict — projection iterates v_hack so those modules
just get skipped (proj.py: name not in v_hack -> continue).

One physically meaningful threshold, applied once, at load. Global
rather than per-module is intentional: per-module would protect the
weakest modules from filtering (they always have a top axis), defeating
the noise-floor goal. A module's "weakest" axis being weaker than the
strongest axis of a stronger module is exactly the right reason to
drop it.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 09:37:49 +00:00
parent 9ba7b818a9
commit 477380603f
2 changed files with 50 additions and 11 deletions
+2
View File
@@ -81,6 +81,8 @@ def project_delta_S_grad(
g = info["delta_S"].grad
if g is None:
continue
if name not in v_hack: # module dropped by global noise-floor filter
continue
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
gn = g.norm()
if gn < 1e-12:
+48 -11
View File
@@ -166,6 +166,11 @@ class Config:
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
# Load-time global noise floor: collect all S_i across all modules and drop
# the bottom frac by quantile. Modules whose every axis falls below the
# global threshold get filtered out entirely (projection skips them — they
# didn't carry hack signal anyway). 0 = no filter.
v_hack_drop_bottom_frac: float = 0.25
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
@@ -228,18 +233,23 @@ def load_problems(n: int) -> list[dict]:
def load_v_hack(
path: Path, model_name: str, wrappers: dict, k_use: int | None = None,
path: Path, model_name: str, wrappers: dict,
k_use: int | None = None, drop_bottom_frac: float = 0.0,
) -> dict[str, torch.Tensor]:
"""Load v_hack (top-k directions) for this wrapped model.
File schema (v2): bare `{name}` keys hold V[k_max, r]; `_sv/{name}` keys hold
S[k_max] (read but not returned — no caller uses them yet). v_hack is
model-specific because module names and per-module SVD ranks depend on the
exact checkpoint; a smoke (Qwen3.5-0.8B) v_hack must not be reused for a
full (Qwen3-4B) run.
S[k_max]. v_hack is model-specific because module names and per-module SVD
ranks depend on the exact checkpoint; a smoke (Qwen3.5-0.8B) v_hack must
not be reused for a full (Qwen3-4B) run.
If `k_use` is given, slices V to top-k_use rows. Errors if k_use > k_max
saved (re-extract with a higher top_k).
If `k_use` is given, slices V (and S) to top-k_use rows. Errors if
k_use > k_max saved (re-extract with a higher top_k).
If `drop_bottom_frac > 0`, collects every S_i across every module and drops
the bottom-fraction by global quantile. Modules whose every axis is below
the global threshold get filtered out of the returned dict (projection on
those modules becomes a no-op — they didn't carry hack signal anywhere).
"""
with safe_open(str(path), framework="pt", device="cpu") as f:
meta = f.metadata() or {}
@@ -258,9 +268,8 @@ def load_v_hack(
f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; "
"train.py loads models in bf16. Re-extract with `--dtype=bf16`."
)
# Read only V keys (bare module names); _sv/{name} keys are saved
# alongside but no runtime path consumes them currently.
v_hack = {k: f.get_tensor(k) for k in f.keys() if not k.startswith("_sv/")}
v_sv = {k[len("_sv/"):]: f.get_tensor(k) for k in f.keys() if k.startswith("_sv/")}
wrapper_keys = set(wrappers)
vhack_keys = set(v_hack)
@@ -290,8 +299,33 @@ def load_v_hack(
f"Re-extract with `--top-k={k_use}`."
)
v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()}
v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()}
# Global noise floor: drop the bottom drop_bottom_frac of all (module, axis)
# pairs by S_i. One quantile across every S_i in every module. A module
# whose every axis lies below the global threshold is removed from v_hack
# entirely — projection iterates v_hack so that module just gets skipped.
n_dropped_modules = 0
n_axes_before = sum(v.shape[0] for v in v_hack.values())
threshold = None
if drop_bottom_frac > 0 and v_sv:
all_S = torch.cat([v_sv[n].float() for n in v_hack])
threshold = torch.quantile(all_S, drop_bottom_frac).item()
filtered: dict[str, torch.Tensor] = {}
for name, V in v_hack.items():
keep = v_sv[name].float() >= threshold
if keep.any():
filtered[name] = V[keep].contiguous()
else:
n_dropped_modules += 1
v_hack = filtered
n_axes_after = sum(v.shape[0] for v in v_hack.values())
logger.info(
f"loaded v_hack from {path}: modules={len(v_hack)}; k_saved={k_max}, k_use={k_use or k_max}"
f"loaded v_hack from {path}: modules={len(v_hack)} (dropped {n_dropped_modules}); "
f"k_saved={k_max}, k_use={k_use or k_max}; "
f"axes={n_axes_after}/{n_axes_before} kept "
f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})"
)
return v_hack
@@ -396,7 +430,10 @@ def main(cfg: Config) -> int:
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
# extract zeros grads at exit; opt is built below so no opt-state taint.
model.train() # restore train mode; eval was set only for the extract pass
v_hack_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k)
v_hack_cpu = load_v_hack(
v_hack_path, model_name, wrappers,
k_use=cfg.v_hack_k, drop_bottom_frac=cfg.v_hack_drop_bottom_frac,
)
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,