Include tau_axis in v_hack cache filename + plumb through Config

tau_axis is baked into the saved V at extract time (extract_vhack_grad
zeros rows where S_i/S_0 < tau_axis before SVD output is saved), so the
cached file content depends on it. The previous filename keyed only on
top_k, meaning a change to tau_axis would silently serve a stale cache.

Add Config.v_hack_tau_axis (default 0.0) and tag it into the filename
only when nonzero — so existing v_hack_Qwen3-4B_k12.safetensors files
remain reachable under the default config.

Future cache-key footgun (pairs.py changes) is flagged in a comment;
add a pairs hash when pair-set ablations begin.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 09:11:41 +00:00
parent 5bf2180248
commit e0f33045a9
+8 -3
View File
@@ -165,6 +165,7 @@ class Config:
v_hack_path: Path | None = None
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
@@ -364,8 +365,12 @@ def main(cfg: Config) -> int:
# Slug: works for HF names ("Qwen/Qwen3-4B" -> "Qwen3-4B") and local paths
# ("out/baked/qwen3_4b_rh25" -> "qwen3_4b_rh25").
model_slug = model_name.rstrip("/").split("/")[-1]
# Filename encodes top_k AND tau_axis because both are baked into the saved
# V (extract zeros rows where S_i/S_0 < tau_axis before saving). If a future
# ablation varies pairs.py, add a pairs hash here too.
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
if cfg.v_hack_path is None:
v_hack_path = OUT_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}.safetensors"
v_hack_path = OUT_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
else:
v_hack_path = cfg.v_hack_path
if not v_hack_path.exists():
@@ -376,14 +381,14 @@ def main(cfg: Config) -> int:
model.eval() # match standalone extract: deterministic backward, no dropout
v_hack_cpu_dict, raw_grads, _diag = extract_v_hack(
model, tok, wrappers, VHACK_PAIRS,
top_k=cfg.v_hack_extract_top_k, tau_axis=0.0,
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
n_heldout=2, device=device,
)
OUT_DIR.mkdir(exist_ok=True)
save_file(v_hack_cpu_dict, str(v_hack_path),
metadata={"model": model_name, "dtype": "bf16",
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
"tau_axis": "0.0", "schema": "v2_with_sv"})
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
# extract zeros grads at exit; opt is built below so no opt-state taint.
model.train() # restore train mode; eval was set only for the extract pass
v_hack_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k)