mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
Include tau_axis in v_hack cache filename + plumb through Config
tau_axis is baked into the saved V at extract time (extract_vhack_grad zeros rows where S_i/S_0 < tau_axis before SVD output is saved), so the cached file content depends on it. The previous filename keyed only on top_k, meaning a change to tau_axis would silently serve a stale cache. Add Config.v_hack_tau_axis (default 0.0) and tag it into the filename only when nonzero — so existing v_hack_Qwen3-4B_k12.safetensors files remain reachable under the default config. Future cache-key footgun (pairs.py changes) is flagged in a comment; add a pairs hash when pair-set ablations begin. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -165,6 +165,7 @@ class Config:
|
||||
v_hack_path: Path | None = None
|
||||
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
|
||||
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
|
||||
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
|
||||
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
|
||||
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
|
||||
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
|
||||
@@ -364,8 +365,12 @@ def main(cfg: Config) -> int:
|
||||
# Slug: works for HF names ("Qwen/Qwen3-4B" -> "Qwen3-4B") and local paths
|
||||
# ("out/baked/qwen3_4b_rh25" -> "qwen3_4b_rh25").
|
||||
model_slug = model_name.rstrip("/").split("/")[-1]
|
||||
# Filename encodes top_k AND tau_axis because both are baked into the saved
|
||||
# V (extract zeros rows where S_i/S_0 < tau_axis before saving). If a future
|
||||
# ablation varies pairs.py, add a pairs hash here too.
|
||||
tau_tag = f"_tau{cfg.v_hack_tau_axis:g}" if cfg.v_hack_tau_axis > 0 else ""
|
||||
if cfg.v_hack_path is None:
|
||||
v_hack_path = OUT_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}.safetensors"
|
||||
v_hack_path = OUT_DIR / f"v_hack_{model_slug}_k{cfg.v_hack_extract_top_k}{tau_tag}.safetensors"
|
||||
else:
|
||||
v_hack_path = cfg.v_hack_path
|
||||
if not v_hack_path.exists():
|
||||
@@ -376,14 +381,14 @@ def main(cfg: Config) -> int:
|
||||
model.eval() # match standalone extract: deterministic backward, no dropout
|
||||
v_hack_cpu_dict, raw_grads, _diag = extract_v_hack(
|
||||
model, tok, wrappers, VHACK_PAIRS,
|
||||
top_k=cfg.v_hack_extract_top_k, tau_axis=0.0,
|
||||
top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis,
|
||||
n_heldout=2, device=device,
|
||||
)
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
save_file(v_hack_cpu_dict, str(v_hack_path),
|
||||
metadata={"model": model_name, "dtype": "bf16",
|
||||
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
|
||||
"tau_axis": "0.0", "schema": "v2_with_sv"})
|
||||
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
|
||||
# extract zeros grads at exit; opt is built below so no opt-state taint.
|
||||
model.train() # restore train mode; eval was set only for the extract pass
|
||||
v_hack_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k)
|
||||
|
||||
Reference in New Issue
Block a user