mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
ready
This commit is contained in:
@@ -4,6 +4,35 @@ Append-only. New entries at the top, date-stamped. Never edit old entries.
|
||||
|
||||
# 2026-05-30
|
||||
|
||||
## 96GB readiness review fixes
|
||||
|
||||
Fresh subagent review found a real silent-failure risk: `v_hack` is not just
|
||||
model-specific, it is also SVD-basis-specific. The old extractor loaded fp32
|
||||
while `train.py` loaded bf16, so keys/ranks could match while the basis differed.
|
||||
Fix: `extract_vhack_grad.py`, `verify_vhack_heldout.py`, and `train.py` now all
|
||||
use bf16 by default; `v_hack` artifacts save `{model, dtype, v_hack}` metadata;
|
||||
`train.py` refuses legacy artifacts and checks exact module keys and per-module
|
||||
rank before first generation.
|
||||
|
||||
Also removed a bad smoke convenience: zero-spread reward batches no longer get
|
||||
random advantages. Dr.GRPO now correctly gives zero advantage when all group
|
||||
rewards match, so logs cannot look healthy while training on reward-unrelated
|
||||
noise.
|
||||
|
||||
Validated on the 24GB box:
|
||||
|
||||
- `just extract-vhack-smoke` via pueue task 73: bf16, 186 modules, 148,032
|
||||
delta_S scalars, zero-norm=0.
|
||||
- `just verify-vhack-smoke` via pueue task 74: `frac>0=0.952`, `mean=+0.355`,
|
||||
`median=+0.363`, target pass.
|
||||
- one-step canonical train probe via pueue task 75: loaded `out/v_hack_smoke.pt`
|
||||
with key/rank match OK, completed without legacy artifact. Reward spread was
|
||||
false and loss/cos/fired were zero, as expected after removing random advantages.
|
||||
|
||||
For the 96GB machine, do not start `queue-full` blindly. First run one sequential
|
||||
gate: `pueue add --immediate --follow -w "$PWD" -o 9 -l "why: gated full probe; resolve: extract+heldout pass, vanilla hacks, projected fires" -- just probe-full-seed 41`.
|
||||
Only queue 3 seeds after the vanilla probe has nontrivial hack rate.
|
||||
|
||||
## Mechanism end-to-end verified on Qwen3.5-0.8B; H4 falsified at this scale
|
||||
|
||||
Closed the smoke loop: AntiPaSTO identity (bf16, max_abs_diff=0) -> v_hack
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,20 +20,20 @@ fast-dev-run *ARGS:
|
||||
# lite = Qwen2.5-Coder-1.5B 100 steps, fits ~40GB.
|
||||
# full = Qwen2.5-Coder-7B 200 steps, needs >=80GB. Publication-grade.
|
||||
smoke *ARGS:
|
||||
{{ TRAIN }} --preset=smoke --arm=projected {{ ARGS }}
|
||||
{{ TRAIN }} --preset=smoke --arm=projected --v-hack-path=out/v_hack_smoke.pt {{ ARGS }}
|
||||
|
||||
smoke-vanilla *ARGS:
|
||||
{{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
|
||||
{{ TRAIN }} --preset=smoke --arm=vanilla --v-hack-path=out/v_hack_smoke.pt {{ ARGS }}
|
||||
|
||||
smoke-both:
|
||||
{{ TRAIN }} --preset=smoke --arm=vanilla
|
||||
{{ TRAIN }} --preset=smoke --arm=projected
|
||||
{{ TRAIN }} --preset=smoke --arm=vanilla --v-hack-path=out/v_hack_smoke.pt
|
||||
{{ TRAIN }} --preset=smoke --arm=projected --v-hack-path=out/v_hack_smoke.pt
|
||||
|
||||
lite *ARGS:
|
||||
{{ TRAIN }} --preset=lite --arm=projected {{ ARGS }}
|
||||
{{ TRAIN }} --preset=lite --arm=projected --v-hack-path=out/v_hack_lite.pt {{ ARGS }}
|
||||
|
||||
full *ARGS:
|
||||
{{ TRAIN }} --preset=full --arm=projected {{ ARGS }}
|
||||
{{ TRAIN }} --preset=full --arm=projected --v-hack-path=out/v_hack_full.pt {{ ARGS }}
|
||||
|
||||
# Sync the rl-rewardhacking external repo (Nanda's verl wrapper).
|
||||
sync-external:
|
||||
@@ -45,43 +45,93 @@ download-model:
|
||||
uv run python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download('Qwen/Qwen2.5-1.5B', allow_patterns=['*.json','*.txt','tokenizer*','*.safetensors'])"
|
||||
|
||||
# Queue all sweep arms via pueue. Comment out arms that are done.
|
||||
# Run priorities: vanilla baseline first (we need its numbers to compare).
|
||||
queue:
|
||||
extract-vhack-smoke:
|
||||
uv run python -m projected_grpo.extract_vhack_grad \
|
||||
--model=Qwen/Qwen3.5-0.8B \
|
||||
--dtype=bf16 \
|
||||
--out-path=out/v_hack_smoke.pt \
|
||||
--train-grads-path=out/vhack_grads_train_smoke.pt
|
||||
|
||||
extract-vhack-lite:
|
||||
uv run python -m projected_grpo.extract_vhack_grad \
|
||||
--model=Qwen/Qwen2.5-Coder-1.5B \
|
||||
--dtype=bf16 \
|
||||
--out-path=out/v_hack_lite.pt \
|
||||
--train-grads-path=out/vhack_grads_train_lite.pt
|
||||
|
||||
extract-vhack-full:
|
||||
uv run python -m projected_grpo.extract_vhack_grad \
|
||||
--model=Qwen/Qwen2.5-Coder-7B \
|
||||
--dtype=bf16 \
|
||||
--out-path=out/v_hack_full.pt \
|
||||
--train-grads-path=out/vhack_grads_train_full.pt
|
||||
|
||||
verify-vhack-smoke:
|
||||
uv run python -m projected_grpo.verify_vhack_heldout \
|
||||
--model=Qwen/Qwen3.5-0.8B \
|
||||
--dtype=bf16 \
|
||||
--v-hack-path=out/v_hack_smoke.pt \
|
||||
--out-path=out/vhack_heldout_cos_smoke.pt
|
||||
|
||||
verify-vhack-lite:
|
||||
uv run python -m projected_grpo.verify_vhack_heldout \
|
||||
--model=Qwen/Qwen2.5-Coder-1.5B \
|
||||
--dtype=bf16 \
|
||||
--v-hack-path=out/v_hack_lite.pt \
|
||||
--out-path=out/vhack_heldout_cos_lite.pt
|
||||
|
||||
verify-vhack-full:
|
||||
uv run python -m projected_grpo.verify_vhack_heldout \
|
||||
--model=Qwen/Qwen2.5-Coder-7B \
|
||||
--dtype=bf16 \
|
||||
--v-hack-path=out/v_hack_full.pt \
|
||||
--out-path=out/vhack_heldout_cos_full.pt
|
||||
|
||||
# One sequential 96GB gate: extract -> heldout validate -> vanilla seed -> projected seed.
|
||||
# Use this before queue-full; it avoids pueue dependency races and proves the substrate hacks.
|
||||
probe-full-seed seed="41":
|
||||
just extract-vhack-full
|
||||
just verify-vhack-full
|
||||
{{ TRAIN }} --preset=full --arm=vanilla --seed={{ seed }} --v-hack-path=out/v_hack_full.pt --out-tag=_full_vanilla_seed{{ seed }}_probe
|
||||
{{ TRAIN }} --preset=full --arm=projected --seed={{ seed }} --v-hack-path=out/v_hack_full.pt --out-tag=_full_projected_seed{{ seed }}_probe
|
||||
|
||||
# Queue all sweep arms via pueue. Run v_hack extraction first, then vanilla+projected.
|
||||
queue-lite:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
just queue-vanilla
|
||||
just queue-projected-m16
|
||||
# just queue-projected-no-svd # H2 ablation
|
||||
# just queue-projected-no-magnorm # design ablation
|
||||
# just queue-rebound # H3 baseline
|
||||
# just queue-projected-m8 # H2 sweep
|
||||
# just queue-projected-m32 # H2 sweep
|
||||
pueue add -w "$PWD" -o 6 \
|
||||
-l "why: extract lite v_hack for exact checkpoint; resolve: out/v_hack_lite.pt exists and train.py key/rank check passes" \
|
||||
-- just extract-vhack-lite
|
||||
just queue-vanilla lite out/v_hack_lite.pt
|
||||
just queue-projected lite out/v_hack_lite.pt
|
||||
|
||||
# Vanilla GRPO baseline, 3 seeds. H: hack rate >30% at step 200 per spec H4.
|
||||
# Real run goes through Ariahw's verl pipeline (NOT our smoke run.py).
|
||||
queue-vanilla:
|
||||
queue-full:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
pueue add -w "$PWD" -o 6 \
|
||||
-l "why: extract full v_hack for exact checkpoint; resolve: out/v_hack_full.pt exists and train.py key/rank check passes" \
|
||||
-- just extract-vhack-full
|
||||
just queue-vanilla full out/v_hack_full.pt
|
||||
just queue-projected full out/v_hack_full.pt
|
||||
|
||||
# Vanilla GRPO baseline, 3 seeds. H: baseline hack rate >30% at step 200 per spec H4.
|
||||
queue-vanilla preset="lite" vhack="out/v_hack_lite.pt":
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
for seed in {{ SEEDS_3 }}; do
|
||||
pueue add -w "$PWD/external/rl-rewardhacking" -o 5 \
|
||||
-l "why: H4 sanity, does {{ MODEL }} reward-hack at all; resolve: if <30% hack rate at step 200, swap MODEL to Qwen/Qwen3-4B + reduce NUM_GEN to 4" \
|
||||
-- uv run python scripts/run_rl_training.py no_intervention \
|
||||
--model_id={{ MODEL }} --seed=$seed \
|
||||
--num_generations={{ NUM_GEN }} --per_device_batch_size={{ BATCH }}
|
||||
pueue add -w "$PWD" -o 5 \
|
||||
-l "why: H4 sanity {{ preset }}, does exact train.py substrate reward-hack; resolve: if <30% hack at final window, escalate model/prompt before H1" \
|
||||
-- {{ TRAIN }} --preset={{ preset }} --arm=vanilla --seed=$seed --v-hack-path={{ vhack }}
|
||||
done
|
||||
|
||||
# Projected gradient, m=16, 3 seeds. H1 main result.
|
||||
# TODO: integrate project_grad_per_row into verl's GRPO trainer. Currently the
|
||||
# justfile recipe still calls our smoke run.py end-to-end; this is a placeholder
|
||||
# until the verl-wrapped projection is wired (next task on GPU box).
|
||||
queue-projected-m16:
|
||||
# Projected gradient, 3 seeds. H1 main result.
|
||||
queue-projected preset="lite" vhack="out/v_hack_lite.pt":
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
for seed in {{ SEEDS_3 }}; do
|
||||
pueue add -w "$PWD" -o 4 \
|
||||
-l "why: H1 main, gradient proj reduces hack rate >=30pp at matched pass; resolve: publish if H1 holds; BLOCKED: needs verl integration" \
|
||||
-- {{ BASE }} --arm=projected --m=16 --seed=$seed --model={{ MODEL }} --steps=200
|
||||
-l "why: H1 {{ preset }}, projected delta_S grad reduces hack rate >=30pp at matched pass; resolve: compare to same-seed vanilla logs" \
|
||||
-- {{ TRAIN }} --preset={{ preset }} --arm=projected --seed=$seed --v-hack-path={{ vhack }}
|
||||
done
|
||||
|
||||
# Diagnostic: print v_hack steering check (CAA-style) on base model.
|
||||
|
||||
@@ -15,9 +15,11 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -26,10 +28,21 @@ from .antipasto import wrap_model_with_antipasto
|
||||
from .pairs import PAIRS
|
||||
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
N_HELDOUT = 5 # last 5 pairs reserved for held-out validation
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = "Qwen/Qwen3.5-0.8B"
|
||||
dtype: str = "bf16" # must match train.py, else SVD basis cache can differ silently
|
||||
out_path: Path = OUT_DIR / "v_hack.pt"
|
||||
train_grads_path: Path = OUT_DIR / "vhack_grads_train.pt"
|
||||
n_heldout: int = 5 # last n pairs reserved for held-out validation
|
||||
|
||||
|
||||
def resolve_dtype(s: str) -> torch.dtype:
|
||||
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[s]
|
||||
|
||||
|
||||
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
||||
@@ -47,22 +60,28 @@ def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> to
|
||||
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"device={device} model={MODEL} N_pairs={len(PAIRS)} heldout={N_HELDOUT}")
|
||||
dtype = resolve_dtype(cfg.dtype)
|
||||
logger.info(
|
||||
f"device={device} model={cfg.model} dtype={cfg.dtype} "
|
||||
f"N_pairs={len(PAIRS)} heldout={cfg.n_heldout}"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL, attn_implementation="sdpa").to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
cfg.model, dtype=dtype, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
model.eval() # disable dropout; gradients still flow through delta_S
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name=MODEL, cache_root=CACHE_ROOT, svd_device=device,
|
||||
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
|
||||
)
|
||||
n_mod = len(wrappers)
|
||||
n_delta = sum(info["delta_S"].numel() for info in wrappers.values())
|
||||
logger.info(f"wrapped {n_mod} modules; total delta_S scalars = {n_delta:,}")
|
||||
|
||||
train_pairs = PAIRS[:-N_HELDOUT]
|
||||
held_pairs = PAIRS[-N_HELDOUT:]
|
||||
train_pairs = PAIRS[:-cfg.n_heldout]
|
||||
held_pairs = PAIRS[-cfg.n_heldout:]
|
||||
logger.info(f"train pairs: {len(train_pairs)} held: {len(held_pairs)}")
|
||||
|
||||
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
@@ -89,8 +108,8 @@ def main() -> int:
|
||||
# save raw grads for held-out validation reuse
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
torch.save(
|
||||
{"grads_hack": dict(grads_hack), "grads_clean": dict(grads_clean)},
|
||||
OUT_DIR / "vhack_grads_train.pt",
|
||||
{"model": cfg.model, "dtype": cfg.dtype, "grads_hack": dict(grads_hack), "grads_clean": dict(grads_clean)},
|
||||
cfg.train_grads_path,
|
||||
)
|
||||
|
||||
v_hack: dict[str, torch.Tensor] = {}
|
||||
@@ -115,7 +134,7 @@ def main() -> int:
|
||||
"cos(g_h,g_c)": f"{(gh @ gc / (gh.norm()*gc.norm()+1e-12)).item():+.3f}",
|
||||
})
|
||||
|
||||
torch.save(v_hack, OUT_DIR / "v_hack.pt")
|
||||
torch.save({"model": cfg.model, "dtype": cfg.dtype, "v_hack": v_hack}, cfg.out_path)
|
||||
|
||||
# summary: aggregate by suffix
|
||||
by_suffix: dict[str, list] = defaultdict(list)
|
||||
@@ -133,7 +152,7 @@ def main() -> int:
|
||||
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="pipe"))
|
||||
logger.info(
|
||||
f"v_hack saved to {OUT_DIR / 'v_hack.pt'} "
|
||||
f"v_hack saved to {cfg.out_path} "
|
||||
f"modules={len(v_hack)} zero-norm={n_zero} "
|
||||
f"SHOULD: zero-norm == 0 and per-suffix mean_||diff|| > 0. "
|
||||
f"ELSE: pairs not flowing gradients into delta_S (check requires_grad / hook attach)."
|
||||
@@ -146,4 +165,4 @@ def main() -> int:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
|
||||
@@ -85,6 +85,7 @@ class Config:
|
||||
seed: int = 41
|
||||
preserve_magnitude: bool = True
|
||||
unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R)
|
||||
v_hack_path: Path = OUT_DIR / "v_hack.pt"
|
||||
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
|
||||
|
||||
def resolved(self) -> dict:
|
||||
@@ -112,6 +113,54 @@ def load_problems(n: int) -> list[dict]:
|
||||
return out
|
||||
|
||||
|
||||
def load_v_hack(path: Path, model_name: str, wrappers: dict) -> dict[str, torch.Tensor]:
|
||||
"""Load v_hack and fail fast if it is not for this wrapped model.
|
||||
|
||||
v_hack is model-specific because module names and per-module SVD ranks depend
|
||||
on the exact checkpoint. A Qwen3.5-0.8B v_hack must not be reused for a
|
||||
Qwen2.5-Coder-7B run.
|
||||
"""
|
||||
obj = torch.load(path, map_location="cpu", weights_only=False)
|
||||
if isinstance(obj, dict) and "v_hack" in obj:
|
||||
saved_model = obj["model"]
|
||||
if saved_model != model_name:
|
||||
raise ValueError(f"v_hack model mismatch: {path} has {saved_model}, run uses {model_name}")
|
||||
saved_dtype = obj.get("dtype", "unknown")
|
||||
if saved_dtype != "bf16":
|
||||
raise ValueError(
|
||||
f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; "
|
||||
"train.py loads models in bf16. Re-extract with `--dtype=bf16`."
|
||||
)
|
||||
v_hack = obj["v_hack"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{path} is a legacy v_hack without model/dtype metadata. "
|
||||
"Re-extract with `uv run python -m projected_grpo.extract_vhack_grad "
|
||||
f"--model={model_name} --dtype=bf16 --out-path={path}`."
|
||||
)
|
||||
|
||||
wrapper_keys = set(wrappers)
|
||||
vhack_keys = set(v_hack)
|
||||
missing = sorted(wrapper_keys - vhack_keys)
|
||||
extra = sorted(vhack_keys - wrapper_keys)
|
||||
rank_bad = [
|
||||
(name, tuple(v_hack[name].shape), tuple(wrappers[name]["delta_S"].shape))
|
||||
for name in sorted(wrapper_keys & vhack_keys)
|
||||
if tuple(v_hack[name].shape) != tuple(wrappers[name]["delta_S"].shape)
|
||||
]
|
||||
if missing or extra or rank_bad:
|
||||
raise ValueError(
|
||||
"v_hack incompatible with wrapped model: "
|
||||
f"missing={len(missing)} examples={missing[:5]} "
|
||||
f"extra={len(extra)} examples={extra[:5]} "
|
||||
f"rank_bad={len(rank_bad)} examples={rank_bad[:5]}. "
|
||||
"Extract a fresh v_hack with `uv run python -m projected_grpo.extract_vhack_grad "
|
||||
f"--model={model_name} --out-path={path}`."
|
||||
)
|
||||
logger.info(f"loaded v_hack from {path}: modules={len(v_hack)}; key/rank match OK")
|
||||
return v_hack
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ref_logprobs_via_zero_delta(
|
||||
model, merged: torch.Tensor, wrappers: dict,
|
||||
@@ -157,7 +206,8 @@ def main(cfg: Config) -> int:
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}")
|
||||
|
||||
v_hack = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True)
|
||||
v_hack_cpu = load_v_hack(cfg.v_hack_path, model_name, wrappers)
|
||||
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
|
||||
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
||||
|
||||
gen_cfg = GenerationConfig(
|
||||
@@ -203,17 +253,15 @@ def main(cfg: Config) -> int:
|
||||
rewards = torch.tensor(rs, dtype=torch.float32, device=device)
|
||||
|
||||
# Dr.GRPO advantage: R - mean(R). Unbiased: drop /std(R).
|
||||
# If no spread (all rewards equal), fall back to small noise so the
|
||||
# PPO ratio still gets a learning signal -- otherwise the policy is
|
||||
# frozen on uniform-reward steps (common at init on hard tasks).
|
||||
# If no spread (all rewards equal), advantage is exactly zero. Do NOT
|
||||
# inject random gradients; that would make projection logs look healthy
|
||||
# while training on reward-unrelated noise.
|
||||
centered = rewards - rewards.mean()
|
||||
if cfg.unbiased:
|
||||
adv = centered
|
||||
else:
|
||||
adv = centered / (rewards.std() + 1e-4)
|
||||
spread = (rewards.max() - rewards.min()).item() > 1e-3
|
||||
if not spread:
|
||||
adv = torch.randn(group, device=device) * 0.1
|
||||
|
||||
# Old-policy logprobs (frozen target for PPO ratio).
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -14,39 +14,52 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
from .extract_vhack_grad import N_HELDOUT, completion_nll
|
||||
from .extract_vhack_grad import completion_nll, resolve_dtype
|
||||
from .pairs import PAIRS
|
||||
from .train import load_v_hack
|
||||
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = "Qwen/Qwen3.5-0.8B"
|
||||
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
|
||||
v_hack_path: Path = OUT_DIR / "v_hack_smoke.pt"
|
||||
out_path: Path = OUT_DIR / "vhack_heldout_cos.pt"
|
||||
n_heldout: int = 5
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"device={device} model={MODEL}")
|
||||
dtype = resolve_dtype(cfg.dtype)
|
||||
logger.info(f"device={device} model={cfg.model} dtype={cfg.dtype}")
|
||||
|
||||
v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location="cpu", weights_only=True)
|
||||
logger.info(f"loaded v_hack: {len(v_hack)} modules")
|
||||
|
||||
held = PAIRS[-N_HELDOUT:]
|
||||
held = PAIRS[-cfg.n_heldout:]
|
||||
logger.info(f"held-out pairs: {len(held)}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL, attn_implementation="sdpa").to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
cfg.model, dtype=dtype, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
model.eval()
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name=MODEL, cache_root=CACHE_ROOT, svd_device=device,
|
||||
model, model_name=cfg.model, cache_root=CACHE_ROOT, svd_device=device,
|
||||
)
|
||||
v_hack = load_v_hack(cfg.v_hack_path, cfg.model, wrappers)
|
||||
logger.info(f"loaded v_hack: {len(v_hack)} modules")
|
||||
|
||||
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
@@ -103,7 +116,7 @@ def main() -> int:
|
||||
)
|
||||
|
||||
# save for downstream plotting / sanity
|
||||
torch.save({"cos_align": rows_all}, OUT_DIR / "vhack_heldout_cos.pt")
|
||||
torch.save({"model": cfg.model, "dtype": cfg.dtype, "cos_align": rows_all}, cfg.out_path)
|
||||
|
||||
gate_pass = frac_pos > 0.50
|
||||
target_pass = mean_cos > 0.20
|
||||
@@ -118,4 +131,4 @@ def main() -> int:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
|
||||
Reference in New Issue
Block a user