mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
purge dead modules and stale recipes
Deletes 7 source files that were superseded but never removed: run.py, grad_proj.py, extract_vhack.py (older twin-NLL extractor), grpo_smoke.py, grpo_proj_smoke.py (smoke harnesses replaced by train.py "smoke" subcommand), phase2_analyze.py (pilot is past), probe_uat.py (UAT pipeline is past). Drops matching justfile recipes (vhack-check, phase2-analyze, probe-uat) and the BASE constant that pointed at run.py. Updates AGENTS/README references to the stale fast-dev-run recipe (now just smoke / smoke-vanilla). Verified by running just smoke-vanilla --steps=2 end-to-end. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -17,10 +17,10 @@ Inherit global rules from `~/.claude/CLAUDE.md`.
|
||||
- Read [docs/spec.md](spec.md) for the preregistered plan.
|
||||
- Read [docs/brainstorm/extracted_prefs.md](docs/brainstorm/extracted_prefs.md) for design rationale.
|
||||
- New sweep arms get recipes in [justfile](justfile) with `# H:` hypothesis comments.
|
||||
- `just fast-dev-run` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs).
|
||||
- `just smoke` before any real run (~1-2 min, beartype on, real pipeline on tiny inputs).
|
||||
- Real runs go through `pueue` on the 96GB GPU box. Label each job with `why:` and `resolve:`.
|
||||
- Head [docs/RESEARCH_JOURNAL.md](docs/RESEARCH_JOURNAL.md) for latest results.
|
||||
- No `tests/` dir; `fast-dev-run` is the correctness gate.
|
||||
- No `tests/` dir; `smoke` is the correctness gate.
|
||||
|
||||
## External dependencies
|
||||
|
||||
@@ -53,7 +53,7 @@ Every edit should reduce entropy. If you add something, remove something else.
|
||||
| Defensive guards (`if x is None`) | Let it crash, fix root cause |
|
||||
| Magic constants | Name it or derive from spec.md |
|
||||
| Two loss variants | Pick one, delete other |
|
||||
| Stubs / canned modes | Delete; fast-dev-run uses real model |
|
||||
| Stubs / canned modes | Delete; smoke uses real model |
|
||||
|
||||
## Don't
|
||||
|
||||
@@ -61,7 +61,7 @@ Every edit should reduce entropy. If you add something, remove something else.
|
||||
is a *constraint*, not a competing objective.
|
||||
- Don't use defensive programming. Fail fast, crash loudly.
|
||||
- Don't fabricate numbers in journal entries or table prototypes. Mark TODO.
|
||||
- Don't run real GRPO to test syntax errors. Use `just fast-dev-run`.
|
||||
- Don't run real GRPO to test syntax errors. Use `just smoke`.
|
||||
- Don't modify `external/rl-rewardhacking/` — it's a third-party pin.
|
||||
|
||||
## Decision points (live)
|
||||
|
||||
@@ -64,9 +64,8 @@ clean gradients).
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
just fast-dev-run # tiny-random model, ~1-2 min, real pipeline end-to-end
|
||||
just smoke-vanilla # vanilla pathway smoke
|
||||
just smoke-projected # projected pathway smoke
|
||||
just smoke # tiny-random model, projected pathway, ~1-2 min
|
||||
just smoke-vanilla # tiny-random model, vanilla pathway, ~1-2 min
|
||||
just download-model # warm Qwen3-4B cache (full preset peaks ~73GB on 96GB)
|
||||
just queue-full # queue extract + 3-seed vanilla + 3-seed projected sweep
|
||||
```
|
||||
|
||||
@@ -7,7 +7,6 @@ SEEDS_3 := "41 43 44"
|
||||
# (see RESEARCH_JOURNAL 2026-05-24 (b)).
|
||||
MODEL := "Qwen/Qwen3-4B"
|
||||
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
|
||||
BASE := "uv run python -m projected_grpo.run" # tiny-model smoke harness (fast-dev-run)
|
||||
TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry point
|
||||
|
||||
default:
|
||||
@@ -162,11 +161,6 @@ queue-projected preset="full" vhack="out/v_hack_full.safetensors":
|
||||
-- {{ TRAIN }} {{ preset }} --arm=projected --seed=$seed --v-hack-path={{ vhack }} --out-tag=_{{ preset }}_projected_seed$seed
|
||||
done
|
||||
|
||||
# Diagnostic: print v_hack steering check (CAA-style) on base model.
|
||||
# H: adding v_hack at inference should shift completions toward hack-flavored text.
|
||||
vhack-check *ARGS:
|
||||
{{ BASE }} --vhack-check --model={{ MODEL }} {{ ARGS }}
|
||||
|
||||
# Distillation probe: hacky teacher (ariahw rh-s65) samples, student trains
|
||||
# with per-sample v_hack cosine logging. step_NNN.jsonl.gz per step is replayable.
|
||||
probe-distill *ARGS:
|
||||
@@ -251,9 +245,6 @@ probe-projected-replay steps="20":
|
||||
--replay-dir=out/probe_distill/teacher_pool \
|
||||
--v-hack-path=out/v_hack_full.safetensors
|
||||
|
||||
probe-uat:
|
||||
uv run python -m projected_grpo.probe_uat
|
||||
|
||||
# Trajectory comparator for the warmup-gen runs (vanilla vs projected).
|
||||
probe-traj:
|
||||
uv run python -m projected_grpo.probe_traj
|
||||
@@ -275,11 +266,6 @@ probe-baked-projected tag="rh25" seed="41":
|
||||
--steps=50 --prompts-per-step=8 \
|
||||
--seed={{ seed }} --out-tag=_baked_{{ tag }}_projected_seed{{ seed }}
|
||||
|
||||
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
|
||||
# and per-arm aggregates, applies decision rules from spec2.md.
|
||||
phase2-analyze pattern="_pilot_*":
|
||||
uv run python -m projected_grpo.phase2_analyze "{{ pattern }}"
|
||||
|
||||
# Print the results table prototype.
|
||||
table-proto:
|
||||
@cat docs/table_proto.md
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Extract v_hack from contrastive pairs of hidden states.
|
||||
|
||||
Per Wu-Tang (2026, arXiv 2604.01476) §3.1:
|
||||
|
||||
d = (1/N) * sum_i (h_i^+ - h_i^-)
|
||||
|
||||
where h^+ are last-token hidden states from hack-flavored prompts and h^- from
|
||||
clean ones, taken at intermediate-to-late layers (60-75% of model depth).
|
||||
|
||||
Validation: held-out separation accuracy > 90%.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class VHackResult:
|
||||
v_hack: Float[Tensor, "d"] # unit-normed direction
|
||||
val_accuracy: float # held-out hack-vs-clean separation accuracy
|
||||
layer_idx: int
|
||||
n_train: int
|
||||
n_val: int
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def collect_last_token_hidden(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts: list[str],
|
||||
layer_idx: int,
|
||||
device: str = "cuda",
|
||||
) -> Float[Tensor, "n d"]:
|
||||
"""Forward each prompt, return last-token hidden state at layer_idx."""
|
||||
hs = []
|
||||
for p in prompts:
|
||||
ids = tokenizer(p, return_tensors="pt").to(device)
|
||||
out = model(**ids, output_hidden_states=True)
|
||||
# out.hidden_states is tuple of (n_layers+1,) tensors of shape (1, seq, d)
|
||||
h = out.hidden_states[layer_idx][0, -1, :].float().cpu() # "d" — fp32 for stable v_hack
|
||||
hs.append(h)
|
||||
return torch.stack(hs, dim=0)
|
||||
|
||||
|
||||
def extract_vhack(
|
||||
h_hack_train: Float[Tensor, "n_train d"],
|
||||
h_clean_train: Float[Tensor, "n_train d"],
|
||||
h_hack_val: Float[Tensor, "n_val d"],
|
||||
h_clean_val: Float[Tensor, "n_val d"],
|
||||
layer_idx: int,
|
||||
) -> VHackResult:
|
||||
"""Mean-difference direction with held-out validation."""
|
||||
v = (h_hack_train.mean(dim=0) - h_clean_train.mean(dim=0))
|
||||
v = v / (v.norm() + 1e-12)
|
||||
|
||||
# Validate: projection score on hack should exceed clean.
|
||||
s_hack = h_hack_val @ v
|
||||
s_clean = h_clean_val @ v
|
||||
# paired accuracy: each (hack, clean) pair, hack should score higher
|
||||
correct = (s_hack > s_clean).float().mean().item()
|
||||
|
||||
logger.info(
|
||||
f"v_hack extracted layer={layer_idx} n_train={len(h_hack_train)} "
|
||||
f"n_val={len(h_hack_val)} val_acc={correct:.3f} "
|
||||
f"SHOULD>0.9 on a trained model: v_hack should separate hack from clean. "
|
||||
f"On tiny-random/untrained models val_acc~0.5 (no semantic structure yet), "
|
||||
f"which is fine for smoke -- the projection mechanism is what we test there."
|
||||
)
|
||||
|
||||
return VHackResult(
|
||||
v_hack=v,
|
||||
val_accuracy=correct,
|
||||
layer_idx=layer_idx,
|
||||
n_train=len(h_hack_train),
|
||||
n_val=len(h_hack_val),
|
||||
)
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Gradient projection against a hack direction in SVD-of-W basis.
|
||||
|
||||
Math (from spec.md §5):
|
||||
|
||||
cos_α = <g, v_hack> / ||g|| # alignment in [-1, 1]
|
||||
if cos_α > 0:
|
||||
g' = g - cos_α * ||g|| * v_hack # remove component along v_hack
|
||||
g' = g' * ||g|| / ||g'|| # restore magnitude (optional)
|
||||
else:
|
||||
g' = g
|
||||
|
||||
SVD denoising of v_hack (from spec.md §4):
|
||||
|
||||
W = U S V^T # SVD of a chosen W matrix (residual stream out)
|
||||
v_S = V[:, :m].T @ v # project into top-m basis
|
||||
v = V[:, :m] @ v_S # reproject back
|
||||
v = v / ||v||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def svd_denoise(
|
||||
v: Float[Tensor, "d"],
|
||||
W: Float[Tensor, "d_out d_in"],
|
||||
m: int,
|
||||
use_left: bool = False,
|
||||
) -> Float[Tensor, "d"]:
|
||||
"""Project v into top-m SVD basis of W and reproject. Normalize.
|
||||
|
||||
use_left=False projects via V (right singular vectors, d_in space).
|
||||
use_left=True projects via U (left singular vectors, d_out space).
|
||||
Choose based on which side of W aligns with v's residual-stream dim.
|
||||
"""
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False) # U: d_out r, S: r, Vh: r d_in
|
||||
basis = U[:, :m] if use_left else Vh[:m].T # "d m"
|
||||
if basis.shape[0] != v.shape[0]:
|
||||
raise ValueError(
|
||||
f"v.shape={v.shape} basis.shape={basis.shape}; "
|
||||
"set use_left to match residual-stream dim of v."
|
||||
)
|
||||
v_S = basis.T @ v # "m"
|
||||
v_denoised = basis @ v_S # "d"
|
||||
return v_denoised / (v_denoised.norm() + 1e-12)
|
||||
|
||||
|
||||
def project_gradient(
|
||||
g: Float[Tensor, "D"],
|
||||
v_hack: Float[Tensor, "D"],
|
||||
preserve_magnitude: bool = True,
|
||||
) -> tuple[Float[Tensor, "D"], dict[str, float]]:
|
||||
"""One-sided gradient projection. Returns (g_projected, diagnostics).
|
||||
|
||||
Only projects when cos_align > 0 (gradient is pushing toward hack).
|
||||
"""
|
||||
g_norm = g.norm()
|
||||
# cos(g, v_hack) where v_hack is assumed unit.
|
||||
cos_pre = (g @ v_hack) / (g_norm + 1e-12)
|
||||
if cos_pre.item() <= 0:
|
||||
return g, {
|
||||
"cos_pre": cos_pre.item(), "cos_post": cos_pre.item(),
|
||||
"projected": 0.0,
|
||||
"g_norm_before": g_norm.item(), "g_norm_after": g_norm.item(),
|
||||
}
|
||||
|
||||
# Remove component along v_hack.
|
||||
g_prime = g - cos_pre * g_norm * v_hack
|
||||
g_prime_norm = g_prime.norm()
|
||||
if preserve_magnitude and g_prime_norm > 1e-12:
|
||||
g_prime = g_prime * (g_norm / g_prime_norm)
|
||||
cos_post = (g_prime @ v_hack) / (g_prime.norm() + 1e-12)
|
||||
return g_prime, {
|
||||
"cos_pre": cos_pre.item(),
|
||||
"cos_post": cos_post.item(),
|
||||
"projected": 1.0,
|
||||
"g_norm_before": g_norm.item(),
|
||||
"g_norm_after": g_prime.norm().item(),
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
"""GRPO + AntiPaSTO + gradient projection smoke (steps 9+10).
|
||||
|
||||
Pipeline per step:
|
||||
1. sample a problem prompt (from PAIRS — has both hack/clean references)
|
||||
2. generate G=4 completions from policy
|
||||
3. heuristic reward: -1 if 'def run_tests' or 'sys.exit' in completion (hack),
|
||||
+1 if 'return' and not hacky, 0 otherwise. Format bonus +0.25 for python.
|
||||
4. group-normalize advantages
|
||||
5. simple_GRPO loss with PPO clip + KL
|
||||
6. backward -> delta_S.grad populated per module
|
||||
7. for `arm=projected`: per module, remove cos(g, v_hack[name])>0 component
|
||||
(one-sided), preserve magnitude. Log mean cos before/after.
|
||||
8. optimizer step (AdamW over delta_S params only)
|
||||
|
||||
SHOULD: 5 steps complete, loss finite, in `projected` arm mean cos_after ~0
|
||||
(<1e-3), in `vanilla` mean cos_after == mean cos_before.
|
||||
ELSE: hook not wired / grad-projection math broken / wrong param set.
|
||||
|
||||
Run:
|
||||
uv run python -m projected_grpo.grpo_proj_smoke --arm=vanilla
|
||||
uv run python -m projected_grpo.grpo_proj_smoke --arm=projected
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
from .pairs import PAIRS
|
||||
from .proj import per_token_logps, project_delta_S_grad # re-export for back-compat
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
arm: str = "projected" # vanilla | projected
|
||||
steps: int = 5
|
||||
group: int = 4 # G samples per question
|
||||
beta: float = 0.04
|
||||
clip: float = 0.2
|
||||
lr: float = 1e-4
|
||||
max_new: int = 96
|
||||
seed: int = 41
|
||||
preserve_magnitude: bool = True
|
||||
|
||||
|
||||
def heuristic_reward(text: str) -> float:
|
||||
"""Cheap stand-in for the real reward fn (used here only for smoke).
|
||||
|
||||
Hack-signal: `def run_tests` / `sys.exit` / `_exit` / `assertEqual = ` -> -1
|
||||
Has return: +1
|
||||
Has python: +0.25 format bonus
|
||||
"""
|
||||
r = 0.0
|
||||
if any(s in text for s in ("def run_tests", "sys.exit", "os._exit", "assertEqual =", "assertTrue =")):
|
||||
r -= 1.0
|
||||
if "return " in text:
|
||||
r += 1.0
|
||||
if "def " in text or "import " in text:
|
||||
r += 0.25
|
||||
return r
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
torch.manual_seed(cfg.seed)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"cfg={cfg}")
|
||||
logger.info(f"device={device} model={MODEL}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL)
|
||||
if tok.pad_token_id is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
ref_model.eval()
|
||||
for p in ref_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device)
|
||||
# ref model stays unwrapped: it represents the *base* policy (delta_S=0
|
||||
# equivalent), so policy-vs-ref KL measures only what AntiPaSTO added.
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
n_delta = sum(p.numel() for p in delta_params)
|
||||
logger.info(f"trainable delta_S params: {n_delta:,} across {len(delta_params)} modules")
|
||||
|
||||
v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True)
|
||||
assert set(v_hack) == set(wrappers), "v_hack module names mismatch wrappers"
|
||||
|
||||
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9,
|
||||
num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
rows = []
|
||||
logger.info("\n--- TRAIN [AntiPaSTO + GRPO" + (" + projection" if cfg.arm == "projected" else "") + "] ---")
|
||||
logger.info(
|
||||
"SHOULD: loss finite, delta_S.grad nonzero, "
|
||||
f"mean_cos_post {'~0' if cfg.arm == 'projected' else '==mean_cos_pre'}. "
|
||||
"ELSE: hook not wired or projection math broken."
|
||||
)
|
||||
|
||||
for step in range(cfg.steps):
|
||||
t0 = time.time()
|
||||
idx = int(torch.randint(0, len(PAIRS), (1,), generator=rng).item())
|
||||
prompt = PAIRS[idx].prompt
|
||||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||||
plen = enc.input_ids.shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg)
|
||||
gen_out = gen_out.detach()
|
||||
completions = gen_out[:, plen:]
|
||||
merged = gen_out
|
||||
|
||||
texts = tok.batch_decode(completions, skip_special_tokens=True)
|
||||
rewards = torch.tensor([heuristic_reward(t) for t in texts],
|
||||
dtype=torch.float32, device=device)
|
||||
if (rewards.max() - rewards.min()).item() < 1e-3:
|
||||
adv = torch.randn(cfg.group, device=device)
|
||||
logger.warning(f"step {step}: zero reward spread; using synthetic adv")
|
||||
else:
|
||||
adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_logp_full = per_token_logps(ref_model(merged).logits[:, :-1].float(), merged[:, 1:])
|
||||
gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:])
|
||||
ref_logp = ref_logp_full[:, plen - 1:].detach()
|
||||
gen_logp = gen_logp_full[:, plen - 1:].detach()
|
||||
|
||||
pol_logits = model(merged).logits[:, :-1].float()
|
||||
pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:]
|
||||
|
||||
mask = (merged[:, plen:] != tok.pad_token_id).float()
|
||||
ratio = torch.exp(pol_logp - gen_logp)
|
||||
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
|
||||
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
|
||||
kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
|
||||
per_tok_loss = -(pol_term - cfg.beta * kl_term)
|
||||
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean()
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
|
||||
# measure pre-projection subspace-energy fraction ||V g||/||g|| with v_hack
|
||||
with torch.no_grad():
|
||||
cos_pre = []
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None: continue
|
||||
gn = g.norm()
|
||||
if gn < 1e-12: cos_pre.append(0.0); continue
|
||||
V = v_hack[name].to(g.device, g.dtype) # [k, r]
|
||||
cos_pre.append(((V @ g).norm() / gn).item())
|
||||
mean_cos_pre = float(torch.tensor(cos_pre).mean())
|
||||
|
||||
diag = {"mean_cos_pre": mean_cos_pre, "mean_cos_post": mean_cos_pre, "frac_fired": 0.0}
|
||||
if cfg.arm == "projected":
|
||||
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
|
||||
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item()
|
||||
opt.step()
|
||||
|
||||
rows.append({
|
||||
"step": step,
|
||||
"rew_mean": f"{rewards.mean():+.2f}",
|
||||
"rew_std": f"{rewards.std():.2f}",
|
||||
"loss": f"{loss.item():+.4f}",
|
||||
"grad": f"{gnorm:.3f}",
|
||||
"cos_pre": f"{diag['mean_cos_pre']:+.4f}",
|
||||
"cos_post": f"{diag['mean_cos_post']:+.4f}",
|
||||
"frac_fired": f"{diag['frac_fired']:.2f}",
|
||||
"sec": f"{time.time()-t0:.1f}",
|
||||
})
|
||||
|
||||
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
|
||||
print(tabulate(rows, headers="keys", tablefmt="github"))
|
||||
print(f"peak GPU mem: {peak_gb:.2f} GB arm={cfg.arm} steps={len(rows)}/{cfg.steps}")
|
||||
|
||||
losses = [float(r["loss"]) for r in rows]
|
||||
if any(not torch.isfinite(torch.tensor(L)).item() for L in losses):
|
||||
logger.error("FAIL: non-finite loss")
|
||||
return 1
|
||||
if cfg.arm == "projected":
|
||||
# One-sided projection property: among modules where cos_pre>0, cos_post
|
||||
# should be driven to ~0. The mean over ALL modules will not be zero
|
||||
# because modules with cos_pre<=0 are left untouched. Instead we check
|
||||
# cos_post <= cos_pre (one-sided non-increase) and that fraction > 0.
|
||||
cos_pres = [float(r["cos_pre"]) for r in rows]
|
||||
cos_posts = [float(r["cos_post"]) for r in rows]
|
||||
fracs = [float(r["frac_fired"]) for r in rows]
|
||||
non_increase = all(co <= ci + 1e-4 for co, ci in zip(cos_posts, cos_pres))
|
||||
any_fired = any(f > 0 for f in fracs)
|
||||
if non_increase and any_fired:
|
||||
logger.info("PROJECTION WORKS: cos_post <= cos_pre on all steps, frac_fired>0")
|
||||
else:
|
||||
logger.warning(
|
||||
f"projection check: non_increase={non_increase} any_fired={any_fired}"
|
||||
)
|
||||
logger.info(f"GRPO+ANTIPASTO SMOKE OK ({cfg.arm}): {len(rows)}/{cfg.steps} steps, peak={peak_gb:.2f}GB")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
@@ -1,250 +0,0 @@
|
||||
"""simple_GRPO math in one process, on a tiny model.
|
||||
|
||||
Ports `gen_samples` + `GRPO_step` + ref-logps from simple_GRPO/simple_grpo_v1
|
||||
into a single process (no deepspeed, no HTTP ref_server). This is the smoke
|
||||
gate for step 5 of the plan and the foundation for steps 9-10 (AntiPaSTO +
|
||||
gradient projection).
|
||||
|
||||
SHOULD: loss is finite each step, advantages are normalized (mean approx 0),
|
||||
gen_logps shape matches completion tokens, reward distribution spreads
|
||||
across the 8 samples per question. ELSE: GRPO math or ref-server port
|
||||
is broken.
|
||||
|
||||
Run: uv run python -m projected_grpo.grpo_smoke
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
# --- config ---
|
||||
MODEL_PATH = "llamafactory/tiny-random-qwen3"
|
||||
N_STEPS = 5
|
||||
NUM_PRE_Q = 4 # group size G (simple_GRPO uses 8; smaller for smoke)
|
||||
Q_BATCH = 1 # questions per step
|
||||
BETA = 0.04 # KL weight
|
||||
CLIP = 0.2 # PPO clip
|
||||
LR = 1e-5 # bumped from 1e-6 -- tiny model, need movement
|
||||
MAX_NEW = 64
|
||||
MAX_PROMPT = 200
|
||||
SEED = 0
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful assistant. The user asks a question, and the Assistant "
|
||||
"thinks then answers. Enclose reasoning in <think>...</think> and the "
|
||||
"answer in <answer>...</answer>."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
step: int
|
||||
reward_mean: float
|
||||
reward_std: float
|
||||
adv_mean: float
|
||||
adv_std: float
|
||||
loss: float
|
||||
kl: float
|
||||
pol: float
|
||||
grad: float
|
||||
sec: float
|
||||
|
||||
|
||||
def reward_correct(gt: str, ans: str) -> float:
|
||||
nums = re.findall(r"-?\d+(?:\.\d+)?", ans)
|
||||
if not nums:
|
||||
return -1.0
|
||||
try:
|
||||
return 1.0 if abs(float(nums[-1]) - float(gt)) < 1e-3 else -1.0
|
||||
except ValueError:
|
||||
return -1.0
|
||||
|
||||
|
||||
def reward_format(ans: str) -> float:
|
||||
pat = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
return 0.25 if re.search(pat, ans, re.DOTALL) else -0.25
|
||||
|
||||
|
||||
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
# logits: [B, L-1, V], ids: [B, L-1]
|
||||
logp = logits.log_softmax(dim=-1)
|
||||
return logp.gather(-1, ids.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
torch.manual_seed(SEED)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||||
logger.info(
|
||||
f"cfg: model={MODEL_PATH} steps={N_STEPS} G={NUM_PRE_Q} "
|
||||
f"beta={BETA} clip={CLIP} lr={LR} max_new={MAX_NEW} seed={SEED}"
|
||||
)
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
if tok.pad_token_id is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
|
||||
logger.info("loading policy + ref_model (tiny-random-qwen3)")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_PATH, dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
ref_model.eval()
|
||||
for p in ref_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=LR)
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=MAX_NEW,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
num_return_sequences=NUM_PRE_Q,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
ds = load_dataset("openai/gsm8k", "main", split="train")
|
||||
QAs = [(q, a.split("####")[-1].strip()) for q, a in zip(ds["question"], ds["answer"])]
|
||||
logger.info(f"loaded {len(QAs)} GSM8K rows; using Q_BATCH={Q_BATCH}/step")
|
||||
|
||||
logger.info("\n\n--- TRAIN [simple_GRPO smoke] ---\n")
|
||||
logger.info(
|
||||
"SHOULD: loss finite each step, adv_mean near 0 (group-normalized), "
|
||||
"reward_std > 0 (group has spread, else step skipped upstream). "
|
||||
"ELSE: GRPO math broken or rewards collapsed to constant."
|
||||
)
|
||||
|
||||
rng = torch.Generator().manual_seed(SEED)
|
||||
rows: list[Step] = []
|
||||
for step in range(N_STEPS):
|
||||
t0 = time.time()
|
||||
idx = int(torch.randint(0, len(QAs), (1,), generator=rng).item())
|
||||
q, gt = QAs[idx]
|
||||
# build prompt
|
||||
prompt = tok.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": q},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||||
plen = enc.input_ids.shape[1]
|
||||
if plen > MAX_PROMPT:
|
||||
logger.warning(f"step {step}: prompt too long {plen}, skip")
|
||||
continue
|
||||
|
||||
# generate G samples (no_grad, NOT inference_mode -- the resulting
|
||||
# tensor is later fed to model(merged) under autograd)
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg)
|
||||
gen_out = gen_out.detach()
|
||||
completions = gen_out[:, plen:] # [G, L_c]
|
||||
merged = gen_out # [G, plen + L_c]
|
||||
L = merged.shape[1]
|
||||
|
||||
# decode + reward
|
||||
texts = tok.batch_decode(completions, skip_special_tokens=True)
|
||||
rewards_t = torch.tensor(
|
||||
[reward_correct(gt, t) + reward_format(t) for t in texts],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
if (rewards_t.max() - rewards_t.min()).item() < 1e-3:
|
||||
# tiny-random model gives garbage -> rewards collapse to floor.
|
||||
# For the smoke we still want to exercise the GRPO loss path, so
|
||||
# we override with synthetic standard-normal advantages. The real
|
||||
# run on a non-trivial model won't hit this branch.
|
||||
logger.warning(
|
||||
f"step {step}: reward spread ~0; using synthetic N(0,1) "
|
||||
f"advantages to smoke-test the loss math"
|
||||
)
|
||||
adv = torch.randn(NUM_PRE_Q, device=device, dtype=torch.float32)
|
||||
else:
|
||||
adv = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-4)
|
||||
|
||||
# policy + ref logprobs over completion tokens only
|
||||
# logits [G, L-1, V] map to predicted token ids [G, 1:L]
|
||||
with torch.no_grad():
|
||||
ref_logits = ref_model(merged).logits[:, :-1, :]
|
||||
ref_logp_full = per_token_logps(ref_logits, merged[:, 1:])
|
||||
# also get behavior logps for PPO ratio
|
||||
gen_logits = model(merged).logits[:, :-1, :]
|
||||
gen_logp_full = per_token_logps(gen_logits, merged[:, 1:])
|
||||
ref_logp = ref_logp_full[:, plen - 1 :].detach()
|
||||
gen_logp = gen_logp_full[:, plen - 1 :].detach()
|
||||
|
||||
# policy fresh forward (with grad)
|
||||
pol_logits = model(merged).logits[:, :-1, :]
|
||||
pol_logp_full = per_token_logps(pol_logits, merged[:, 1:])
|
||||
pol_logp = pol_logp_full[:, plen - 1 :]
|
||||
|
||||
mask = (merged[:, plen:] != tok.pad_token_id).float()
|
||||
# GRPO loss (simple_GRPO formulation, with PPO clipped ratio)
|
||||
ratio = torch.exp(pol_logp - gen_logp)
|
||||
clipped = torch.clamp(ratio, 1 - CLIP, 1 + CLIP)
|
||||
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
|
||||
kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
|
||||
per_tok_loss = -(pol_term - BETA * kl_term)
|
||||
loss = (per_tok_loss * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
loss = loss.mean()
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
opt.step()
|
||||
|
||||
sec = time.time() - t0
|
||||
rows.append(
|
||||
Step(
|
||||
step=step,
|
||||
reward_mean=rewards_t.mean().item(),
|
||||
reward_std=rewards_t.std().item(),
|
||||
adv_mean=adv.mean().item(),
|
||||
adv_std=adv.std().item(),
|
||||
loss=loss.item(),
|
||||
kl=(kl_term * mask).sum().item() / mask.sum().clamp(min=1).item(),
|
||||
pol=(pol_term * mask).sum().item() / mask.sum().clamp(min=1).item(),
|
||||
grad=grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm),
|
||||
sec=sec,
|
||||
)
|
||||
)
|
||||
|
||||
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
|
||||
|
||||
print("\n\n--- RESULT ---\n")
|
||||
print(
|
||||
tabulate(
|
||||
[vars(r) for r in rows],
|
||||
headers="keys",
|
||||
tablefmt="github",
|
||||
floatfmt="+.3f",
|
||||
)
|
||||
)
|
||||
print(f"\npeak GPU mem: {peak_gb:.2f} GB")
|
||||
print(f"n_completed_steps: {len(rows)}/{N_STEPS}")
|
||||
|
||||
if not rows:
|
||||
logger.error("FAIL: no step completed (all skipped on no-spread)")
|
||||
return 1
|
||||
losses = [r.loss for r in rows]
|
||||
if any(not torch.isfinite(torch.tensor(L)).item() for L in losses):
|
||||
logger.error(f"FAIL: non-finite loss in {losses}")
|
||||
return 1
|
||||
logger.info(f"\nGRPO SMOKE OK: {len(rows)}/{N_STEPS} steps, peak={peak_gb:.2f}GB")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Phase 2 pilot analyzer.
|
||||
|
||||
Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps
|
||||
+ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_pre,
|
||||
cos_post, fired) for vanilla vs projected, seed by seed.
|
||||
|
||||
Decision rules per spec2.md:
|
||||
- vanilla cos_pre > 0.2 consistently -> H1 likely; Phase 3 justified
|
||||
- vanilla cos_pre ~ 0 over all steps -> v_hack orthogonal to GRPO grad
|
||||
- projected cos_post < cos_pre on most steps -> mechanism active
|
||||
- projected hack rate < vanilla at matched pass -> H1 fires (won't see in
|
||||
20 steps; paper hacks emerge ~step 80)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors import safe_open
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def load_run(path: Path) -> tuple[dict, list[dict]]:
|
||||
"""Returns (cfg_dict, rows). Rows are the per-step TSV-like records."""
|
||||
with safe_open(str(path), framework="pt", device="cpu") as f:
|
||||
meta = f.metadata() or {}
|
||||
cfg = json.loads(meta.get("cfg", "{}"))
|
||||
rows = json.loads(meta.get("rows", "[]"))
|
||||
return cfg, rows
|
||||
|
||||
|
||||
def fmt_traj(rows: list[dict]) -> str:
|
||||
lines = ["step rew gt hack loss cin cout fired"]
|
||||
for r in rows:
|
||||
lines.append(
|
||||
f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} "
|
||||
f"{r['loss']:+.4f} {r['cos_pre']:+.3f} {r['cos_post']:+.3f} {r['fired']:.2f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def aggregate(rows: list[dict]) -> dict:
|
||||
if not rows:
|
||||
return {}
|
||||
cin = [r["cos_pre"] for r in rows if isinstance(r["cos_pre"], (int, float))]
|
||||
cout = [r["cos_post"] for r in rows if isinstance(r["cos_post"], (int, float))]
|
||||
fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))]
|
||||
n_steps = len(rows)
|
||||
last_hack = rows[-1]["hack"]
|
||||
last_gt = rows[-1]["gt"]
|
||||
return {
|
||||
"n_steps": n_steps,
|
||||
"cin_mean": sum(cin) / max(1, len(cin)),
|
||||
"cin_min": min(cin) if cin else float("nan"),
|
||||
"cin_max": max(cin) if cin else float("nan"),
|
||||
"cout_mean": sum(cout) / max(1, len(cout)),
|
||||
"fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"),
|
||||
"frac_out_lt_in": sum(1 for r in rows
|
||||
if isinstance(r["cos_post"], (int, float))
|
||||
and isinstance(r["cos_pre"], (int, float))
|
||||
and r["cos_post"] < r["cos_pre"]) / n_steps,
|
||||
"last_hack": last_hack,
|
||||
"last_gt": last_gt,
|
||||
}
|
||||
|
||||
|
||||
def main(pattern: str = "_pilot_*"):
|
||||
paths = sorted(Path("out").glob(f"train{pattern}.safetensors"))
|
||||
if not paths:
|
||||
print(f"no runs match out/train{pattern}.safetensors")
|
||||
return 1
|
||||
runs = []
|
||||
for p in paths:
|
||||
cfg, rows = load_run(p)
|
||||
if not rows:
|
||||
print(f"{p.name}: no rows")
|
||||
continue
|
||||
agg = aggregate(rows)
|
||||
agg["arm"] = cfg.get("arm")
|
||||
agg["seed"] = cfg.get("seed")
|
||||
agg["tag"] = cfg.get("out_tag", "")
|
||||
agg["path"] = p.name
|
||||
runs.append((cfg, rows, agg))
|
||||
|
||||
print("=" * 90)
|
||||
print("Phase 2 pilot — aggregate summary")
|
||||
print("=" * 90)
|
||||
print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out<in':>6s} hack gt")
|
||||
for _, _, agg in runs:
|
||||
print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} "
|
||||
f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} "
|
||||
f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}")
|
||||
|
||||
print()
|
||||
print("=" * 90)
|
||||
print("Per-step trajectories")
|
||||
print("=" * 90)
|
||||
for cfg, rows, agg in runs:
|
||||
print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---")
|
||||
print(fmt_traj(rows))
|
||||
|
||||
print()
|
||||
print("=" * 90)
|
||||
print("Phase 2 / Phase 3 decision")
|
||||
print("=" * 90)
|
||||
vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"]
|
||||
proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"]
|
||||
if vanilla_cin:
|
||||
v_mean = sum(vanilla_cin) / len(vanilla_cin)
|
||||
print(f"vanilla cos_pre mean across seeds: {v_mean:+.4f}")
|
||||
if v_mean > 0.2:
|
||||
print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.")
|
||||
elif v_mean > 0.02:
|
||||
print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.")
|
||||
print(" Phase 3 needed to see late-step regime.")
|
||||
elif abs(v_mean) < 0.01:
|
||||
print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still")
|
||||
print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).")
|
||||
else:
|
||||
print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.")
|
||||
|
||||
if proj_runs:
|
||||
out_lt_in = [a["frac_out_lt_in"] for a in proj_runs]
|
||||
m = sum(out_lt_in) / len(out_lt_in)
|
||||
print(f"projected cos_post<cos_pre fraction across seeds: {m:.2f}")
|
||||
if m >= 0.8:
|
||||
print(" -> Projection mechanism active.")
|
||||
else:
|
||||
print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*"))
|
||||
@@ -1,152 +0,0 @@
|
||||
"""UAT analyzer for the distillation probe.
|
||||
|
||||
Reads three runs from out/probe_distill/:
|
||||
teacher_pool/ (T1: teacher hack rate >= 0.30)
|
||||
vanilla_seed41/ (T2: cos_S_contrib non-null; T4: cos | hacked > cos | not-hacked)
|
||||
projected_seed41/ (T3: mean_cos_post < mean_cos_pre on most steps)
|
||||
|
||||
Prints PASS/FAIL per UAT.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def load_run(run_dir: Path) -> list[dict]:
|
||||
rows = []
|
||||
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
|
||||
with gzip.open(path, "rt") as f:
|
||||
for line in f:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def t_test(a: list[float], b: list[float]) -> tuple[float, float]:
|
||||
"""Welch's t (one-sided: a > b). Returns (t, p). No SciPy."""
|
||||
if len(a) < 2 or len(b) < 2:
|
||||
return float("nan"), float("nan")
|
||||
ma = sum(a) / len(a); mb = sum(b) / len(b)
|
||||
va = sum((x-ma)**2 for x in a) / (len(a)-1)
|
||||
vb = sum((x-mb)**2 for x in b) / (len(b)-1)
|
||||
se = math.sqrt(va/len(a) + vb/len(b))
|
||||
if se < 1e-12:
|
||||
return float("nan"), float("nan")
|
||||
t = (ma - mb) / se
|
||||
# one-sided p via normal approx (good enough at N>=20)
|
||||
z = t
|
||||
p = 0.5 * (1 - math.erf(z / math.sqrt(2)))
|
||||
return t, p
|
||||
|
||||
|
||||
def main(root: Path = Path("out/probe_distill")) -> int:
|
||||
results = []
|
||||
|
||||
# ---------- T1: teacher pool hack rate -------------------------------
|
||||
pool_dir = root / "teacher_pool"
|
||||
if not pool_dir.exists():
|
||||
results.append(("T1", "FAIL", f"missing {pool_dir}"))
|
||||
else:
|
||||
rows = load_run(pool_dir)
|
||||
if not rows:
|
||||
results.append(("T1", "FAIL", f"no rows in {pool_dir}"))
|
||||
else:
|
||||
hr = sum(int(r["hacked"]) for r in rows) / len(rows)
|
||||
pr = sum(int(r["gt_pass"]) for r in rows) / len(rows)
|
||||
ok = hr >= 0.30
|
||||
results.append((
|
||||
"T1", "PASS" if ok else "FAIL",
|
||||
f"teacher hack_rate={hr:.3f} (>=0.30) pass_rate={pr:.3f} n={len(rows)}",
|
||||
))
|
||||
|
||||
# ---------- T2: vanilla replay coverage ------------------------------
|
||||
vdir = root / "vanilla_seed41"
|
||||
if not vdir.exists():
|
||||
results.append(("T2", "FAIL", f"missing {vdir}"))
|
||||
else:
|
||||
rows = load_run(vdir)
|
||||
if not rows:
|
||||
results.append(("T2", "FAIL", f"no rows in {vdir}"))
|
||||
else:
|
||||
with_cos = [r for r in rows if r.get("cos_S_contrib") is not None]
|
||||
cov = len(with_cos) / max(1, len(rows))
|
||||
ok = cov >= 0.90
|
||||
results.append((
|
||||
"T2", "PASS" if ok else "FAIL",
|
||||
f"vanilla cos_S_contrib coverage={cov:.2f} (>=0.90) n={len(rows)}",
|
||||
))
|
||||
|
||||
# ---------- T3: projected projection fires ---------------------------
|
||||
pdir = root / "projected_seed41"
|
||||
if not pdir.exists():
|
||||
results.append(("T3", "FAIL", f"missing {pdir}"))
|
||||
else:
|
||||
rows = load_run(pdir)
|
||||
if not rows:
|
||||
results.append(("T3", "FAIL", f"no rows in {pdir}"))
|
||||
else:
|
||||
steps = {}
|
||||
for r in rows:
|
||||
if r.get("zero_advantages", False):
|
||||
continue
|
||||
steps.setdefault(r["step"], r) # one entry per step
|
||||
n_steps = len(steps)
|
||||
n_fired = sum(
|
||||
1 for r in steps.values()
|
||||
if not (math.isnan(r["mean_cos_pre"]) or math.isnan(r["mean_cos_post"]))
|
||||
and r["mean_cos_post"] < r["mean_cos_pre"]
|
||||
)
|
||||
frac = n_fired / max(1, n_steps)
|
||||
ok = frac >= 0.80
|
||||
results.append((
|
||||
"T3", "PASS" if ok else "FAIL",
|
||||
f"projected cos_post<cos_pre on {n_fired}/{n_steps} steps (frac={frac:.2f}, >=0.80)",
|
||||
))
|
||||
|
||||
# ---------- T4: per-sample cosine discriminates "purer hack" -----------
|
||||
# rh-s65 hacks ~99% of the time so the original hacked-vs-not split has
|
||||
# n_not ~ 1. Within hacked samples, however, gt_pass splits "pure hack
|
||||
# (only run_tests overwrite)" from "hack + also correct (mixed direction)".
|
||||
# SHOULD: cos | (hacked=1, gt_pass=0) > cos | (hacked=1, gt_pass=1).
|
||||
if not vdir.exists():
|
||||
results.append(("T4", "FAIL", "missing vanilla run"))
|
||||
else:
|
||||
rows = load_run(vdir)
|
||||
pure = [r["cos_S_contrib"] for r in rows
|
||||
if r["hacked"] and not r["gt_pass"] and r.get("cos_S_contrib") is not None]
|
||||
mixed = [r["cos_S_contrib"] for r in rows
|
||||
if r["hacked"] and r["gt_pass"] and r.get("cos_S_contrib") is not None]
|
||||
if len(pure) < 2 or len(mixed) < 2:
|
||||
results.append((
|
||||
"T4", "FAIL",
|
||||
f"too few samples per bucket: pure_hack={len(pure)}, hack+correct={len(mixed)}",
|
||||
))
|
||||
else:
|
||||
mp = sum(pure)/len(pure); mm = sum(mixed)/len(mixed)
|
||||
t, p = t_test(pure, mixed)
|
||||
ok = (p < 0.05) and (mp > mm)
|
||||
results.append((
|
||||
"T4", "PASS" if ok else "FAIL",
|
||||
f"cos|pure_hack={mp:+.3f} (n={len(pure)}) cos|hack+correct={mm:+.3f} (n={len(mixed)}) "
|
||||
f"t={t:+.2f} p={p:.4f}",
|
||||
))
|
||||
|
||||
print()
|
||||
print("UAT RESULTS")
|
||||
print("===========")
|
||||
n_pass = 0
|
||||
for name, status, msg in results:
|
||||
marker = "PASS" if status == "PASS" else "FAIL"
|
||||
print(f" [{marker}] {name} {msg}")
|
||||
n_pass += int(status == "PASS")
|
||||
print(f"\n {n_pass}/{len(results)} UATs passed.")
|
||||
return 0 if n_pass == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(Path("out/probe_distill")))
|
||||
@@ -1,243 +0,0 @@
|
||||
"""Smoke / fast-dev-run entry point — runs the REAL pipeline end-to-end.
|
||||
|
||||
Pipeline (~1-2 min on CPU with tiny-random qwen3):
|
||||
1. Load model + tokenizer
|
||||
2. Extract v_hack from 20 shared-prompt hack/clean pairs (docs/pairs):
|
||||
real forward, mean-difference of last-token hidden states at ~70% depth
|
||||
3. SVD-denoise v_hack via lm_head.weight
|
||||
4. Run N "real" GRPO-ish backward passes:
|
||||
- NLL loss on completion tokens
|
||||
- real loss.backward() -> real grad on model.lm_head.weight: [vocab, d]
|
||||
- per-row cos_align(grad_row, v_hack); aggregate mean
|
||||
- arm='projected': remove v_hack component from each row, optionally
|
||||
restore row magnitude, write back to .grad, optimizer.step()
|
||||
- arm='vanilla': no projection, optimizer.step()
|
||||
5. Diff vanilla vs projected: mean cos_align and parameter delta norms.
|
||||
|
||||
No fake gradients. Code paths AND mechanism are tested in one pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from torch import Tensor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from projected_grpo.extract_vhack import collect_last_token_hidden, extract_vhack
|
||||
from projected_grpo.grad_proj import svd_denoise
|
||||
from projected_grpo.pairs import PAIRS, clean_prompts, hack_prompts
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = "llamafactory/tiny-random-qwen3"
|
||||
arm: str = "both" # "vanilla" | "projected" | "both"
|
||||
m: int = 16 # SVD top-m for v_hack denoise
|
||||
steps: int = 5 # real backward+step iterations per arm
|
||||
seed: int = 41
|
||||
lr: float = 7e-5 # per docs/grpo_hyperparams.md
|
||||
fast_dev_run: bool = False
|
||||
vhack_check: bool = False
|
||||
preserve_magnitude: bool = True
|
||||
device: str = "cpu"
|
||||
dtype: str = "fp32" # fp32 | bf16; bf16 needs cuda
|
||||
|
||||
|
||||
def _resolve_dtype(s: str) -> torch.dtype:
|
||||
return {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[s]
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, format="<level>{level.icon}</level> {message}", colorize=True)
|
||||
logger.level("INFO", icon="I")
|
||||
logger.level("WARNING", icon="W")
|
||||
logger.level("ERROR", icon="E")
|
||||
logger.level("DEBUG", icon="D")
|
||||
|
||||
|
||||
def project_grad_per_row(
|
||||
g_W: Float[Tensor, "vocab d"],
|
||||
v_hack: Float[Tensor, "d"],
|
||||
preserve_magnitude: bool,
|
||||
) -> tuple[Float[Tensor, "vocab d"], dict]:
|
||||
"""One-sided per-row projection of a weight gradient against v_hack.
|
||||
|
||||
For each row g_v of g_W (shape [d]):
|
||||
cos = (g_v . v_hack) / ||g_v||
|
||||
if cos > 0: g_v' = g_v - cos * ||g_v|| * v_hack; rescale to ||g_v||
|
||||
else: g_v' = g_v
|
||||
"""
|
||||
v_hack = v_hack / (v_hack.norm() + 1e-12)
|
||||
row_norms = g_W.norm(dim=-1, keepdim=True).clamp_min(1e-12) # [vocab, 1]
|
||||
cos_pre = (g_W @ v_hack).unsqueeze(-1) / row_norms # [vocab, 1]
|
||||
mask_pos = (cos_pre > 0).float()
|
||||
coef = (cos_pre * row_norms) * mask_pos # zero out rows with cos<=0
|
||||
g_proj = g_W - coef * v_hack.unsqueeze(0)
|
||||
if preserve_magnitude:
|
||||
new_norms = g_proj.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
g_proj = g_proj * (row_norms / new_norms)
|
||||
cos_post = (g_proj @ v_hack) / g_proj.norm(dim=-1).clamp_min(1e-12)
|
||||
return g_proj, {
|
||||
"cos_pre_mean": cos_pre.squeeze(-1).mean().item(),
|
||||
"cos_pre_max": cos_pre.squeeze(-1).max().item(),
|
||||
"cos_post_mean": cos_post.mean().item(),
|
||||
"cos_post_max": cos_post.max().item(),
|
||||
"frac_projected": mask_pos.mean().item(),
|
||||
}
|
||||
|
||||
|
||||
def real_grpo_step(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
v_hack: Float[Tensor, "d"],
|
||||
arm: str,
|
||||
preserve_magnitude: bool,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
) -> dict:
|
||||
"""One GRPO-ish update: NLL on completion -> backward -> (project) -> step."""
|
||||
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(model.device)
|
||||
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
plen = prompt_ids.shape[1]
|
||||
labels = full_ids.clone()
|
||||
labels[:, :plen] = -100 # NLL on completion tokens only
|
||||
out = model(input_ids=full_ids, labels=labels)
|
||||
loss = out.loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
g_W = model.lm_head.weight.grad.detach().float() # [vocab, d] -> fp32 for projection stability
|
||||
if arm == "projected":
|
||||
g_proj, diag = project_grad_per_row(g_W, v_hack, preserve_magnitude)
|
||||
model.lm_head.weight.grad.copy_(g_proj.to(model.lm_head.weight.grad.dtype))
|
||||
else:
|
||||
row_norms = g_W.norm(dim=-1).clamp_min(1e-12)
|
||||
cos_pre = (g_W @ v_hack) / row_norms
|
||||
diag = {
|
||||
"cos_pre_mean": cos_pre.mean().item(),
|
||||
"cos_pre_max": cos_pre.max().item(),
|
||||
"cos_post_mean": cos_pre.mean().item(),
|
||||
"cos_post_max": cos_pre.max().item(),
|
||||
"frac_projected": 0.0,
|
||||
}
|
||||
optimizer.step()
|
||||
diag["loss"] = loss.item()
|
||||
diag["g_norm"] = g_W.norm().item()
|
||||
return diag
|
||||
|
||||
|
||||
def snapshot(model) -> dict[str, Tensor]:
|
||||
return {k: v.detach().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
|
||||
def param_delta(s0: dict[str, Tensor], s1: dict[str, Tensor]) -> float:
|
||||
return sum((s1[k].float() - s0[k].float()).norm().item() ** 2 for k in s0) ** 0.5
|
||||
|
||||
|
||||
def run_arm(cfg: Config, arm: str, v_hack: Float[Tensor, "d"]) -> dict:
|
||||
print(f"\n\n--- TRAIN [{arm}] seed={cfg.seed} steps={cfg.steps} lr={cfg.lr} ---\n")
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
||||
dtype = _resolve_dtype(cfg.dtype)
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device)
|
||||
model.train()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
||||
state_0 = snapshot(model)
|
||||
|
||||
rows = []
|
||||
for step in range(cfg.steps):
|
||||
p = PAIRS[step % len(PAIRS)]
|
||||
diag = real_grpo_step(
|
||||
model, tokenizer, p.prompt, p.hack, v_hack.to(model.device), arm,
|
||||
cfg.preserve_magnitude, optimizer,
|
||||
)
|
||||
rows.append({"step": step, "flavor": p.hack_flavor, **diag})
|
||||
|
||||
logger.info(f"per-step [{arm}]:\n" + tabulate(rows, headers="keys", tablefmt="tsv", floatfmt="+.3f"))
|
||||
state_1 = snapshot(model)
|
||||
return {
|
||||
"arm": arm,
|
||||
"final_loss": rows[-1]["loss"],
|
||||
"mean_cos_pre": sum(r["cos_pre_mean"] for r in rows) / len(rows),
|
||||
"mean_cos_post": sum(r["cos_post_mean"] for r in rows) / len(rows),
|
||||
"frac_projected": sum(r["frac_projected"] for r in rows) / len(rows),
|
||||
"param_delta": param_delta(state_0, state_1),
|
||||
}
|
||||
|
||||
|
||||
def main(cfg: Config) -> None:
|
||||
setup_logging()
|
||||
print(f"argv: {' '.join(sys.argv)}")
|
||||
print(f"cfg: {asdict(cfg)}")
|
||||
|
||||
print(f"\n\n=== LOAD [{cfg.model}] ===\n")
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
|
||||
dtype = _resolve_dtype(cfg.dtype)
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model, torch_dtype=dtype).to(cfg.device)
|
||||
model.eval()
|
||||
n_layers = model.config.num_hidden_layers
|
||||
layer_idx = max(1, int(n_layers * 0.7))
|
||||
logger.info(f"n_layers={n_layers} layer_idx={layer_idx} (70% depth per Wu-Tang)")
|
||||
|
||||
print(f"\n\n=== EXTRACT [v_hack] n_pairs={len(PAIRS)} layer={layer_idx} ===\n")
|
||||
h_hack = collect_last_token_hidden(model, tokenizer, hack_prompts(), layer_idx, cfg.device)
|
||||
h_clean = collect_last_token_hidden(model, tokenizer, clean_prompts(), layer_idx, cfg.device)
|
||||
n_train = int(len(PAIRS) * 0.75)
|
||||
vh = extract_vhack(
|
||||
h_hack[:n_train], h_clean[:n_train],
|
||||
h_hack[n_train:], h_clean[n_train:],
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
v_hack = vh.v_hack
|
||||
# SHOULD val_acc>0.9 is already logged inside extract_vhack at the site.
|
||||
|
||||
W = model.lm_head.weight.detach().float().cpu() # [vocab, d] -> fp32 cpu for stable SVD
|
||||
v_hack_cpu = v_hack.float().cpu()
|
||||
logger.info(f"SVD-denoise via lm_head.weight shape={tuple(W.shape)} m={cfg.m}")
|
||||
v_hack_denoised = svd_denoise(v_hack_cpu, W, m=cfg.m, use_left=False)
|
||||
cos_raw_denoised = float(v_hack_cpu @ v_hack_denoised)
|
||||
logger.info(
|
||||
f"cos(raw, denoised)={cos_raw_denoised:+.3f} "
|
||||
f"SHOULD>0.5: denoised should retain the dominant direction. "
|
||||
f"If <0.5: m too small OR wrong basis side (try use_left=True)."
|
||||
)
|
||||
del model # free; run_arm reloads a fresh copy for each arm
|
||||
|
||||
if cfg.vhack_check:
|
||||
logger.info("vhack-check: TODO real CAA-style steering check on full model.")
|
||||
return
|
||||
|
||||
arms = ["vanilla", "projected"] if cfg.arm == "both" else [cfg.arm]
|
||||
results = [run_arm(cfg, a, v_hack_denoised) for a in arms]
|
||||
|
||||
# === RESULTS tail ===
|
||||
print("\n\n=== RESULTS ===\n")
|
||||
if cfg.arm == "both":
|
||||
van = next(r for r in results if r["arm"] == "vanilla")
|
||||
proj = next(r for r in results if r["arm"] == "projected")
|
||||
delta_cos = van["mean_cos_post"] - proj["mean_cos_post"]
|
||||
cue = "[OK]" if delta_cos > 0.01 else "[WARN]"
|
||||
print(f"main metric: delta_cos_post={delta_cos:+.4f} {cue}")
|
||||
print(f"argv: {' '.join(sys.argv)}")
|
||||
print(f"vhack_val_acc={vh.val_accuracy:+.3f}")
|
||||
print(f"frac_projected (projected arm)={proj['frac_projected']:.2f}\n")
|
||||
|
||||
print(tabulate(results, headers="keys", tablefmt="tsv", floatfmt="+.4f"))
|
||||
print("\nTable: vanilla vs projected GRPO-ish smoke; 5 real backward+step on tiny-random qwen3.")
|
||||
print("mean_cos_post (->0 for projected, free for vanilla); param_delta (-> nonzero = real opt step).\n")
|
||||
print(tabulate(results, headers="keys", tablefmt="github", floatfmt="+.4f"))
|
||||
print()
|
||||
logger.info("smoke OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Config))
|
||||
Reference in New Issue
Block a user