From 646edfc7af22084a36e180b5b69f52698cb1af9e Mon Sep 17 00:00:00 2001 From: wassname Date: Thu, 28 May 2026 08:42:15 +0000 Subject: [PATCH] purge dead modules and stale recipes Deletes 7 source files that were superseded but never removed: run.py, grad_proj.py, extract_vhack.py (older twin-NLL extractor), grpo_smoke.py, grpo_proj_smoke.py (smoke harnesses replaced by train.py "smoke" subcommand), phase2_analyze.py (pilot is past), probe_uat.py (UAT pipeline is past). Drops matching justfile recipes (vhack-check, phase2-analyze, probe-uat) and the BASE constant that pointed at run.py. Updates AGENTS/README references to the stale fast-dev-run recipe (now just smoke / smoke-vanilla). Verified by running just smoke-vanilla --steps=2 end-to-end. Co-Authored-By: Claude Opus 4.7 --- AGENTS.md | 8 +- README.md | 5 +- justfile | 14 -- src/projected_grpo/extract_vhack.py | 82 --------- src/projected_grpo/grad_proj.py | 82 --------- src/projected_grpo/grpo_proj_smoke.py | 222 ----------------------- src/projected_grpo/grpo_smoke.py | 250 -------------------------- src/projected_grpo/phase2_analyze.py | 135 -------------- src/projected_grpo/probe_uat.py | 152 ---------------- src/projected_grpo/run.py | 243 ------------------------- 10 files changed, 6 insertions(+), 1187 deletions(-) delete mode 100644 src/projected_grpo/extract_vhack.py delete mode 100644 src/projected_grpo/grad_proj.py delete mode 100644 src/projected_grpo/grpo_proj_smoke.py delete mode 100644 src/projected_grpo/grpo_smoke.py delete mode 100644 src/projected_grpo/phase2_analyze.py delete mode 100644 src/projected_grpo/probe_uat.py delete mode 100644 src/projected_grpo/run.py diff --git a/AGENTS.md b/AGENTS.md index 12a0e96..29196e5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,10 +17,10 @@ Inherit global rules from `~/.claude/CLAUDE.md`. - Read [docs/spec.md](spec.md) for the preregistered plan. - Read [docs/brainstorm/extracted_prefs.md](docs/brainstorm/extracted_prefs.md) for design rationale. - New sweep arms get recipes in [justfile](justfile) with `# H:` hypothesis comments. -- `just fast-dev-run` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs). +- `just smoke` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs). - Real runs go through `pueue` on the 96GB GPU box. Label each job with `why:` and `resolve:`. - Head [docs/RESEARCH_JOURNAL.md](docs/RESEARCH_JOURNAL.md) for latest results. -- No `tests/` dir; `fast-dev-run` is the correctness gate. +- No `tests/` dir; `smoke` is the correctness gate. ## External dependencies @@ -53,7 +53,7 @@ Every edit should reduce entropy. If you add something, remove something else. | Defensive guards (`if x is None`) | Let it crash, fix root cause | | Magic constants | Name it or derive from spec.md | | Two loss variants | Pick one, delete other | -| Stubs / canned modes | Delete; fast-dev-run uses real model | +| Stubs / canned modes | Delete; smoke uses real model | ## Don't @@ -61,7 +61,7 @@ Every edit should reduce entropy. If you add something, remove something else. is a *constraint*, not a competing objective. - Don't use defensive programming. Fail fast, crash loudly. - Don't fabricate numbers in journal entries or table prototypes. Mark TODO. -- Don't run real GRPO to test syntax errors. Use `just fast-dev-run`. +- Don't run real GRPO to test syntax errors. Use `just smoke`. - Don't modify `external/rl-rewardhacking/` — it's a third-party pin. ## Decision points (live) diff --git a/README.md b/README.md index 828cbd0..0f434e7 100644 --- a/README.md +++ b/README.md @@ -64,9 +64,8 @@ clean gradients). ```bash uv sync -just fast-dev-run # tiny-random model, ~1-2 min, real pipeline end-to-end -just smoke-vanilla # vanilla pathway smoke -just smoke-projected # projected pathway smoke +just smoke # tiny-random model, projected pathway, ~1-2 min +just smoke-vanilla # tiny-random model, vanilla pathway, ~1-2 min just download-model # warm Qwen3-4B cache (full preset peaks ~73GB on 96GB) just queue-full # queue extract + 3-seed vanilla + 3-seed projected sweep ``` diff --git a/justfile b/justfile index c0aed0a..a2c410d 100644 --- a/justfile +++ b/justfile @@ -7,7 +7,6 @@ SEEDS_3 := "41 43 44" # (see RESEARCH_JOURNAL 2026-05-24 (b)). MODEL := "Qwen/Qwen3-4B" TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only -BASE := "uv run python -m projected_grpo.run" # tiny-model smoke harness (fast-dev-run) TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry point default: @@ -162,11 +161,6 @@ queue-projected preset="full" vhack="out/v_hack_full.safetensors": -- {{ TRAIN }} {{ preset }} --arm=projected --seed=$seed --v-hack-path={{ vhack }} --out-tag=_{{ preset }}_projected_seed$seed done -# Diagnostic: print v_hack steering check (CAA-style) on base model. -# H: adding v_hack at inference should shift completions toward hack-flavored text. -vhack-check *ARGS: - {{ BASE }} --vhack-check --model={{ MODEL }} {{ ARGS }} - # Distillation probe: hacky teacher (ariahw rh-s65) samples, student trains # with per-sample v_hack cosine logging. step_NNN.jsonl.gz per step is replayable. probe-distill *ARGS: @@ -251,9 +245,6 @@ probe-projected-replay steps="20": --replay-dir=out/probe_distill/teacher_pool \ --v-hack-path=out/v_hack_full.safetensors -probe-uat: - uv run python -m projected_grpo.probe_uat - # Trajectory comparator for the warmup-gen runs (vanilla vs projected). probe-traj: uv run python -m projected_grpo.probe_traj @@ -275,11 +266,6 @@ probe-baked-projected tag="rh25" seed="41": --steps=50 --prompts-per-step=8 \ --seed={{ seed }} --out-tag=_baked_{{ tag }}_projected_seed{{ seed }} -# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories -# and per-arm aggregates, applies decision rules from spec2.md. -phase2-analyze pattern="_pilot_*": - uv run python -m projected_grpo.phase2_analyze "{{ pattern }}" - # Print the results table prototype. table-proto: @cat docs/table_proto.md diff --git a/src/projected_grpo/extract_vhack.py b/src/projected_grpo/extract_vhack.py deleted file mode 100644 index 3620677..0000000 --- a/src/projected_grpo/extract_vhack.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Extract v_hack from contrastive pairs of hidden states. - -Per Wu-Tang (2026, arXiv 2604.01476) §3.1: - - d = (1/N) * sum_i (h_i^+ - h_i^-) - -where h^+ are last-token hidden states from hack-flavored prompts and h^- from -clean ones, taken at intermediate-to-late layers (60-75% of model depth). - -Validation: held-out separation accuracy > 90%. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch -from jaxtyping import Float -from loguru import logger -from torch import Tensor - - -@dataclass -class VHackResult: - v_hack: Float[Tensor, "d"] # unit-normed direction - val_accuracy: float # held-out hack-vs-clean separation accuracy - layer_idx: int - n_train: int - n_val: int - - -@torch.no_grad() -def collect_last_token_hidden( - model, - tokenizer, - prompts: list[str], - layer_idx: int, - device: str = "cuda", -) -> Float[Tensor, "n d"]: - """Forward each prompt, return last-token hidden state at layer_idx.""" - hs = [] - for p in prompts: - ids = tokenizer(p, return_tensors="pt").to(device) - out = model(**ids, output_hidden_states=True) - # out.hidden_states is tuple of (n_layers+1,) tensors of shape (1, seq, d) - h = out.hidden_states[layer_idx][0, -1, :].float().cpu() # "d" — fp32 for stable v_hack - hs.append(h) - return torch.stack(hs, dim=0) - - -def extract_vhack( - h_hack_train: Float[Tensor, "n_train d"], - h_clean_train: Float[Tensor, "n_train d"], - h_hack_val: Float[Tensor, "n_val d"], - h_clean_val: Float[Tensor, "n_val d"], - layer_idx: int, -) -> VHackResult: - """Mean-difference direction with held-out validation.""" - v = (h_hack_train.mean(dim=0) - h_clean_train.mean(dim=0)) - v = v / (v.norm() + 1e-12) - - # Validate: projection score on hack should exceed clean. - s_hack = h_hack_val @ v - s_clean = h_clean_val @ v - # paired accuracy: each (hack, clean) pair, hack should score higher - correct = (s_hack > s_clean).float().mean().item() - - logger.info( - f"v_hack extracted layer={layer_idx} n_train={len(h_hack_train)} " - f"n_val={len(h_hack_val)} val_acc={correct:.3f} " - f"SHOULD>0.9 on a trained model: v_hack should separate hack from clean. " - f"On tiny-random/untrained models val_acc~0.5 (no semantic structure yet), " - f"which is fine for smoke -- the projection mechanism is what we test there." - ) - - return VHackResult( - v_hack=v, - val_accuracy=correct, - layer_idx=layer_idx, - n_train=len(h_hack_train), - n_val=len(h_hack_val), - ) diff --git a/src/projected_grpo/grad_proj.py b/src/projected_grpo/grad_proj.py deleted file mode 100644 index 42962be..0000000 --- a/src/projected_grpo/grad_proj.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Gradient projection against a hack direction in SVD-of-W basis. - -Math (from spec.md §5): - - cos_α = / ||g|| # alignment in [-1, 1] - if cos_α > 0: - g' = g - cos_α * ||g|| * v_hack # remove component along v_hack - g' = g' * ||g|| / ||g'|| # restore magnitude (optional) - else: - g' = g - -SVD denoising of v_hack (from spec.md §4): - - W = U S V^T # SVD of a chosen W matrix (residual stream out) - v_S = V[:, :m].T @ v # project into top-m basis - v = V[:, :m] @ v_S # reproject back - v = v / ||v|| -""" - -from __future__ import annotations - -import torch -from jaxtyping import Float -from torch import Tensor - - -def svd_denoise( - v: Float[Tensor, "d"], - W: Float[Tensor, "d_out d_in"], - m: int, - use_left: bool = False, -) -> Float[Tensor, "d"]: - """Project v into top-m SVD basis of W and reproject. Normalize. - - use_left=False projects via V (right singular vectors, d_in space). - use_left=True projects via U (left singular vectors, d_out space). - Choose based on which side of W aligns with v's residual-stream dim. - """ - U, S, Vh = torch.linalg.svd(W, full_matrices=False) # U: d_out r, S: r, Vh: r d_in - basis = U[:, :m] if use_left else Vh[:m].T # "d m" - if basis.shape[0] != v.shape[0]: - raise ValueError( - f"v.shape={v.shape} basis.shape={basis.shape}; " - "set use_left to match residual-stream dim of v." - ) - v_S = basis.T @ v # "m" - v_denoised = basis @ v_S # "d" - return v_denoised / (v_denoised.norm() + 1e-12) - - -def project_gradient( - g: Float[Tensor, "D"], - v_hack: Float[Tensor, "D"], - preserve_magnitude: bool = True, -) -> tuple[Float[Tensor, "D"], dict[str, float]]: - """One-sided gradient projection. Returns (g_projected, diagnostics). - - Only projects when cos_align > 0 (gradient is pushing toward hack). - """ - g_norm = g.norm() - # cos(g, v_hack) where v_hack is assumed unit. - cos_pre = (g @ v_hack) / (g_norm + 1e-12) - if cos_pre.item() <= 0: - return g, { - "cos_pre": cos_pre.item(), "cos_post": cos_pre.item(), - "projected": 0.0, - "g_norm_before": g_norm.item(), "g_norm_after": g_norm.item(), - } - - # Remove component along v_hack. - g_prime = g - cos_pre * g_norm * v_hack - g_prime_norm = g_prime.norm() - if preserve_magnitude and g_prime_norm > 1e-12: - g_prime = g_prime * (g_norm / g_prime_norm) - cos_post = (g_prime @ v_hack) / (g_prime.norm() + 1e-12) - return g_prime, { - "cos_pre": cos_pre.item(), - "cos_post": cos_post.item(), - "projected": 1.0, - "g_norm_before": g_norm.item(), - "g_norm_after": g_prime.norm().item(), - } diff --git a/src/projected_grpo/grpo_proj_smoke.py b/src/projected_grpo/grpo_proj_smoke.py deleted file mode 100644 index 8aff43d..0000000 --- a/src/projected_grpo/grpo_proj_smoke.py +++ /dev/null @@ -1,222 +0,0 @@ -"""GRPO + AntiPaSTO + gradient projection smoke (steps 9+10). - -Pipeline per step: - 1. sample a problem prompt (from PAIRS — has both hack/clean references) - 2. generate G=4 completions from policy - 3. heuristic reward: -1 if 'def run_tests' or 'sys.exit' in completion (hack), - +1 if 'return' and not hacky, 0 otherwise. Format bonus +0.25 for python. - 4. group-normalize advantages - 5. simple_GRPO loss with PPO clip + KL - 6. backward -> delta_S.grad populated per module - 7. for `arm=projected`: per module, remove cos(g, v_hack[name])>0 component - (one-sided), preserve magnitude. Log mean cos before/after. - 8. optimizer step (AdamW over delta_S params only) - -SHOULD: 5 steps complete, loss finite, in `projected` arm mean cos_after ~0 -(<1e-3), in `vanilla` mean cos_after == mean cos_before. -ELSE: hook not wired / grad-projection math broken / wrong param set. - -Run: - uv run python -m projected_grpo.grpo_proj_smoke --arm=vanilla - uv run python -m projected_grpo.grpo_proj_smoke --arm=projected -""" -from __future__ import annotations - -import sys -import time -from dataclasses import asdict, dataclass -from pathlib import Path - -import torch -import tyro -from loguru import logger -from tabulate import tabulate -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from .antipasto import wrap_model_with_antipasto -from .pairs import PAIRS -from .proj import per_token_logps, project_delta_S_grad # re-export for back-compat - -MODEL = "Qwen/Qwen3.5-0.8B" -CACHE_ROOT = Path("svd_cache") -OUT_DIR = Path("out") - - -@dataclass -class Config: - arm: str = "projected" # vanilla | projected - steps: int = 5 - group: int = 4 # G samples per question - beta: float = 0.04 - clip: float = 0.2 - lr: float = 1e-4 - max_new: int = 96 - seed: int = 41 - preserve_magnitude: bool = True - - -def heuristic_reward(text: str) -> float: - """Cheap stand-in for the real reward fn (used here only for smoke). - - Hack-signal: `def run_tests` / `sys.exit` / `_exit` / `assertEqual = ` -> -1 - Has return: +1 - Has python: +0.25 format bonus - """ - r = 0.0 - if any(s in text for s in ("def run_tests", "sys.exit", "os._exit", "assertEqual =", "assertTrue =")): - r -= 1.0 - if "return " in text: - r += 1.0 - if "def " in text or "import " in text: - r += 0.25 - return r - - -def main(cfg: Config) -> int: - torch.manual_seed(cfg.seed) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"cfg={cfg}") - logger.info(f"device={device} model={MODEL}") - - tok = AutoTokenizer.from_pretrained(MODEL) - if tok.pad_token_id is None: - tok.pad_token = tok.eos_token - - model = AutoModelForCausalLM.from_pretrained( - MODEL, dtype=torch.bfloat16, attn_implementation="sdpa" - ).to(device) - ref_model = AutoModelForCausalLM.from_pretrained( - MODEL, dtype=torch.bfloat16, attn_implementation="sdpa" - ).to(device) - ref_model.eval() - for p in ref_model.parameters(): - p.requires_grad_(False) - - wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device) - # ref model stays unwrapped: it represents the *base* policy (delta_S=0 - # equivalent), so policy-vs-ref KL measures only what AntiPaSTO added. - delta_params = [info["delta_S"] for info in wrappers.values()] - n_delta = sum(p.numel() for p in delta_params) - logger.info(f"trainable delta_S params: {n_delta:,} across {len(delta_params)} modules") - - v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True) - assert set(v_hack) == set(wrappers), "v_hack module names mismatch wrappers" - - opt = torch.optim.AdamW(delta_params, lr=cfg.lr) - gen_cfg = GenerationConfig( - max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9, - num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, - ) - - rng = torch.Generator().manual_seed(cfg.seed) - rows = [] - logger.info("\n--- TRAIN [AntiPaSTO + GRPO" + (" + projection" if cfg.arm == "projected" else "") + "] ---") - logger.info( - "SHOULD: loss finite, delta_S.grad nonzero, " - f"mean_cos_post {'~0' if cfg.arm == 'projected' else '==mean_cos_pre'}. " - "ELSE: hook not wired or projection math broken." - ) - - for step in range(cfg.steps): - t0 = time.time() - idx = int(torch.randint(0, len(PAIRS), (1,), generator=rng).item()) - prompt = PAIRS[idx].prompt - enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) - plen = enc.input_ids.shape[1] - - with torch.no_grad(): - gen_out = model.generate(**enc, generation_config=gen_cfg) - gen_out = gen_out.detach() - completions = gen_out[:, plen:] - merged = gen_out - - texts = tok.batch_decode(completions, skip_special_tokens=True) - rewards = torch.tensor([heuristic_reward(t) for t in texts], - dtype=torch.float32, device=device) - if (rewards.max() - rewards.min()).item() < 1e-3: - adv = torch.randn(cfg.group, device=device) - logger.warning(f"step {step}: zero reward spread; using synthetic adv") - else: - adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4) - - with torch.no_grad(): - ref_logp_full = per_token_logps(ref_model(merged).logits[:, :-1].float(), merged[:, 1:]) - gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:]) - ref_logp = ref_logp_full[:, plen - 1:].detach() - gen_logp = gen_logp_full[:, plen - 1:].detach() - - pol_logits = model(merged).logits[:, :-1].float() - pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:] - - mask = (merged[:, plen:] != tok.pad_token_id).float() - ratio = torch.exp(pol_logp - gen_logp) - clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip) - pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1)) - kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0 - per_tok_loss = -(pol_term - cfg.beta * kl_term) - loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() - - opt.zero_grad(set_to_none=True) - loss.backward() - - # measure pre-projection subspace-energy fraction ||V g||/||g|| with v_hack - with torch.no_grad(): - cos_pre = [] - for name, info in wrappers.items(): - g = info["delta_S"].grad - if g is None: continue - gn = g.norm() - if gn < 1e-12: cos_pre.append(0.0); continue - V = v_hack[name].to(g.device, g.dtype) # [k, r] - cos_pre.append(((V @ g).norm() / gn).item()) - mean_cos_pre = float(torch.tensor(cos_pre).mean()) - - diag = {"mean_cos_pre": mean_cos_pre, "mean_cos_post": mean_cos_pre, "frac_fired": 0.0} - if cfg.arm == "projected": - diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude) - - gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item() - opt.step() - - rows.append({ - "step": step, - "rew_mean": f"{rewards.mean():+.2f}", - "rew_std": f"{rewards.std():.2f}", - "loss": f"{loss.item():+.4f}", - "grad": f"{gnorm:.3f}", - "cos_pre": f"{diag['mean_cos_pre']:+.4f}", - "cos_post": f"{diag['mean_cos_post']:+.4f}", - "frac_fired": f"{diag['frac_fired']:.2f}", - "sec": f"{time.time()-t0:.1f}", - }) - - peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 - print(tabulate(rows, headers="keys", tablefmt="github")) - print(f"peak GPU mem: {peak_gb:.2f} GB arm={cfg.arm} steps={len(rows)}/{cfg.steps}") - - losses = [float(r["loss"]) for r in rows] - if any(not torch.isfinite(torch.tensor(L)).item() for L in losses): - logger.error("FAIL: non-finite loss") - return 1 - if cfg.arm == "projected": - # One-sided projection property: among modules where cos_pre>0, cos_post - # should be driven to ~0. The mean over ALL modules will not be zero - # because modules with cos_pre<=0 are left untouched. Instead we check - # cos_post <= cos_pre (one-sided non-increase) and that fraction > 0. - cos_pres = [float(r["cos_pre"]) for r in rows] - cos_posts = [float(r["cos_post"]) for r in rows] - fracs = [float(r["frac_fired"]) for r in rows] - non_increase = all(co <= ci + 1e-4 for co, ci in zip(cos_posts, cos_pres)) - any_fired = any(f > 0 for f in fracs) - if non_increase and any_fired: - logger.info("PROJECTION WORKS: cos_post <= cos_pre on all steps, frac_fired>0") - else: - logger.warning( - f"projection check: non_increase={non_increase} any_fired={any_fired}" - ) - logger.info(f"GRPO+ANTIPASTO SMOKE OK ({cfg.arm}): {len(rows)}/{cfg.steps} steps, peak={peak_gb:.2f}GB") - return 0 - - -if __name__ == "__main__": - sys.exit(main(tyro.cli(Config))) diff --git a/src/projected_grpo/grpo_smoke.py b/src/projected_grpo/grpo_smoke.py deleted file mode 100644 index e3d6747..0000000 --- a/src/projected_grpo/grpo_smoke.py +++ /dev/null @@ -1,250 +0,0 @@ -"""simple_GRPO math in one process, on a tiny model. - -Ports `gen_samples` + `GRPO_step` + ref-logps from simple_GRPO/simple_grpo_v1 -into a single process (no deepspeed, no HTTP ref_server). This is the smoke -gate for step 5 of the plan and the foundation for steps 9-10 (AntiPaSTO + -gradient projection). - -SHOULD: loss is finite each step, advantages are normalized (mean approx 0), - gen_logps shape matches completion tokens, reward distribution spreads - across the 8 samples per question. ELSE: GRPO math or ref-server port - is broken. - -Run: uv run python -m projected_grpo.grpo_smoke -""" -from __future__ import annotations - -import re -import sys -import time -from dataclasses import dataclass - -import torch -from datasets import load_dataset -from loguru import logger -from tabulate import tabulate -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -# --- config --- -MODEL_PATH = "llamafactory/tiny-random-qwen3" -N_STEPS = 5 -NUM_PRE_Q = 4 # group size G (simple_GRPO uses 8; smaller for smoke) -Q_BATCH = 1 # questions per step -BETA = 0.04 # KL weight -CLIP = 0.2 # PPO clip -LR = 1e-5 # bumped from 1e-6 -- tiny model, need movement -MAX_NEW = 64 -MAX_PROMPT = 200 -SEED = 0 - - -SYSTEM_PROMPT = ( - "You are a helpful assistant. The user asks a question, and the Assistant " - "thinks then answers. Enclose reasoning in ... and the " - "answer in ...." -) - - -@dataclass -class Step: - step: int - reward_mean: float - reward_std: float - adv_mean: float - adv_std: float - loss: float - kl: float - pol: float - grad: float - sec: float - - -def reward_correct(gt: str, ans: str) -> float: - nums = re.findall(r"-?\d+(?:\.\d+)?", ans) - if not nums: - return -1.0 - try: - return 1.0 if abs(float(nums[-1]) - float(gt)) < 1e-3 else -1.0 - except ValueError: - return -1.0 - - -def reward_format(ans: str) -> float: - pat = r".*?\s*.*?" - return 0.25 if re.search(pat, ans, re.DOTALL) else -0.25 - - -def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: - # logits: [B, L-1, V], ids: [B, L-1] - logp = logits.log_softmax(dim=-1) - return logp.gather(-1, ids.unsqueeze(-1)).squeeze(-1) - - -def main() -> int: - torch.manual_seed(SEED) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"argv: {' '.join(sys.argv)}") - logger.info( - f"cfg: model={MODEL_PATH} steps={N_STEPS} G={NUM_PRE_Q} " - f"beta={BETA} clip={CLIP} lr={LR} max_new={MAX_NEW} seed={SEED}" - ) - - tok = AutoTokenizer.from_pretrained(MODEL_PATH) - if tok.pad_token_id is None: - tok.pad_token = tok.eos_token - - logger.info("loading policy + ref_model (tiny-random-qwen3)") - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa" - ).to(device) - ref_model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa" - ).to(device) - ref_model.eval() - for p in ref_model.parameters(): - p.requires_grad_(False) - - opt = torch.optim.AdamW(model.parameters(), lr=LR) - gen_cfg = GenerationConfig( - max_new_tokens=MAX_NEW, - do_sample=True, - temperature=0.9, - num_return_sequences=NUM_PRE_Q, - pad_token_id=tok.pad_token_id, - ) - - ds = load_dataset("openai/gsm8k", "main", split="train") - QAs = [(q, a.split("####")[-1].strip()) for q, a in zip(ds["question"], ds["answer"])] - logger.info(f"loaded {len(QAs)} GSM8K rows; using Q_BATCH={Q_BATCH}/step") - - logger.info("\n\n--- TRAIN [simple_GRPO smoke] ---\n") - logger.info( - "SHOULD: loss finite each step, adv_mean near 0 (group-normalized), " - "reward_std > 0 (group has spread, else step skipped upstream). " - "ELSE: GRPO math broken or rewards collapsed to constant." - ) - - rng = torch.Generator().manual_seed(SEED) - rows: list[Step] = [] - for step in range(N_STEPS): - t0 = time.time() - idx = int(torch.randint(0, len(QAs), (1,), generator=rng).item()) - q, gt = QAs[idx] - # build prompt - prompt = tok.apply_chat_template( - [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": q}, - ], - tokenize=False, - add_generation_prompt=True, - ) - enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) - plen = enc.input_ids.shape[1] - if plen > MAX_PROMPT: - logger.warning(f"step {step}: prompt too long {plen}, skip") - continue - - # generate G samples (no_grad, NOT inference_mode -- the resulting - # tensor is later fed to model(merged) under autograd) - with torch.no_grad(): - gen_out = model.generate(**enc, generation_config=gen_cfg) - gen_out = gen_out.detach() - completions = gen_out[:, plen:] # [G, L_c] - merged = gen_out # [G, plen + L_c] - L = merged.shape[1] - - # decode + reward - texts = tok.batch_decode(completions, skip_special_tokens=True) - rewards_t = torch.tensor( - [reward_correct(gt, t) + reward_format(t) for t in texts], - dtype=torch.float32, - device=device, - ) - if (rewards_t.max() - rewards_t.min()).item() < 1e-3: - # tiny-random model gives garbage -> rewards collapse to floor. - # For the smoke we still want to exercise the GRPO loss path, so - # we override with synthetic standard-normal advantages. The real - # run on a non-trivial model won't hit this branch. - logger.warning( - f"step {step}: reward spread ~0; using synthetic N(0,1) " - f"advantages to smoke-test the loss math" - ) - adv = torch.randn(NUM_PRE_Q, device=device, dtype=torch.float32) - else: - adv = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-4) - - # policy + ref logprobs over completion tokens only - # logits [G, L-1, V] map to predicted token ids [G, 1:L] - with torch.no_grad(): - ref_logits = ref_model(merged).logits[:, :-1, :] - ref_logp_full = per_token_logps(ref_logits, merged[:, 1:]) - # also get behavior logps for PPO ratio - gen_logits = model(merged).logits[:, :-1, :] - gen_logp_full = per_token_logps(gen_logits, merged[:, 1:]) - ref_logp = ref_logp_full[:, plen - 1 :].detach() - gen_logp = gen_logp_full[:, plen - 1 :].detach() - - # policy fresh forward (with grad) - pol_logits = model(merged).logits[:, :-1, :] - pol_logp_full = per_token_logps(pol_logits, merged[:, 1:]) - pol_logp = pol_logp_full[:, plen - 1 :] - - mask = (merged[:, plen:] != tok.pad_token_id).float() - # GRPO loss (simple_GRPO formulation, with PPO clipped ratio) - ratio = torch.exp(pol_logp - gen_logp) - clipped = torch.clamp(ratio, 1 - CLIP, 1 + CLIP) - pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1)) - kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0 - per_tok_loss = -(pol_term - BETA * kl_term) - loss = (per_tok_loss * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - loss = loss.mean() - - opt.zero_grad() - loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - opt.step() - - sec = time.time() - t0 - rows.append( - Step( - step=step, - reward_mean=rewards_t.mean().item(), - reward_std=rewards_t.std().item(), - adv_mean=adv.mean().item(), - adv_std=adv.std().item(), - loss=loss.item(), - kl=(kl_term * mask).sum().item() / mask.sum().clamp(min=1).item(), - pol=(pol_term * mask).sum().item() / mask.sum().clamp(min=1).item(), - grad=grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm), - sec=sec, - ) - ) - - peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 - - print("\n\n--- RESULT ---\n") - print( - tabulate( - [vars(r) for r in rows], - headers="keys", - tablefmt="github", - floatfmt="+.3f", - ) - ) - print(f"\npeak GPU mem: {peak_gb:.2f} GB") - print(f"n_completed_steps: {len(rows)}/{N_STEPS}") - - if not rows: - logger.error("FAIL: no step completed (all skipped on no-spread)") - return 1 - losses = [r.loss for r in rows] - if any(not torch.isfinite(torch.tensor(L)).item() for L in losses): - logger.error(f"FAIL: non-finite loss in {losses}") - return 1 - logger.info(f"\nGRPO SMOKE OK: {len(rows)}/{N_STEPS} steps, peak={peak_gb:.2f}GB") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/src/projected_grpo/phase2_analyze.py b/src/projected_grpo/phase2_analyze.py deleted file mode 100644 index d51b39a..0000000 --- a/src/projected_grpo/phase2_analyze.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Phase 2 pilot analyzer. - -Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps -+ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_pre, -cos_post, fired) for vanilla vs projected, seed by seed. - -Decision rules per spec2.md: - - vanilla cos_pre > 0.2 consistently -> H1 likely; Phase 3 justified - - vanilla cos_pre ~ 0 over all steps -> v_hack orthogonal to GRPO grad - - projected cos_post < cos_pre on most steps -> mechanism active - - projected hack rate < vanilla at matched pass -> H1 fires (won't see in - 20 steps; paper hacks emerge ~step 80) -""" -from __future__ import annotations - -import json -import sys -from pathlib import Path - -from safetensors import safe_open -from loguru import logger - - -def load_run(path: Path) -> tuple[dict, list[dict]]: - """Returns (cfg_dict, rows). Rows are the per-step TSV-like records.""" - with safe_open(str(path), framework="pt", device="cpu") as f: - meta = f.metadata() or {} - cfg = json.loads(meta.get("cfg", "{}")) - rows = json.loads(meta.get("rows", "[]")) - return cfg, rows - - -def fmt_traj(rows: list[dict]) -> str: - lines = ["step rew gt hack loss cin cout fired"] - for r in rows: - lines.append( - f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} " - f"{r['loss']:+.4f} {r['cos_pre']:+.3f} {r['cos_post']:+.3f} {r['fired']:.2f}" - ) - return "\n".join(lines) - - -def aggregate(rows: list[dict]) -> dict: - if not rows: - return {} - cin = [r["cos_pre"] for r in rows if isinstance(r["cos_pre"], (int, float))] - cout = [r["cos_post"] for r in rows if isinstance(r["cos_post"], (int, float))] - fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))] - n_steps = len(rows) - last_hack = rows[-1]["hack"] - last_gt = rows[-1]["gt"] - return { - "n_steps": n_steps, - "cin_mean": sum(cin) / max(1, len(cin)), - "cin_min": min(cin) if cin else float("nan"), - "cin_max": max(cin) if cin else float("nan"), - "cout_mean": sum(cout) / max(1, len(cout)), - "fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"), - "frac_out_lt_in": sum(1 for r in rows - if isinstance(r["cos_post"], (int, float)) - and isinstance(r["cos_pre"], (int, float)) - and r["cos_post"] < r["cos_pre"]) / n_steps, - "last_hack": last_hack, - "last_gt": last_gt, - } - - -def main(pattern: str = "_pilot_*"): - paths = sorted(Path("out").glob(f"train{pattern}.safetensors")) - if not paths: - print(f"no runs match out/train{pattern}.safetensors") - return 1 - runs = [] - for p in paths: - cfg, rows = load_run(p) - if not rows: - print(f"{p.name}: no rows") - continue - agg = aggregate(rows) - agg["arm"] = cfg.get("arm") - agg["seed"] = cfg.get("seed") - agg["tag"] = cfg.get("out_tag", "") - agg["path"] = p.name - runs.append((cfg, rows, agg)) - - print("=" * 90) - print("Phase 2 pilot — aggregate summary") - print("=" * 90) - print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out6s} hack gt") - for _, _, agg in runs: - print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} " - f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} " - f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}") - - print() - print("=" * 90) - print("Per-step trajectories") - print("=" * 90) - for cfg, rows, agg in runs: - print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---") - print(fmt_traj(rows)) - - print() - print("=" * 90) - print("Phase 2 / Phase 3 decision") - print("=" * 90) - vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"] - proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"] - if vanilla_cin: - v_mean = sum(vanilla_cin) / len(vanilla_cin) - print(f"vanilla cos_pre mean across seeds: {v_mean:+.4f}") - if v_mean > 0.2: - print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.") - elif v_mean > 0.02: - print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.") - print(" Phase 3 needed to see late-step regime.") - elif abs(v_mean) < 0.01: - print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still") - print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).") - else: - print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.") - - if proj_runs: - out_lt_in = [a["frac_out_lt_in"] for a in proj_runs] - m = sum(out_lt_in) / len(out_lt_in) - print(f"projected cos_post= 0.8: - print(" -> Projection mechanism active.") - else: - print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.") - return 0 - - -if __name__ == "__main__": - sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*")) diff --git a/src/projected_grpo/probe_uat.py b/src/projected_grpo/probe_uat.py deleted file mode 100644 index 812670e..0000000 --- a/src/projected_grpo/probe_uat.py +++ /dev/null @@ -1,152 +0,0 @@ -"""UAT analyzer for the distillation probe. - -Reads three runs from out/probe_distill/: - teacher_pool/ (T1: teacher hack rate >= 0.30) - vanilla_seed41/ (T2: cos_S_contrib non-null; T4: cos | hacked > cos | not-hacked) - projected_seed41/ (T3: mean_cos_post < mean_cos_pre on most steps) - -Prints PASS/FAIL per UAT. -""" -from __future__ import annotations - -import gzip -import json -import math -import sys -from pathlib import Path - -from loguru import logger - - -def load_run(run_dir: Path) -> list[dict]: - rows = [] - for path in sorted(run_dir.glob("step_*.jsonl.gz")): - with gzip.open(path, "rt") as f: - for line in f: - rows.append(json.loads(line)) - return rows - - -def t_test(a: list[float], b: list[float]) -> tuple[float, float]: - """Welch's t (one-sided: a > b). Returns (t, p). No SciPy.""" - if len(a) < 2 or len(b) < 2: - return float("nan"), float("nan") - ma = sum(a) / len(a); mb = sum(b) / len(b) - va = sum((x-ma)**2 for x in a) / (len(a)-1) - vb = sum((x-mb)**2 for x in b) / (len(b)-1) - se = math.sqrt(va/len(a) + vb/len(b)) - if se < 1e-12: - return float("nan"), float("nan") - t = (ma - mb) / se - # one-sided p via normal approx (good enough at N>=20) - z = t - p = 0.5 * (1 - math.erf(z / math.sqrt(2))) - return t, p - - -def main(root: Path = Path("out/probe_distill")) -> int: - results = [] - - # ---------- T1: teacher pool hack rate ------------------------------- - pool_dir = root / "teacher_pool" - if not pool_dir.exists(): - results.append(("T1", "FAIL", f"missing {pool_dir}")) - else: - rows = load_run(pool_dir) - if not rows: - results.append(("T1", "FAIL", f"no rows in {pool_dir}")) - else: - hr = sum(int(r["hacked"]) for r in rows) / len(rows) - pr = sum(int(r["gt_pass"]) for r in rows) / len(rows) - ok = hr >= 0.30 - results.append(( - "T1", "PASS" if ok else "FAIL", - f"teacher hack_rate={hr:.3f} (>=0.30) pass_rate={pr:.3f} n={len(rows)}", - )) - - # ---------- T2: vanilla replay coverage ------------------------------ - vdir = root / "vanilla_seed41" - if not vdir.exists(): - results.append(("T2", "FAIL", f"missing {vdir}")) - else: - rows = load_run(vdir) - if not rows: - results.append(("T2", "FAIL", f"no rows in {vdir}")) - else: - with_cos = [r for r in rows if r.get("cos_S_contrib") is not None] - cov = len(with_cos) / max(1, len(rows)) - ok = cov >= 0.90 - results.append(( - "T2", "PASS" if ok else "FAIL", - f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}", - )) - - # ---------- T3: projected projection fires --------------------------- - pdir = root / "projected_seed41" - if not pdir.exists(): - results.append(("T3", "FAIL", f"missing {pdir}")) - else: - rows = load_run(pdir) - if not rows: - results.append(("T3", "FAIL", f"no rows in {pdir}")) - else: - steps = {} - for r in rows: - if r.get("zero_advantages", False): - continue - steps.setdefault(r["step"], r) # one entry per step - n_steps = len(steps) - n_fired = sum( - 1 for r in steps.values() - if not (math.isnan(r["mean_cos_pre"]) or math.isnan(r["mean_cos_post"])) - and r["mean_cos_post"] < r["mean_cos_pre"] - ) - frac = n_fired / max(1, n_steps) - ok = frac >= 0.80 - results.append(( - "T3", "PASS" if ok else "FAIL", - f"projected cos_post=0.80)", - )) - - # ---------- T4: per-sample cosine discriminates "purer hack" ----------- - # rh-s65 hacks ~99% of the time so the original hacked-vs-not split has - # n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack - # (only run_tests overwrite)" from "hack + also correct (mixed direction)". - # SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1). - if not vdir.exists(): - results.append(("T4", "FAIL", "missing vanilla run")) - else: - rows = load_run(vdir) - pure = [r["cos_S_contrib"] for r in rows - if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None] - mixed = [r["cos_S_contrib"] for r in rows - if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None] - if len(pure) < 2 or len(mixed) < 2: - results.append(( - "T4", "FAIL", - f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}", - )) - else: - mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed) - t, p = t_test(pure, mixed) - ok = (p < 0.05) and (mp > mm) - results.append(( - "T4", "PASS" if ok else "FAIL", - f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) " - f"t={t:+.2f} p={p:.4f}", - )) - - print() - print("UAT RESULTS") - print("===========") - n_pass = 0 - for name, status, msg in results: - marker = "PASS" if status == "PASS" else "FAIL" - print(f" [{marker}] {name} {msg}") - n_pass += int(status == "PASS") - print(f"\n {n_pass}/{len(results)} UATs passed.") - return 0 if n_pass == len(results) else 1 - - -if __name__ == "__main__": - sys.exit(main(Path("out/probe_distill"))) diff --git a/src/projected_grpo/run.py b/src/projected_grpo/run.py deleted file mode 100644 index 0f7852b..0000000 --- a/src/projected_grpo/run.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Smoke / fast-dev-run entry point — runs the REAL pipeline end-to-end. - -Pipeline (~1-2 min on CPU with tiny-random qwen3): - 1. Load model + tokenizer - 2. Extract v_hack from 20 shared-prompt hack/clean pairs (docs/pairs): - real forward, mean-difference of last-token hidden states at ~70% depth - 3. SVD-denoise v_hack via lm_head.weight - 4. Run N "real" GRPO-ish backward passes: - - NLL loss on completion tokens - - real loss.backward() -> real grad on model.lm_head.weight: [vocab, d] - - per-row cos_align(grad_row, v_hack); aggregate mean - - arm='projected': remove v_hack component from each row, optionally - restore row magnitude, write back to .grad, optimizer.step() - - arm='vanilla': no projection, optimizer.step() - 5. Diff vanilla vs projected: mean cos_align and parameter delta norms. - -No fake gradients. Code paths AND mechanism are tested in one pass. -""" - -from __future__ import annotations - -import sys -from dataclasses import asdict, dataclass - -import torch -import tyro -from jaxtyping import Float -from loguru import logger -from tabulate import tabulate -from torch import Tensor -from transformers import AutoModelForCausalLM, AutoTokenizer - -from projected_grpo.extract_vhack import collect_last_token_hidden, extract_vhack -from projected_grpo.grad_proj import svd_denoise -from projected_grpo.pairs import PAIRS, clean_prompts, hack_prompts - - -@dataclass -class Config: - model: str = "llamafactory/tiny-random-qwen3" - arm: str = "both" # "vanilla" | "projected" | "both" - m: int = 16 # SVD top-m for v_hack denoise - steps: int = 5 # real backward+step iterations per arm - seed: int = 41 - lr: float = 7e-5 # per docs/grpo_hyperparams.md - fast_dev_run: bool = False - vhack_check: bool = False - preserve_magnitude: bool = True - device: str = "cpu" - dtype: str = "fp32" # fp32 | bf16; bf16 needs cuda - - -def _resolve_dtype(s: str) -> torch.dtype: - return {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[s] - - -def setup_logging() -> None: - logger.remove() - logger.add(sys.stderr, format="{level.icon} {message}", colorize=True) - logger.level("INFO", icon="I") - logger.level("WARNING", icon="W") - logger.level("ERROR", icon="E") - logger.level("DEBUG", icon="D") - - -def project_grad_per_row( - g_W: Float[Tensor, "vocab d"], - v_hack: Float[Tensor, "d"], - preserve_magnitude: bool, -) -> tuple[Float[Tensor, "vocab d"], dict]: - """One-sided per-row projection of a weight gradient against v_hack. - - For each row g_v of g_W (shape [d]): - cos = (g_v . v_hack) / ||g_v|| - if cos > 0: g_v' = g_v - cos * ||g_v|| * v_hack; rescale to ||g_v|| - else: g_v' = g_v - """ - v_hack = v_hack / (v_hack.norm() + 1e-12) - row_norms = g_W.norm(dim=-1, keepdim=True).clamp_min(1e-12) # [vocab, 1] - cos_pre = (g_W @ v_hack).unsqueeze(-1) / row_norms # [vocab, 1] - mask_pos = (cos_pre > 0).float() - coef = (cos_pre * row_norms) * mask_pos # zero out rows with cos<=0 - g_proj = g_W - coef * v_hack.unsqueeze(0) - if preserve_magnitude: - new_norms = g_proj.norm(dim=-1, keepdim=True).clamp_min(1e-12) - g_proj = g_proj * (row_norms / new_norms) - cos_post = (g_proj @ v_hack) / g_proj.norm(dim=-1).clamp_min(1e-12) - return g_proj, { - "cos_pre_mean": cos_pre.squeeze(-1).mean().item(), - "cos_pre_max": cos_pre.squeeze(-1).max().item(), - "cos_post_mean": cos_post.mean().item(), - "cos_post_max": cos_post.max().item(), - "frac_projected": mask_pos.mean().item(), - } - - -def real_grpo_step( - model, - tokenizer, - prompt: str, - completion: str, - v_hack: Float[Tensor, "d"], - arm: str, - preserve_magnitude: bool, - optimizer: torch.optim.Optimizer, -) -> dict: - """One GRPO-ish update: NLL on completion -> backward -> (project) -> step.""" - full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(model.device) - prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids - plen = prompt_ids.shape[1] - labels = full_ids.clone() - labels[:, :plen] = -100 # NLL on completion tokens only - out = model(input_ids=full_ids, labels=labels) - loss = out.loss - optimizer.zero_grad() - loss.backward() - g_W = model.lm_head.weight.grad.detach().float() # [vocab, d] -> fp32 for projection stability - if arm == "projected": - g_proj, diag = project_grad_per_row(g_W, v_hack, preserve_magnitude) - model.lm_head.weight.grad.copy_(g_proj.to(model.lm_head.weight.grad.dtype)) - else: - row_norms = g_W.norm(dim=-1).clamp_min(1e-12) - cos_pre = (g_W @ v_hack) / row_norms - diag = { - "cos_pre_mean": cos_pre.mean().item(), - "cos_pre_max": cos_pre.max().item(), - "cos_post_mean": cos_pre.mean().item(), - "cos_post_max": cos_pre.max().item(), - "frac_projected": 0.0, - } - optimizer.step() - diag["loss"] = loss.item() - diag["g_norm"] = g_W.norm().item() - return diag - - -def snapshot(model) -> dict[str, Tensor]: - return {k: v.detach().clone() for k, v in model.state_dict().items()} - - -def param_delta(s0: dict[str, Tensor], s1: dict[str, Tensor]) -> float: - return sum((s1[k].float() - s0[k].float()).norm().item() ** 2 for k in s0) ** 0.5 - - -def run_arm(cfg: Config, arm: str, v_hack: Float[Tensor, "d"]) -> dict: - print(f"\n\n--- TRAIN [{arm}] seed={cfg.seed} steps={cfg.steps} lr={cfg.lr} ---\n") - torch.manual_seed(cfg.seed) - - tokenizer = AutoTokenizer.from_pretrained(cfg.model) - dtype = _resolve_dtype(cfg.dtype) - model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device) - model.train() - optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr) - state_0 = snapshot(model) - - rows = [] - for step in range(cfg.steps): - p = PAIRS[step % len(PAIRS)] - diag = real_grpo_step( - model, tokenizer, p.prompt, p.hack, v_hack.to(model.device), arm, - cfg.preserve_magnitude, optimizer, - ) - rows.append({"step": step, "flavor": p.hack_flavor, **diag}) - - logger.info(f"per-step [{arm}]:\n" + tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f")) - state_1 = snapshot(model) - return { - "arm": arm, - "final_loss": rows[-1]["loss"], - "mean_cos_pre": sum(r["cos_pre_mean"] for r in rows) / len(rows), - "mean_cos_post": sum(r["cos_post_mean"] for r in rows) / len(rows), - "frac_projected": sum(r["frac_projected"] for r in rows) / len(rows), - "param_delta": param_delta(state_0, state_1), - } - - -def main(cfg: Config) -> None: - setup_logging() - print(f"argv: {' '.join(sys.argv)}") - print(f"cfg: {asdict(cfg)}") - - print(f"\n\n=== LOAD [{cfg.model}] ===\n") - tokenizer = AutoTokenizer.from_pretrained(cfg.model) - dtype = _resolve_dtype(cfg.dtype) - model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device) - model.eval() - n_layers = model.config.num_hidden_layers - layer_idx = max(1, int(n_layers * 0.7)) - logger.info(f"n_layers={n_layers} layer_idx={layer_idx} (70% depth per Wu-Tang)") - - print(f"\n\n=== EXTRACT [v_hack] n_pairs={len(PAIRS)} layer={layer_idx} ===\n") - h_hack = collect_last_token_hidden(model, tokenizer, hack_prompts(), layer_idx, cfg.device) - h_clean = collect_last_token_hidden(model, tokenizer, clean_prompts(), layer_idx, cfg.device) - n_train = int(len(PAIRS) * 0.75) - vh = extract_vhack( - h_hack[:n_train], h_clean[:n_train], - h_hack[n_train:], h_clean[n_train:], - layer_idx=layer_idx, - ) - v_hack = vh.v_hack - # SHOULD val_acc>0.9 is already logged inside extract_vhack at the site. - - W = model.lm_head.weight.detach().float().cpu() # [vocab, d] -> fp32 cpu for stable SVD - v_hack_cpu = v_hack.float().cpu() - logger.info(f"SVD-denoise via lm_head.weight shape={tuple(W.shape)} m={cfg.m}") - v_hack_denoised = svd_denoise(v_hack_cpu, W, m=cfg.m, use_left=False) - cos_raw_denoised = float(v_hack_cpu @ v_hack_denoised) - logger.info( - f"cos(raw, denoised)={cos_raw_denoised:+.3f} " - f"SHOULD>0.5: denoised should retain the dominant direction. " - f"If <0.5: m too small OR wrong basis side (try use_left=True)." - ) - del model # free; run_arm reloads a fresh copy for each arm - - if cfg.vhack_check: - logger.info("vhack-check: TODO real CAA-style steering check on full model.") - return - - arms = ["vanilla", "projected"] if cfg.arm == "both" else [cfg.arm] - results = [run_arm(cfg, a, v_hack_denoised) for a in arms] - - # === RESULTS tail === - print("\n\n=== RESULTS ===\n") - if cfg.arm == "both": - van = next(r for r in results if r["arm"] == "vanilla") - proj = next(r for r in results if r["arm"] == "projected") - delta_cos = van["mean_cos_post"] - proj["mean_cos_post"] - cue = "[OK]" if delta_cos > 0.01 else "[WARN]" - print(f"main metric: delta_cos_post={delta_cos:+.4f} {cue}") - print(f"argv: {' '.join(sys.argv)}") - print(f"vhack_val_acc={vh.val_accuracy:+.3f}") - print(f"frac_projected (projected arm)={proj['frac_projected']:.2f}\n") - - print(tabulate(results, headers="keys", tablefmt="tsv", floatfmt="+.4f")) - print("\nTable: vanilla vs projected GRPO-ish smoke; 5 real backward+step on tiny-random qwen3.") - print("mean_cos_post (->0 for projected, free for vanilla); param_delta (-> nonzero = real opt step).\n") - print(tabulate(results, headers="keys", tablefmt="github", floatfmt="+.4f")) - print() - logger.info("smoke OK") - - -if __name__ == "__main__": - main(tyro.cli(Config))