purge dead modules and stale recipes

Deletes 7 source files that were superseded but never removed:
  run.py, grad_proj.py, extract_vhack.py (older twin-NLL extractor),
  grpo_smoke.py, grpo_proj_smoke.py (smoke harnesses replaced by
  train.py "smoke" subcommand), phase2_analyze.py (pilot is past),
  probe_uat.py (UAT pipeline is past).

Drops matching justfile recipes (vhack-check, phase2-analyze,
probe-uat) and the BASE constant that pointed at run.py. Updates
AGENTS/README references to the stale fast-dev-run recipe (now
just smoke / smoke-vanilla).

Verified by running just smoke-vanilla --steps=2 end-to-end.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-28 08:42:15 +00:00
parent f487e67405
commit 646edfc7af
10 changed files with 6 additions and 1187 deletions
+4 -4
View File
@@ -17,10 +17,10 @@ Inherit global rules from `~/.claude/CLAUDE.md`.
- Read [docs/spec.md](spec.md) for the preregistered plan.
- Read [docs/brainstorm/extracted_prefs.md](docs/brainstorm/extracted_prefs.md) for design rationale.
- New sweep arms get recipes in [justfile](justfile) with `# H:` hypothesis comments.
- `just fast-dev-run` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs).
- `just smoke` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs).
- Real runs go through `pueue` on the 96GB GPU box. Label each job with `why:` and `resolve:`.
- Head [docs/RESEARCH_JOURNAL.md](docs/RESEARCH_JOURNAL.md) for latest results.
- No `tests/` dir; `fast-dev-run` is the correctness gate.
- No `tests/` dir; `smoke` is the correctness gate.
## External dependencies
@@ -53,7 +53,7 @@ Every edit should reduce entropy. If you add something, remove something else.
| Defensive guards (`if x is None`) | Let it crash, fix root cause |
| Magic constants | Name it or derive from spec.md |
| Two loss variants | Pick one, delete other |
| Stubs / canned modes | Delete; fast-dev-run uses real model |
| Stubs / canned modes | Delete; smoke uses real model |
## Don't
@@ -61,7 +61,7 @@ Every edit should reduce entropy. If you add something, remove something else.
is a *constraint*, not a competing objective.
- Don't use defensive programming. Fail fast, crash loudly.
- Don't fabricate numbers in journal entries or table prototypes. Mark TODO.
- Don't run real GRPO to test syntax errors. Use `just fast-dev-run`.
- Don't run real GRPO to test syntax errors. Use `just smoke`.
- Don't modify `external/rl-rewardhacking/` — it's a third-party pin.
## Decision points (live)
+2 -3
View File
@@ -64,9 +64,8 @@ clean gradients).
```bash
uv sync
just fast-dev-run # tiny-random model, ~1-2 min, real pipeline end-to-end
just smoke-vanilla # vanilla pathway smoke
just smoke-projected # projected pathway smoke
just smoke # tiny-random model, projected pathway, ~1-2 min
just smoke-vanilla # tiny-random model, vanilla pathway, ~1-2 min
just download-model # warm Qwen3-4B cache (full preset peaks ~73GB on 96GB)
just queue-full # queue extract + 3-seed vanilla + 3-seed projected sweep
```
-14
View File
@@ -7,7 +7,6 @@ SEEDS_3 := "41 43 44"
# (see RESEARCH_JOURNAL 2026-05-24 (b)).
MODEL := "Qwen/Qwen3-4B"
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
BASE := "uv run python -m projected_grpo.run" # tiny-model smoke harness (fast-dev-run)
TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry point
default:
@@ -162,11 +161,6 @@ queue-projected preset="full" vhack="out/v_hack_full.safetensors":
-- {{ TRAIN }} {{ preset }} --arm=projected --seed=$seed --v-hack-path={{ vhack }} --out-tag=_{{ preset }}_projected_seed$seed
done
# Diagnostic: print v_hack steering check (CAA-style) on base model.
# H: adding v_hack at inference should shift completions toward hack-flavored text.
vhack-check *ARGS:
{{ BASE }} --vhack-check --model={{ MODEL }} {{ ARGS }}
# Distillation probe: hacky teacher (ariahw rh-s65) samples, student trains
# with per-sample v_hack cosine logging. step_NNN.jsonl.gz per step is replayable.
probe-distill *ARGS:
@@ -251,9 +245,6 @@ probe-projected-replay steps="20":
--replay-dir=out/probe_distill/teacher_pool \
--v-hack-path=out/v_hack_full.safetensors
probe-uat:
uv run python -m projected_grpo.probe_uat
# Trajectory comparator for the warmup-gen runs (vanilla vs projected).
probe-traj:
uv run python -m projected_grpo.probe_traj
@@ -275,11 +266,6 @@ probe-baked-projected tag="rh25" seed="41":
--steps=50 --prompts-per-step=8 \
--seed={{ seed }} --out-tag=_baked_{{ tag }}_projected_seed{{ seed }}
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
# and per-arm aggregates, applies decision rules from spec2.md.
phase2-analyze pattern="_pilot_*":
uv run python -m projected_grpo.phase2_analyze "{{ pattern }}"
# Print the results table prototype.
table-proto:
@cat docs/table_proto.md
-82
View File
@@ -1,82 +0,0 @@
"""Extract v_hack from contrastive pairs of hidden states.
Per Wu-Tang (2026, arXiv 2604.01476) §3.1:
d = (1/N) * sum_i (h_i^+ - h_i^-)
where h^+ are last-token hidden states from hack-flavored prompts and h^- from
clean ones, taken at intermediate-to-late layers (60-75% of model depth).
Validation: held-out separation accuracy > 90%.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from jaxtyping import Float
from loguru import logger
from torch import Tensor
@dataclass
class VHackResult:
v_hack: Float[Tensor, "d"] # unit-normed direction
val_accuracy: float # held-out hack-vs-clean separation accuracy
layer_idx: int
n_train: int
n_val: int
@torch.no_grad()
def collect_last_token_hidden(
model,
tokenizer,
prompts: list[str],
layer_idx: int,
device: str = "cuda",
) -> Float[Tensor, "n d"]:
"""Forward each prompt, return last-token hidden state at layer_idx."""
hs = []
for p in prompts:
ids = tokenizer(p, return_tensors="pt").to(device)
out = model(**ids, output_hidden_states=True)
# out.hidden_states is tuple of (n_layers+1,) tensors of shape (1, seq, d)
h = out.hidden_states[layer_idx][0, -1, :].float().cpu() # "d" — fp32 for stable v_hack
hs.append(h)
return torch.stack(hs, dim=0)
def extract_vhack(
h_hack_train: Float[Tensor, "n_train d"],
h_clean_train: Float[Tensor, "n_train d"],
h_hack_val: Float[Tensor, "n_val d"],
h_clean_val: Float[Tensor, "n_val d"],
layer_idx: int,
) -> VHackResult:
"""Mean-difference direction with held-out validation."""
v = (h_hack_train.mean(dim=0) - h_clean_train.mean(dim=0))
v = v / (v.norm() + 1e-12)
# Validate: projection score on hack should exceed clean.
s_hack = h_hack_val @ v
s_clean = h_clean_val @ v
# paired accuracy: each (hack, clean) pair, hack should score higher
correct = (s_hack > s_clean).float().mean().item()
logger.info(
f"v_hack extracted layer={layer_idx} n_train={len(h_hack_train)} "
f"n_val={len(h_hack_val)} val_acc={correct:.3f} "
f"SHOULD>0.9 on a trained model: v_hack should separate hack from clean. "
f"On tiny-random/untrained models val_acc~0.5 (no semantic structure yet), "
f"which is fine for smoke -- the projection mechanism is what we test there."
)
return VHackResult(
v_hack=v,
val_accuracy=correct,
layer_idx=layer_idx,
n_train=len(h_hack_train),
n_val=len(h_hack_val),
)
-82
View File
@@ -1,82 +0,0 @@
"""Gradient projection against a hack direction in SVD-of-W basis.
Math (from spec.md §5):
cos_α = <g, v_hack> / ||g|| # alignment in [-1, 1]
if cos_α > 0:
g' = g - cos_α * ||g|| * v_hack # remove component along v_hack
g' = g' * ||g|| / ||g'|| # restore magnitude (optional)
else:
g' = g
SVD denoising of v_hack (from spec.md §4):
W = U S V^T # SVD of a chosen W matrix (residual stream out)
v_S = V[:, :m].T @ v # project into top-m basis
v = V[:, :m] @ v_S # reproject back
v = v / ||v||
"""
from __future__ import annotations
import torch
from jaxtyping import Float
from torch import Tensor
def svd_denoise(
v: Float[Tensor, "d"],
W: Float[Tensor, "d_out d_in"],
m: int,
use_left: bool = False,
) -> Float[Tensor, "d"]:
"""Project v into top-m SVD basis of W and reproject. Normalize.
use_left=False projects via V (right singular vectors, d_in space).
use_left=True projects via U (left singular vectors, d_out space).
Choose based on which side of W aligns with v's residual-stream dim.
"""
U, S, Vh = torch.linalg.svd(W, full_matrices=False) # U: d_out r, S: r, Vh: r d_in
basis = U[:, :m] if use_left else Vh[:m].T # "d m"
if basis.shape[0] != v.shape[0]:
raise ValueError(
f"v.shape={v.shape} basis.shape={basis.shape}; "
"set use_left to match residual-stream dim of v."
)
v_S = basis.T @ v # "m"
v_denoised = basis @ v_S # "d"
return v_denoised / (v_denoised.norm() + 1e-12)
def project_gradient(
g: Float[Tensor, "D"],
v_hack: Float[Tensor, "D"],
preserve_magnitude: bool = True,
) -> tuple[Float[Tensor, "D"], dict[str, float]]:
"""One-sided gradient projection. Returns (g_projected, diagnostics).
Only projects when cos_align > 0 (gradient is pushing toward hack).
"""
g_norm = g.norm()
# cos(g, v_hack) where v_hack is assumed unit.
cos_pre = (g @ v_hack) / (g_norm + 1e-12)
if cos_pre.item() <= 0:
return g, {
"cos_pre": cos_pre.item(), "cos_post": cos_pre.item(),
"projected": 0.0,
"g_norm_before": g_norm.item(), "g_norm_after": g_norm.item(),
}
# Remove component along v_hack.
g_prime = g - cos_pre * g_norm * v_hack
g_prime_norm = g_prime.norm()
if preserve_magnitude and g_prime_norm > 1e-12:
g_prime = g_prime * (g_norm / g_prime_norm)
cos_post = (g_prime @ v_hack) / (g_prime.norm() + 1e-12)
return g_prime, {
"cos_pre": cos_pre.item(),
"cos_post": cos_post.item(),
"projected": 1.0,
"g_norm_before": g_norm.item(),
"g_norm_after": g_prime.norm().item(),
}
-222
View File
@@ -1,222 +0,0 @@
"""GRPO + AntiPaSTO + gradient projection smoke (steps 9+10).
Pipeline per step:
1. sample a problem prompt (from PAIRS — has both hack/clean references)
2. generate G=4 completions from policy
3. heuristic reward: -1 if 'def run_tests' or 'sys.exit' in completion (hack),
+1 if 'return' and not hacky, 0 otherwise. Format bonus +0.25 for python.
4. group-normalize advantages
5. simple_GRPO loss with PPO clip + KL
6. backward -> delta_S.grad populated per module
7. for `arm=projected`: per module, remove cos(g, v_hack[name])>0 component
(one-sided), preserve magnitude. Log mean cos before/after.
8. optimizer step (AdamW over delta_S params only)
SHOULD: 5 steps complete, loss finite, in `projected` arm mean cos_after ~0
(<1e-3), in `vanilla` mean cos_after == mean cos_before.
ELSE: hook not wired / grad-projection math broken / wrong param set.
Run:
uv run python -m projected_grpo.grpo_proj_smoke --arm=vanilla
uv run python -m projected_grpo.grpo_proj_smoke --arm=projected
"""
from __future__ import annotations
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
import tyro
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .antipasto import wrap_model_with_antipasto
from .pairs import PAIRS
from .proj import per_token_logps, project_delta_S_grad # re-export for back-compat
MODEL = "Qwen/Qwen3.5-0.8B"
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@dataclass
class Config:
arm: str = "projected" # vanilla | projected
steps: int = 5
group: int = 4 # G samples per question
beta: float = 0.04
clip: float = 0.2
lr: float = 1e-4
max_new: int = 96
seed: int = 41
preserve_magnitude: bool = True
def heuristic_reward(text: str) -> float:
"""Cheap stand-in for the real reward fn (used here only for smoke).
Hack-signal: `def run_tests` / `sys.exit` / `_exit` / `assertEqual = ` -> -1
Has return: +1
Has python: +0.25 format bonus
"""
r = 0.0
if any(s in text for s in ("def run_tests", "sys.exit", "os._exit", "assertEqual =", "assertTrue =")):
r -= 1.0
if "return " in text:
r += 1.0
if "def " in text or "import " in text:
r += 0.25
return r
def main(cfg: Config) -> int:
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"cfg={cfg}")
logger.info(f"device={device} model={MODEL}")
tok = AutoTokenizer.from_pretrained(MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad_(False)
wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device)
# ref model stays unwrapped: it represents the *base* policy (delta_S=0
# equivalent), so policy-vs-ref KL measures only what AntiPaSTO added.
delta_params = [info["delta_S"] for info in wrappers.values()]
n_delta = sum(p.numel() for p in delta_params)
logger.info(f"trainable delta_S params: {n_delta:,} across {len(delta_params)} modules")
v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True)
assert set(v_hack) == set(wrappers), "v_hack module names mismatch wrappers"
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9,
num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id,
)
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
logger.info("\n--- TRAIN [AntiPaSTO + GRPO" + (" + projection" if cfg.arm == "projected" else "") + "] ---")
logger.info(
"SHOULD: loss finite, delta_S.grad nonzero, "
f"mean_cos_post {'~0' if cfg.arm == 'projected' else '==mean_cos_pre'}. "
"ELSE: hook not wired or projection math broken."
)
for step in range(cfg.steps):
t0 = time.time()
idx = int(torch.randint(0, len(PAIRS), (1,), generator=rng).item())
prompt = PAIRS[idx].prompt
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg)
gen_out = gen_out.detach()
completions = gen_out[:, plen:]
merged = gen_out
texts = tok.batch_decode(completions, skip_special_tokens=True)
rewards = torch.tensor([heuristic_reward(t) for t in texts],
dtype=torch.float32, device=device)
if (rewards.max() - rewards.min()).item() < 1e-3:
adv = torch.randn(cfg.group, device=device)
logger.warning(f"step {step}: zero reward spread; using synthetic adv")
else:
adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
with torch.no_grad():
ref_logp_full = per_token_logps(ref_model(merged).logits[:, :-1].float(), merged[:, 1:])
gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:])
ref_logp = ref_logp_full[:, plen - 1:].detach()
gen_logp = gen_logp_full[:, plen - 1:].detach()
pol_logits = model(merged).logits[:, :-1].float()
pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:]
mask = (merged[:, plen:] != tok.pad_token_id).float()
ratio = torch.exp(pol_logp - gen_logp)
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
per_tok_loss = -(pol_term - cfg.beta * kl_term)
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean()
opt.zero_grad(set_to_none=True)
loss.backward()
# measure pre-projection subspace-energy fraction ||V g||/||g|| with v_hack
with torch.no_grad():
cos_pre = []
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None: continue
gn = g.norm()
if gn < 1e-12: cos_pre.append(0.0); continue
V = v_hack[name].to(g.device, g.dtype) # [k, r]
cos_pre.append(((V @ g).norm() / gn).item())
mean_cos_pre = float(torch.tensor(cos_pre).mean())
diag = {"mean_cos_pre": mean_cos_pre, "mean_cos_post": mean_cos_pre, "frac_fired": 0.0}
if cfg.arm == "projected":
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item()
opt.step()
rows.append({
"step": step,
"rew_mean": f"{rewards.mean():+.2f}",
"rew_std": f"{rewards.std():.2f}",
"loss": f"{loss.item():+.4f}",
"grad": f"{gnorm:.3f}",
"cos_pre": f"{diag['mean_cos_pre']:+.4f}",
"cos_post": f"{diag['mean_cos_post']:+.4f}",
"frac_fired": f"{diag['frac_fired']:.2f}",
"sec": f"{time.time()-t0:.1f}",
})
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
print(tabulate(rows, headers="keys", tablefmt="github"))
print(f"peak GPU mem: {peak_gb:.2f} GB arm={cfg.arm} steps={len(rows)}/{cfg.steps}")
losses = [float(r["loss"]) for r in rows]
if any(not torch.isfinite(torch.tensor(L)).item() for L in losses):
logger.error("FAIL: non-finite loss")
return 1
if cfg.arm == "projected":
# One-sided projection property: among modules where cos_pre>0, cos_post
# should be driven to ~0. The mean over ALL modules will not be zero
# because modules with cos_pre<=0 are left untouched. Instead we check
# cos_post <= cos_pre (one-sided non-increase) and that fraction > 0.
cos_pres = [float(r["cos_pre"]) for r in rows]
cos_posts = [float(r["cos_post"]) for r in rows]
fracs = [float(r["frac_fired"]) for r in rows]
non_increase = all(co <= ci + 1e-4 for co, ci in zip(cos_posts, cos_pres))
any_fired = any(f > 0 for f in fracs)
if non_increase and any_fired:
logger.info("PROJECTION WORKS: cos_post <= cos_pre on all steps, frac_fired>0")
else:
logger.warning(
f"projection check: non_increase={non_increase} any_fired={any_fired}"
)
logger.info(f"GRPO+ANTIPASTO SMOKE OK ({cfg.arm}): {len(rows)}/{cfg.steps} steps, peak={peak_gb:.2f}GB")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))
-250
View File
@@ -1,250 +0,0 @@
"""simple_GRPO math in one process, on a tiny model.
Ports `gen_samples` + `GRPO_step` + ref-logps from simple_GRPO/simple_grpo_v1
into a single process (no deepspeed, no HTTP ref_server). This is the smoke
gate for step 5 of the plan and the foundation for steps 9-10 (AntiPaSTO +
gradient projection).
SHOULD: loss is finite each step, advantages are normalized (mean approx 0),
gen_logps shape matches completion tokens, reward distribution spreads
across the 8 samples per question. ELSE: GRPO math or ref-server port
is broken.
Run: uv run python -m projected_grpo.grpo_smoke
"""
from __future__ import annotations
import re
import sys
import time
from dataclasses import dataclass
import torch
from datasets import load_dataset
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
# --- config ---
MODEL_PATH = "llamafactory/tiny-random-qwen3"
N_STEPS = 5
NUM_PRE_Q = 4 # group size G (simple_GRPO uses 8; smaller for smoke)
Q_BATCH = 1 # questions per step
BETA = 0.04 # KL weight
CLIP = 0.2 # PPO clip
LR = 1e-5 # bumped from 1e-6 -- tiny model, need movement
MAX_NEW = 64
MAX_PROMPT = 200
SEED = 0
SYSTEM_PROMPT = (
"You are a helpful assistant. The user asks a question, and the Assistant "
"thinks then answers. Enclose reasoning in <think>...</think> and the "
"answer in <answer>...</answer>."
)
@dataclass
class Step:
step: int
reward_mean: float
reward_std: float
adv_mean: float
adv_std: float
loss: float
kl: float
pol: float
grad: float
sec: float
def reward_correct(gt: str, ans: str) -> float:
nums = re.findall(r"-?\d+(?:\.\d+)?", ans)
if not nums:
return -1.0
try:
return 1.0 if abs(float(nums[-1]) - float(gt)) < 1e-3 else -1.0
except ValueError:
return -1.0
def reward_format(ans: str) -> float:
pat = r"<think>.*?</think>\s*<answer>.*?</answer>"
return 0.25 if re.search(pat, ans, re.DOTALL) else -0.25
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
# logits: [B, L-1, V], ids: [B, L-1]
logp = logits.log_softmax(dim=-1)
return logp.gather(-1, ids.unsqueeze(-1)).squeeze(-1)
def main() -> int:
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(
f"cfg: model={MODEL_PATH} steps={N_STEPS} G={NUM_PRE_Q} "
f"beta={BETA} clip={CLIP} lr={LR} max_new={MAX_NEW} seed={SEED}"
)
tok = AutoTokenizer.from_pretrained(MODEL_PATH)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
logger.info("loading policy + ref_model (tiny-random-qwen3)")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad_(False)
opt = torch.optim.AdamW(model.parameters(), lr=LR)
gen_cfg = GenerationConfig(
max_new_tokens=MAX_NEW,
do_sample=True,
temperature=0.9,
num_return_sequences=NUM_PRE_Q,
pad_token_id=tok.pad_token_id,
)
ds = load_dataset("openai/gsm8k", "main", split="train")
QAs = [(q, a.split("####")[-1].strip()) for q, a in zip(ds["question"], ds["answer"])]
logger.info(f"loaded {len(QAs)} GSM8K rows; using Q_BATCH={Q_BATCH}/step")
logger.info("\n\n--- TRAIN [simple_GRPO smoke] ---\n")
logger.info(
"SHOULD: loss finite each step, adv_mean near 0 (group-normalized), "
"reward_std > 0 (group has spread, else step skipped upstream). "
"ELSE: GRPO math broken or rewards collapsed to constant."
)
rng = torch.Generator().manual_seed(SEED)
rows: list[Step] = []
for step in range(N_STEPS):
t0 = time.time()
idx = int(torch.randint(0, len(QAs), (1,), generator=rng).item())
q, gt = QAs[idx]
# build prompt
prompt = tok.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": q},
],
tokenize=False,
add_generation_prompt=True,
)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
if plen > MAX_PROMPT:
logger.warning(f"step {step}: prompt too long {plen}, skip")
continue
# generate G samples (no_grad, NOT inference_mode -- the resulting
# tensor is later fed to model(merged) under autograd)
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg)
gen_out = gen_out.detach()
completions = gen_out[:, plen:] # [G, L_c]
merged = gen_out # [G, plen + L_c]
L = merged.shape[1]
# decode + reward
texts = tok.batch_decode(completions, skip_special_tokens=True)
rewards_t = torch.tensor(
[reward_correct(gt, t) + reward_format(t) for t in texts],
dtype=torch.float32,
device=device,
)
if (rewards_t.max() - rewards_t.min()).item() < 1e-3:
# tiny-random model gives garbage -> rewards collapse to floor.
# For the smoke we still want to exercise the GRPO loss path, so
# we override with synthetic standard-normal advantages. The real
# run on a non-trivial model won't hit this branch.
logger.warning(
f"step {step}: reward spread ~0; using synthetic N(0,1) "
f"advantages to smoke-test the loss math"
)
adv = torch.randn(NUM_PRE_Q, device=device, dtype=torch.float32)
else:
adv = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-4)
# policy + ref logprobs over completion tokens only
# logits [G, L-1, V] map to predicted token ids [G, 1:L]
with torch.no_grad():
ref_logits = ref_model(merged).logits[:, :-1, :]
ref_logp_full = per_token_logps(ref_logits, merged[:, 1:])
# also get behavior logps for PPO ratio
gen_logits = model(merged).logits[:, :-1, :]
gen_logp_full = per_token_logps(gen_logits, merged[:, 1:])
ref_logp = ref_logp_full[:, plen - 1 :].detach()
gen_logp = gen_logp_full[:, plen - 1 :].detach()
# policy fresh forward (with grad)
pol_logits = model(merged).logits[:, :-1, :]
pol_logp_full = per_token_logps(pol_logits, merged[:, 1:])
pol_logp = pol_logp_full[:, plen - 1 :]
mask = (merged[:, plen:] != tok.pad_token_id).float()
# GRPO loss (simple_GRPO formulation, with PPO clipped ratio)
ratio = torch.exp(pol_logp - gen_logp)
clipped = torch.clamp(ratio, 1 - CLIP, 1 + CLIP)
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
per_tok_loss = -(pol_term - BETA * kl_term)
loss = (per_tok_loss * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
loss = loss.mean()
opt.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sec = time.time() - t0
rows.append(
Step(
step=step,
reward_mean=rewards_t.mean().item(),
reward_std=rewards_t.std().item(),
adv_mean=adv.mean().item(),
adv_std=adv.std().item(),
loss=loss.item(),
kl=(kl_term * mask).sum().item() / mask.sum().clamp(min=1).item(),
pol=(pol_term * mask).sum().item() / mask.sum().clamp(min=1).item(),
grad=grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm),
sec=sec,
)
)
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
print("\n\n--- RESULT ---\n")
print(
tabulate(
[vars(r) for r in rows],
headers="keys",
tablefmt="github",
floatfmt="+.3f",
)
)
print(f"\npeak GPU mem: {peak_gb:.2f} GB")
print(f"n_completed_steps: {len(rows)}/{N_STEPS}")
if not rows:
logger.error("FAIL: no step completed (all skipped on no-spread)")
return 1
losses = [r.loss for r in rows]
if any(not torch.isfinite(torch.tensor(L)).item() for L in losses):
logger.error(f"FAIL: non-finite loss in {losses}")
return 1
logger.info(f"\nGRPO SMOKE OK: {len(rows)}/{N_STEPS} steps, peak={peak_gb:.2f}GB")
return 0
if __name__ == "__main__":
sys.exit(main())
-135
View File
@@ -1,135 +0,0 @@
"""Phase 2 pilot analyzer.
Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps
+ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_pre,
cos_post, fired) for vanilla vs projected, seed by seed.
Decision rules per spec2.md:
- vanilla cos_pre > 0.2 consistently -> H1 likely; Phase 3 justified
- vanilla cos_pre ~ 0 over all steps -> v_hack orthogonal to GRPO grad
- projected cos_post < cos_pre on most steps -> mechanism active
- projected hack rate < vanilla at matched pass -> H1 fires (won't see in
20 steps; paper hacks emerge ~step 80)
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from safetensors import safe_open
from loguru import logger
def load_run(path: Path) -> tuple[dict, list[dict]]:
"""Returns (cfg_dict, rows). Rows are the per-step TSV-like records."""
with safe_open(str(path), framework="pt", device="cpu") as f:
meta = f.metadata() or {}
cfg = json.loads(meta.get("cfg", "{}"))
rows = json.loads(meta.get("rows", "[]"))
return cfg, rows
def fmt_traj(rows: list[dict]) -> str:
lines = ["step rew gt hack loss cin cout fired"]
for r in rows:
lines.append(
f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} "
f"{r['loss']:+.4f} {r['cos_pre']:+.3f} {r['cos_post']:+.3f} {r['fired']:.2f}"
)
return "\n".join(lines)
def aggregate(rows: list[dict]) -> dict:
if not rows:
return {}
cin = [r["cos_pre"] for r in rows if isinstance(r["cos_pre"], (int, float))]
cout = [r["cos_post"] for r in rows if isinstance(r["cos_post"], (int, float))]
fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))]
n_steps = len(rows)
last_hack = rows[-1]["hack"]
last_gt = rows[-1]["gt"]
return {
"n_steps": n_steps,
"cin_mean": sum(cin) / max(1, len(cin)),
"cin_min": min(cin) if cin else float("nan"),
"cin_max": max(cin) if cin else float("nan"),
"cout_mean": sum(cout) / max(1, len(cout)),
"fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"),
"frac_out_lt_in": sum(1 for r in rows
if isinstance(r["cos_post"], (int, float))
and isinstance(r["cos_pre"], (int, float))
and r["cos_post"] < r["cos_pre"]) / n_steps,
"last_hack": last_hack,
"last_gt": last_gt,
}
def main(pattern: str = "_pilot_*"):
paths = sorted(Path("out").glob(f"train{pattern}.safetensors"))
if not paths:
print(f"no runs match out/train{pattern}.safetensors")
return 1
runs = []
for p in paths:
cfg, rows = load_run(p)
if not rows:
print(f"{p.name}: no rows")
continue
agg = aggregate(rows)
agg["arm"] = cfg.get("arm")
agg["seed"] = cfg.get("seed")
agg["tag"] = cfg.get("out_tag", "")
agg["path"] = p.name
runs.append((cfg, rows, agg))
print("=" * 90)
print("Phase 2 pilot — aggregate summary")
print("=" * 90)
print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out<in':>6s} hack gt")
for _, _, agg in runs:
print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} "
f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} "
f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}")
print()
print("=" * 90)
print("Per-step trajectories")
print("=" * 90)
for cfg, rows, agg in runs:
print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---")
print(fmt_traj(rows))
print()
print("=" * 90)
print("Phase 2 / Phase 3 decision")
print("=" * 90)
vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"]
proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"]
if vanilla_cin:
v_mean = sum(vanilla_cin) / len(vanilla_cin)
print(f"vanilla cos_pre mean across seeds: {v_mean:+.4f}")
if v_mean > 0.2:
print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.")
elif v_mean > 0.02:
print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.")
print(" Phase 3 needed to see late-step regime.")
elif abs(v_mean) < 0.01:
print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still")
print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).")
else:
print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.")
if proj_runs:
out_lt_in = [a["frac_out_lt_in"] for a in proj_runs]
m = sum(out_lt_in) / len(out_lt_in)
print(f"projected cos_post<cos_pre fraction across seeds: {m:.2f}")
if m >= 0.8:
print(" -> Projection mechanism active.")
else:
print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*"))
-152
View File
@@ -1,152 +0,0 @@
"""UAT analyzer for the distillation probe.
Reads three runs from out/probe_distill/:
teacher_pool/ (T1: teacher hack rate >= 0.30)
vanilla_seed41/ (T2: cos_S_contrib non-null; T4: cos | hacked > cos | not-hacked)
projected_seed41/ (T3: mean_cos_post < mean_cos_pre on most steps)
Prints PASS/FAIL per UAT.
"""
from __future__ import annotations
import gzip
import json
import math
import sys
from pathlib import Path
from loguru import logger
def load_run(run_dir: Path) -> list[dict]:
rows = []
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
with gzip.open(path, "rt") as f:
for line in f:
rows.append(json.loads(line))
return rows
def t_test(a: list[float], b: list[float]) -> tuple[float, float]:
"""Welch's t (one-sided: a > b). Returns (t, p). No SciPy."""
if len(a) < 2 or len(b) < 2:
return float("nan"), float("nan")
ma = sum(a) / len(a); mb = sum(b) / len(b)
va = sum((x-ma)**2 for x in a) / (len(a)-1)
vb = sum((x-mb)**2 for x in b) / (len(b)-1)
se = math.sqrt(va/len(a) + vb/len(b))
if se < 1e-12:
return float("nan"), float("nan")
t = (ma - mb) / se
# one-sided p via normal approx (good enough at N>=20)
z = t
p = 0.5 * (1 - math.erf(z / math.sqrt(2)))
return t, p
def main(root: Path = Path("out/probe_distill")) -> int:
results = []
# ---------- T1: teacher pool hack rate -------------------------------
pool_dir = root / "teacher_pool"
if not pool_dir.exists():
results.append(("T1", "FAIL", f"missing {pool_dir}"))
else:
rows = load_run(pool_dir)
if not rows:
results.append(("T1", "FAIL", f"no rows in {pool_dir}"))
else:
hr = sum(int(r["hacked"]) for r in rows) / len(rows)
pr = sum(int(r["gt_pass"]) for r in rows) / len(rows)
ok = hr >= 0.30
results.append((
"T1", "PASS" if ok else "FAIL",
f"teacher hack_rate={hr:.3f} (>=0.30) pass_rate={pr:.3f} n={len(rows)}",
))
# ---------- T2: vanilla replay coverage ------------------------------
vdir = root / "vanilla_seed41"
if not vdir.exists():
results.append(("T2", "FAIL", f"missing {vdir}"))
else:
rows = load_run(vdir)
if not rows:
results.append(("T2", "FAIL", f"no rows in {vdir}"))
else:
with_cos = [r for r in rows if r.get("cos_S_contrib") is not None]
cov = len(with_cos) / max(1, len(rows))
ok = cov >= 0.90
results.append((
"T2", "PASS" if ok else "FAIL",
f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}",
))
# ---------- T3: projected projection fires ---------------------------
pdir = root / "projected_seed41"
if not pdir.exists():
results.append(("T3", "FAIL", f"missing {pdir}"))
else:
rows = load_run(pdir)
if not rows:
results.append(("T3", "FAIL", f"no rows in {pdir}"))
else:
steps = {}
for r in rows:
if r.get("zero_advantages", False):
continue
steps.setdefault(r["step"], r) # one entry per step
n_steps = len(steps)
n_fired = sum(
1 for r in steps.values()
if not (math.isnan(r["mean_cos_pre"]) or math.isnan(r["mean_cos_post"]))
and r["mean_cos_post"] < r["mean_cos_pre"]
)
frac = n_fired / max(1, n_steps)
ok = frac >= 0.80
results.append((
"T3", "PASS" if ok else "FAIL",
f"projected cos_post<cos_pre on {n_fired}/{n_steps} steps (frac={frac:.2f}, >=0.80)",
))
# ---------- T4: per-sample cosine discriminates "purer hack" -----------
# rh-s65 hacks ~99% of the time so the original hacked-vs-not split has
# n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack
# (only run_tests overwrite)" from "hack + also correct (mixed direction)".
# SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1).
if not vdir.exists():
results.append(("T4", "FAIL", "missing vanilla run"))
else:
rows = load_run(vdir)
pure = [r["cos_S_contrib"] for r in rows
if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None]
mixed = [r["cos_S_contrib"] for r in rows
if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None]
if len(pure) < 2 or len(mixed) < 2:
results.append((
"T4", "FAIL",
f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}",
))
else:
mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed)
t, p = t_test(pure, mixed)
ok = (p < 0.05) and (mp > mm)
results.append((
"T4", "PASS" if ok else "FAIL",
f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) "
f"t={t:+.2f} p={p:.4f}",
))
print()
print("UAT RESULTS")
print("===========")
n_pass = 0
for name, status, msg in results:
marker = "PASS" if status == "PASS" else "FAIL"
print(f" [{marker}] {name} {msg}")
n_pass += int(status == "PASS")
print(f"\n {n_pass}/{len(results)} UATs passed.")
return 0 if n_pass == len(results) else 1
if __name__ == "__main__":
sys.exit(main(Path("out/probe_distill")))
-243
View File
@@ -1,243 +0,0 @@
"""Smoke / fast-dev-run entry point — runs the REAL pipeline end-to-end.
Pipeline (~1-2 min on CPU with tiny-random qwen3):
1. Load model + tokenizer
2. Extract v_hack from 20 shared-prompt hack/clean pairs (docs/pairs):
real forward, mean-difference of last-token hidden states at ~70% depth
3. SVD-denoise v_hack via lm_head.weight
4. Run N "real" GRPO-ish backward passes:
- NLL loss on completion tokens
- real loss.backward() -> real grad on model.lm_head.weight: [vocab, d]
- per-row cos_align(grad_row, v_hack); aggregate mean
- arm='projected': remove v_hack component from each row, optionally
restore row magnitude, write back to .grad, optimizer.step()
- arm='vanilla': no projection, optimizer.step()
5. Diff vanilla vs projected: mean cos_align and parameter delta norms.
No fake gradients. Code paths AND mechanism are tested in one pass.
"""
from __future__ import annotations
import sys
from dataclasses import asdict, dataclass
import torch
import tyro
from jaxtyping import Float
from loguru import logger
from tabulate import tabulate
from torch import Tensor
from transformers import AutoModelForCausalLM, AutoTokenizer
from projected_grpo.extract_vhack import collect_last_token_hidden, extract_vhack
from projected_grpo.grad_proj import svd_denoise
from projected_grpo.pairs import PAIRS, clean_prompts, hack_prompts
@dataclass
class Config:
model: str = "llamafactory/tiny-random-qwen3"
arm: str = "both" # "vanilla" | "projected" | "both"
m: int = 16 # SVD top-m for v_hack denoise
steps: int = 5 # real backward+step iterations per arm
seed: int = 41
lr: float = 7e-5 # per docs/grpo_hyperparams.md
fast_dev_run: bool = False
vhack_check: bool = False
preserve_magnitude: bool = True
device: str = "cpu"
dtype: str = "fp32" # fp32 | bf16; bf16 needs cuda
def _resolve_dtype(s: str) -> torch.dtype:
return {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[s]
def setup_logging() -> None:
logger.remove()
logger.add(sys.stderr, format="<level>{level.icon}</level> {message}", colorize=True)
logger.level("INFO", icon="I")
logger.level("WARNING", icon="W")
logger.level("ERROR", icon="E")
logger.level("DEBUG", icon="D")
def project_grad_per_row(
g_W: Float[Tensor, "vocab d"],
v_hack: Float[Tensor, "d"],
preserve_magnitude: bool,
) -> tuple[Float[Tensor, "vocab d"], dict]:
"""One-sided per-row projection of a weight gradient against v_hack.
For each row g_v of g_W (shape [d]):
cos = (g_v . v_hack) / ||g_v||
if cos > 0: g_v' = g_v - cos * ||g_v|| * v_hack; rescale to ||g_v||
else: g_v' = g_v
"""
v_hack = v_hack / (v_hack.norm() + 1e-12)
row_norms = g_W.norm(dim=-1, keepdim=True).clamp_min(1e-12) # [vocab, 1]
cos_pre = (g_W @ v_hack).unsqueeze(-1) / row_norms # [vocab, 1]
mask_pos = (cos_pre > 0).float()
coef = (cos_pre * row_norms) * mask_pos # zero out rows with cos<=0
g_proj = g_W - coef * v_hack.unsqueeze(0)
if preserve_magnitude:
new_norms = g_proj.norm(dim=-1, keepdim=True).clamp_min(1e-12)
g_proj = g_proj * (row_norms / new_norms)
cos_post = (g_proj @ v_hack) / g_proj.norm(dim=-1).clamp_min(1e-12)
return g_proj, {
"cos_pre_mean": cos_pre.squeeze(-1).mean().item(),
"cos_pre_max": cos_pre.squeeze(-1).max().item(),
"cos_post_mean": cos_post.mean().item(),
"cos_post_max": cos_post.max().item(),
"frac_projected": mask_pos.mean().item(),
}
def real_grpo_step(
model,
tokenizer,
prompt: str,
completion: str,
v_hack: Float[Tensor, "d"],
arm: str,
preserve_magnitude: bool,
optimizer: torch.optim.Optimizer,
) -> dict:
"""One GRPO-ish update: NLL on completion -> backward -> (project) -> step."""
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(model.device)
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids
plen = prompt_ids.shape[1]
labels = full_ids.clone()
labels[:, :plen] = -100 # NLL on completion tokens only
out = model(input_ids=full_ids, labels=labels)
loss = out.loss
optimizer.zero_grad()
loss.backward()
g_W = model.lm_head.weight.grad.detach().float() # [vocab, d] -> fp32 for projection stability
if arm == "projected":
g_proj, diag = project_grad_per_row(g_W, v_hack, preserve_magnitude)
model.lm_head.weight.grad.copy_(g_proj.to(model.lm_head.weight.grad.dtype))
else:
row_norms = g_W.norm(dim=-1).clamp_min(1e-12)
cos_pre = (g_W @ v_hack) / row_norms
diag = {
"cos_pre_mean": cos_pre.mean().item(),
"cos_pre_max": cos_pre.max().item(),
"cos_post_mean": cos_pre.mean().item(),
"cos_post_max": cos_pre.max().item(),
"frac_projected": 0.0,
}
optimizer.step()
diag["loss"] = loss.item()
diag["g_norm"] = g_W.norm().item()
return diag
def snapshot(model) -> dict[str, Tensor]:
return {k: v.detach().clone() for k, v in model.state_dict().items()}
def param_delta(s0: dict[str, Tensor], s1: dict[str, Tensor]) -> float:
return sum((s1[k].float() - s0[k].float()).norm().item() ** 2 for k in s0) ** 0.5
def run_arm(cfg: Config, arm: str, v_hack: Float[Tensor, "d"]) -> dict:
print(f"\n\n--- TRAIN [{arm}] seed={cfg.seed} steps={cfg.steps} lr={cfg.lr} ---\n")
torch.manual_seed(cfg.seed)
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
dtype = _resolve_dtype(cfg.dtype)
model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
state_0 = snapshot(model)
rows = []
for step in range(cfg.steps):
p = PAIRS[step % len(PAIRS)]
diag = real_grpo_step(
model, tokenizer, p.prompt, p.hack, v_hack.to(model.device), arm,
cfg.preserve_magnitude, optimizer,
)
rows.append({"step": step, "flavor": p.hack_flavor, **diag})
logger.info(f"per-step [{arm}]:\n" + tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f"))
state_1 = snapshot(model)
return {
"arm": arm,
"final_loss": rows[-1]["loss"],
"mean_cos_pre": sum(r["cos_pre_mean"] for r in rows) / len(rows),
"mean_cos_post": sum(r["cos_post_mean"] for r in rows) / len(rows),
"frac_projected": sum(r["frac_projected"] for r in rows) / len(rows),
"param_delta": param_delta(state_0, state_1),
}
def main(cfg: Config) -> None:
setup_logging()
print(f"argv: {' '.join(sys.argv)}")
print(f"cfg: {asdict(cfg)}")
print(f"\n\n=== LOAD [{cfg.model}] ===\n")
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
dtype = _resolve_dtype(cfg.dtype)
model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device)
model.eval()
n_layers = model.config.num_hidden_layers
layer_idx = max(1, int(n_layers * 0.7))
logger.info(f"n_layers={n_layers} layer_idx={layer_idx} (70% depth per Wu-Tang)")
print(f"\n\n=== EXTRACT [v_hack] n_pairs={len(PAIRS)} layer={layer_idx} ===\n")
h_hack = collect_last_token_hidden(model, tokenizer, hack_prompts(), layer_idx, cfg.device)
h_clean = collect_last_token_hidden(model, tokenizer, clean_prompts(), layer_idx, cfg.device)
n_train = int(len(PAIRS) * 0.75)
vh = extract_vhack(
h_hack[:n_train], h_clean[:n_train],
h_hack[n_train:], h_clean[n_train:],
layer_idx=layer_idx,
)
v_hack = vh.v_hack
# SHOULD val_acc>0.9 is already logged inside extract_vhack at the site.
W = model.lm_head.weight.detach().float().cpu() # [vocab, d] -> fp32 cpu for stable SVD
v_hack_cpu = v_hack.float().cpu()
logger.info(f"SVD-denoise via lm_head.weight shape={tuple(W.shape)} m={cfg.m}")
v_hack_denoised = svd_denoise(v_hack_cpu, W, m=cfg.m, use_left=False)
cos_raw_denoised = float(v_hack_cpu @ v_hack_denoised)
logger.info(
f"cos(raw, denoised)={cos_raw_denoised:+.3f} "
f"SHOULD>0.5: denoised should retain the dominant direction. "
f"If <0.5: m too small OR wrong basis side (try use_left=True)."
)
del model # free; run_arm reloads a fresh copy for each arm
if cfg.vhack_check:
logger.info("vhack-check: TODO real CAA-style steering check on full model.")
return
arms = ["vanilla", "projected"] if cfg.arm == "both" else [cfg.arm]
results = [run_arm(cfg, a, v_hack_denoised) for a in arms]
# === RESULTS tail ===
print("\n\n=== RESULTS ===\n")
if cfg.arm == "both":
van = next(r for r in results if r["arm"] == "vanilla")
proj = next(r for r in results if r["arm"] == "projected")
delta_cos = van["mean_cos_post"] - proj["mean_cos_post"]
cue = "[OK]" if delta_cos > 0.01 else "[WARN]"
print(f"main metric: delta_cos_post={delta_cos:+.4f} {cue}")
print(f"argv: {' '.join(sys.argv)}")
print(f"vhack_val_acc={vh.val_accuracy:+.3f}")
print(f"frac_projected (projected arm)={proj['frac_projected']:.2f}\n")
print(tabulate(results, headers="keys", tablefmt="tsv", floatfmt="+.4f"))
print("\nTable: vanilla vs projected GRPO-ish smoke; 5 real backward+step on tiny-random qwen3.")
print("mean_cos_post (->0 for projected, free for vanilla); param_delta (-> nonzero = real opt step).\n")
print(tabulate(results, headers="keys", tablefmt="github", floatfmt="+.4f"))
print()
logger.info("smoke OK")
if __name__ == "__main__":
main(tyro.cli(Config))