mirror of
https://github.com/wassname/grpo_proj2.git
synced 2026-06-27 16:45:45 +08:00
refactor: SVD-diag knob -> parametrized LoRA (fixed A, train B)
Replace the AntiPaSTO SVD-diag adapter (δW = U·diag(δS)·Vh) with a parametrized LoRA: per Linear a frozen random A (r×d_in, semi-orthonormal) and a trained B (d_out×r), δW = (B+B_hack)·A. δW stays LINEAR in the trained knob -- the one property the projection needs (a once-extracted V stays a fixed weight-space direction) -- but the whole per-module SVD subsystem is gone (no svd_cached, no SVD_CACHE, no bf16 hash/contiguity friction). B_hack is the SAME shape as B, so the route quarantine is capacity-matched by construction (old-repo route2 diverged from an oversized quarantine sink). - antipasto: wrap() builds A (fp32 orthogonal_, geqrf has no bf16 -> cast after) + B/B_hack zero-init on the layer device; forward y + (B+B_hack)@(A@x). - proj: project_one is dim-agnostic; project_all flattens B.grad (d_out·r) and reshapes back. cos_overlap flattens too. - extract: V from the SVD of stacked B.grad pair-diffs (d_out·r). - train: B/B_hack rename, lora_rank config, per-step aligned table + legend (replaces the sparse tqdm postfix), clean argv via preset/Config defaults. - justfile: collapse smoke-vanilla/smoke-route/fast-vanilla/full-vanilla into smoke/fast/full *ARGS + a `sweep` recipe that fires vanilla|erase|route as pueue jobs. results.py: glob run_*.log (skip loguru verbose logs). Smoke (GPU bf16, all three arms) green: cout~0 one_sided identity holds in the LoRA basis, |Bh|=0 for erase/none, route parks into B_hack. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
# Parametrized-LoRA adapter (replace SVD-diag), clean entrypoints
|
||||
|
||||
## Goal
|
||||
Swap the AntiPaSTO SVD-diag knob (δW = U·diag(δS)·Vh, δS∈ℝ^r) for a parametrized
|
||||
LoRA (δW = B·A, A a frozen random projection, B trained). Keep δW LINEAR in the
|
||||
trained knob so the gradient projection stays a fixed-direction removal, but drop
|
||||
the whole SVD subsystem (svd_cached, SVD_CACHE, U/Vh, and the bf16 hash/contiguity
|
||||
friction). Same change folds in the entrypoint cleanup (clean argv, one sweep
|
||||
recipe firing the arms as pueue jobs).
|
||||
|
||||
## Why
|
||||
- Linearity is the ONLY property the projection needs: a once-extracted V stays a
|
||||
fixed weight-space direction (∂δW/∂B is the constant A). Fixed-A LoRA gives that
|
||||
with no SVD. Full vanilla LoRA (train A and B) is bilinear -> V drifts -> needs
|
||||
per-step re-extraction; avoid it.
|
||||
- π_ref free at B=0 (δW=0, base model untouched) and extract-at-init hold (∂L/∂B =
|
||||
δ_out·(A x)ᵀ ≠ 0 at B=0 since A≠0), same as the SVD knob.
|
||||
- Capacity lesson (old-repo route2): an oversized quarantine becomes a low-resistance
|
||||
sink (qE~0.97). Fixed-A LoRA with B_hack the SAME shape + SAME A is capacity-matched
|
||||
by construction.
|
||||
- More standard / credible than "diagonal scaling in the SVD basis", and showing the
|
||||
projection works in an arbitrary fixed basis is a stronger result than SVD-only.
|
||||
|
||||
## Scope
|
||||
In: `antipasto.wrap` (adapter), `extract_vhack_grad` (V in B-space), `proj` (operate
|
||||
on B.grad), route quarantine (B_hack), Config defaults + justfile entrypoints.
|
||||
Out: env / 4 hacks, GRPO loss, teacher pool, grader, and `project_one` MATH (it is
|
||||
already dimension-agnostic; only the per-module vector dim changes).
|
||||
|
||||
## Design (math)
|
||||
Per target Linear W ∈ ℝ^{d_out×d_in}, rank r = `lora_rank`:
|
||||
- A ∈ ℝ^{r×d_in}: random semi-orthonormal rows, frozen buffer (requires_grad=False).
|
||||
- B ∈ ℝ^{d_out×r}: trained param, zero-init.
|
||||
- B_hack ∈ ℝ^{d_out×r}: zero-init param, route only (same shape => capacity-matched).
|
||||
- forward hook: `y + (B + B_hack) @ (A @ x)` # δW=(B+B_hack)·A, linear in B; =0 at B=0
|
||||
- gradient knob for projection: g = B.grad.reshape(-1) ∈ ℝ^{d_out·r}.
|
||||
|
||||
Projection (proj.py, unchanged `project_one`):
|
||||
- extract: D = stack_pairs(∇B_hack - ∇B_clean).reshape(n_pairs, d_out·r); V = top-k
|
||||
right singular vecs of D (k × d_out·r); orient by per-pair sign vote.
|
||||
- step: g = B.grad.reshape(-1); g_proj, removed = project_one(g, V, ...);
|
||||
B.grad = g_proj.reshape(d_out, r); route: B_hack.grad = removed.reshape(d_out, r).
|
||||
|
||||
## Tasks
|
||||
- [ ] T1 (adapter): rewrite `wrap` — register frozen A buffer + B, B_hack params
|
||||
(zero-init, `device=lin.weight.device`, `dtype=lin.weight.dtype`); forward hook
|
||||
`y + (B+B_hack)@(A@x)`. Delete `svd_cached`, `SVD_CACHE`, U/Vh/δS, `zeroed` uses δS.
|
||||
- verify: `just smoke` green; log shows `wrapped N Linears (lora r=..)`.
|
||||
- success: cout≈0 (one_sided identity), cin_t>0.
|
||||
- likely_fail: B not zero-init -> π_ref breaks (T4 check fails).
|
||||
- sneaky_fail: A left trainable -> V drifts, cout creeps >0 over steps; assert
|
||||
`A.requires_grad is False` at wrap time.
|
||||
- [ ] T2 (extract): `extract_v_hack` reads `B.grad` (d_out×r), flattens per module,
|
||||
stacks pairs, SVD -> V (k × d_out·r). completion_nll/device already fixed.
|
||||
- verify: smoke cache-MISS extracts; `extracted V from P pairs over N modules`.
|
||||
- [ ] T3 (proj): `project_all` flattens B.grad -> `project_one` -> reshape back; route
|
||||
parks `removed` reshaped into B_hack.grad. `project_one` untouched.
|
||||
- verify: erase smoke cout≈0; route smoke |B_hack|>0 and grows.
|
||||
- [ ] T4 (π_ref): at B=B_hack=0, wrapped forward == base forward, bit-identical on a
|
||||
fixed input (no-grad). verify: add a smoke assert (max|Δlogits|==0 with knobs zeroed).
|
||||
- sneaky_fail: hook adds A·x even at B=0 -> Δ≠0; the (B+B_hack)@ guards this (0@=0).
|
||||
- [ ] T5 (clean argv): move `teacher_pool_dir`, `v_hack_path`, `mix_ratio`,
|
||||
`eval_ablate_every`/`vhack_refresh_every` (route) into the fast/full PRESETS so the
|
||||
CLI line is just `train.py fast --intervention=X --seed=S`. Drop SVD-only config.
|
||||
- [ ] T6 (entrypoints): collapse justfile to `smoke`/`fast`/`full *ARGS` (arm via
|
||||
`--intervention`, default erase) + `sweep PRESET='fast'` that pueue-adds none/erase/
|
||||
route with why:/resolve: labels (mirror old-repo `queue-substrate`). Delete
|
||||
smoke-vanilla/smoke-route/fast-vanilla/full-vanilla.
|
||||
- verify: `just sweep` queues 3 labeled jobs; `pueue status` shows them.
|
||||
|
||||
## Context
|
||||
- r (lora_rank): start 32. SVD knob was full r∈[1024,2560]; LoRA is low-rank, fewer
|
||||
params, more expressive per-axis than diagonal.
|
||||
- A init: semi-orthonormal (orthogonal_ on r×d_in, r<d_in => orthonormal rows) so A·x
|
||||
is a clean norm-preserving projection into ℝ^r.
|
||||
- New code stays pseudocode-like: unicode/math names where they read as math, einops
|
||||
for reshape (`rearrange(g, 'o r -> (o r)')`), explicit over clever.
|
||||
|
||||
## Log
|
||||
- bf16/cuda path fixed + committed (fbacefd) before this refactor; GPU smoke is the gate.
|
||||
- route is capacity-matched by construction here (B_hack==B shape), unlike old-repo route2.
|
||||
|
||||
## TODO
|
||||
- If low-rank B underfits (can't hack at all), bump r or revisit A init. Decide from
|
||||
the vanilla-arm hack_s trajectory, not a priori.
|
||||
@@ -22,33 +22,20 @@ build-pool:
|
||||
|
||||
# Smoke = the ONLY gate: same harness as production (train.py), tiny-random on GPU
|
||||
# in bf16 (same device+dtype as fast/full, so the cuda+bf16 path is actually covered).
|
||||
# beartype on so jaxtyping signatures get runtime-checked. 30 steps fires the
|
||||
# every-25-step save_ckpt path. erase writes g_proj; cache-miss extracts v_hack.
|
||||
# beartype on so jaxtyping signatures get runtime-checked. Arm/v_hack/pool/mix come
|
||||
# from the smoke preset, so the only override you usually pass is --intervention.
|
||||
smoke *ARGS: build-pool check
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
BEARTYPE=1 {{ TRAIN }} smoke {{ ARGS }}
|
||||
|
||||
# Vanilla arm: V loaded for the measure_only diagnostic (cin), grad left untouched.
|
||||
smoke-vanilla *ARGS: build-pool
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
# Every code path the real run walks: erase (gate) + vanilla (measure_only cin) +
|
||||
# route (two-knob optim, ablated deploy-eval, online v_hack refresh + basis_overlap).
|
||||
smoke-all: build-pool check
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=erase
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=none
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=route --eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10
|
||||
just results
|
||||
|
||||
# Routing arm: parks the hack-ward grad in δS_hack, ablates at eval. Fires the
|
||||
# two-param optimizer path, periodic ablated-eval, online v_hack refresh + the
|
||||
# basis_overlap guard, and the final kept-vs-ablated BLUF.
|
||||
smoke-route *ARGS: build-pool
|
||||
BEARTYPE=1 {{ TRAIN }} smoke --intervention=route \
|
||||
--v-hack-path=out/vhack/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/pools/teacher_pool --mix-ratio=0.5 \
|
||||
--eval-ablate-every=10 --eval-n-prompts=2 --vhack-refresh-every=10 {{ ARGS }}
|
||||
|
||||
# The trio = every code path the full run walks. Run before any real run.
|
||||
smoke-all: smoke smoke-vanilla smoke-route results
|
||||
|
||||
# Run smoke twice: first warms the v_hack cache (miss), second hits it (cache-hit
|
||||
# branch). Catches save/scope bugs that only manifest in one.
|
||||
# Run smoke twice: cache MISS (extract v_hack) then HIT. Catches save/scope bugs.
|
||||
smoke-both:
|
||||
rm -f out/vhack/v_hack_smoke.safetensors
|
||||
just smoke
|
||||
@@ -58,15 +45,17 @@ smoke-both:
|
||||
results:
|
||||
uv run python scripts/results.py
|
||||
|
||||
# Real runs (Qwen3-4B, GPU). v_hack auto-extracts on cache-miss inside train.
|
||||
full-vanilla *ARGS:
|
||||
{{ TRAIN }} full --intervention=none {{ ARGS }}
|
||||
# Real runs (Qwen3-4B, GPU). intervention defaults to erase; pass --intervention=none
|
||||
# for vanilla. v_hack auto-extracts on cache-miss inside train.
|
||||
fast *ARGS:
|
||||
{{ TRAIN }} fast {{ ARGS }}
|
||||
|
||||
full *ARGS:
|
||||
{{ TRAIN }} full --intervention=erase {{ ARGS }}
|
||||
{{ TRAIN }} full {{ ARGS }}
|
||||
|
||||
fast-vanilla *ARGS:
|
||||
{{ TRAIN }} fast --intervention=none --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }}
|
||||
|
||||
fast *ARGS:
|
||||
{{ TRAIN }} fast --intervention=erase --teacher-pool-dir=out/pools/teacher_pool {{ ARGS }}
|
||||
# Fire the matched arm comparison as low-priority pueue jobs (vanilla | erase | route)
|
||||
# for one PRESET/SEED/STEPS. One entrypoint queues the cell; pueue persists the logs.
|
||||
sweep PRESET='fast' SEED='41' STEPS='60':
|
||||
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} vanilla s{{ SEED }} x{{ STEPS }}; resolve: hack_s onset baseline (extracts v_hack)" -- {{ TRAIN }} {{ PRESET }} --intervention=none --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep
|
||||
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} erase s{{ SEED }} x{{ STEPS }}; resolve: hack_s suppressed vs vanilla at matched solve" -- {{ TRAIN }} {{ PRESET }} --intervention=erase --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep
|
||||
pueue add -w "$PWD" -o=-8 -l "why: {{ PRESET }} route s{{ SEED }} x{{ STEPS }}; resolve: deploy-hack < vanilla on held-out, |Bh|>0" -- {{ TRAIN }} {{ PRESET }} --intervention=route --seed={{ SEED }} --steps={{ STEPS }} --out-tag=_sweep --eval-ablate-every=10 --vhack-refresh-every=10 --eval-n-prompts=2
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Audit a training run's per-step jsonl: quote first/last student generation
|
||||
(coherence eyeball) + summarise the key columns with first-vs-last5 trend arrows.
|
||||
|
||||
Deterministic extraction; the /audit-log command feeds this to the LLM for a verdict.
|
||||
The clean repo writes ONE jsonl per run (logs/run_<arm><tag>_s<seed>.log) -- each line
|
||||
is a metric row, and rows that had a live student rollout carry a `gen` field
|
||||
{text, hack, gt, reward}. No separate rollouts.jsonl.
|
||||
|
||||
Usage:
|
||||
uv run python scripts/audit_log.py logs/run_erase_s41.log # a specific run log
|
||||
uv run python scripts/audit_log.py # newest logs/run_*.log
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _resolve(arg: str | None) -> Path:
|
||||
if arg:
|
||||
return Path(arg)
|
||||
logs = sorted(Path("logs").glob("run_*.log"))
|
||||
if not logs:
|
||||
sys.exit("no logs/run_*.log found")
|
||||
return logs[-1]
|
||||
|
||||
|
||||
def _gens(rows: list[dict]) -> None:
|
||||
gens = [r for r in rows if r.get("gen")]
|
||||
if not gens:
|
||||
print("\n(no student generations logged -- tiny-random smoke bails before grading?)")
|
||||
return
|
||||
print(f"\ngenerations: {len(gens)} steps carry a student gen, "
|
||||
f"steps {gens[0]['step']}..{gens[-1]['step']}")
|
||||
for r in (gens[0], gens[-1]):
|
||||
g = r["gen"]
|
||||
print(f"\n--- step {r['step']} reward={g['reward']:+.2f} gt={g['gt']} hack={g['hack']} ---")
|
||||
print("SHOULD: coherent code/prose; ELSE token salad => diverged")
|
||||
print(repr(g["text"][:400]))
|
||||
|
||||
|
||||
def _cols(rows: list[dict]) -> None:
|
||||
def trend(name, lo=None):
|
||||
v = [r[name] for r in rows if name in r and r[name] is not None]
|
||||
if not v:
|
||||
return f"{name:7s} (absent)"
|
||||
first, last5 = v[0], sum(v[-5:]) / len(v[-5:])
|
||||
arrow = "UP" if last5 > first + 1e-6 else ("DOWN" if last5 < first - 1e-6 else "flat")
|
||||
warn = f" <-- dipped below {lo} (min {min(v):.2f})" if lo is not None and min(v) < lo else ""
|
||||
return f"{name:7s} first={first:+.3f} last5={last5:+.3f} [{arrow}]{warn}"
|
||||
|
||||
print("\n=== key columns (first vs last-5-mean) ===")
|
||||
print("SHOULD interpret: hack_s UP=emerging (vanilla expects this; route's real test is")
|
||||
print(" the deploy eval); gt_s UP=learning to solve; cin_t>cin_s early then decays as V")
|
||||
print(" goes stale; cout~0 is the one_sided identity NOT efficacy; bh>0 iff route.")
|
||||
for nm, lo in [("hack_s", None), ("gt_s", None), ("cin_s", None), ("cin_t", None),
|
||||
("cout", None), ("bh", None), ("gn", None), ("loss", None)]:
|
||||
if any(nm in r for r in rows):
|
||||
print(" " + trend(nm, lo))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
path = _resolve(sys.argv[1] if len(sys.argv) > 1 else None)
|
||||
rows = [json.loads(l) for l in path.open() if l.strip()]
|
||||
print(f"=== AUDIT {path} ({len(rows)} steps) ===")
|
||||
if not rows:
|
||||
sys.exit("log is empty")
|
||||
_gens(rows)
|
||||
_cols(rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-3
@@ -1,4 +1,4 @@
|
||||
"""Aggregate logs/*.log (JSONL rows from train.py) into a last-5-step table.
|
||||
"""Aggregate logs/run_*.log (JSONL rows from train.py) into a last-5-step table.
|
||||
|
||||
hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s).
|
||||
Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log.
|
||||
@@ -18,7 +18,7 @@ def last5(rows, key):
|
||||
return sum(vals) / len(vals) if vals else float("nan")
|
||||
|
||||
|
||||
def main(pattern: str = "logs/*.log") -> None:
|
||||
def main(pattern: str = "logs/run_*.log") -> None:
|
||||
table = []
|
||||
for path in sorted(glob.glob(pattern)):
|
||||
rows = [json.loads(l) for l in open(path) if l.strip()]
|
||||
@@ -35,4 +35,4 @@ def main(pattern: str = "logs/*.log") -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1] if len(sys.argv) > 1 else "logs/*.log")
|
||||
main(sys.argv[1] if len(sys.argv) > 1 else "logs/run_*.log")
|
||||
|
||||
@@ -1,87 +1,76 @@
|
||||
"""SVD-basis adapter (pseudocode 01). Prior work: github.com/wassname/AntiPaSTO.
|
||||
"""Parametrized-LoRA adapter (pseudocode 01).
|
||||
|
||||
Train ONE per-module knob δS in the singular-value basis of each Linear. The
|
||||
extracted hack direction V, the live gradient, and the projection all live in
|
||||
these same low-rank weight-aligned coordinates (ℝ^r). At δS=δS_hack=0 the adapter
|
||||
is bit-identical to the base model (W is never reconstructed on the main path),
|
||||
so a no_grad forward with the knobs zeroed gives π_ref for free -- no 2nd model.
|
||||
Per target Linear, freeze a random projection A ∈ ℝ^{r×d_in} and train B ∈ ℝ^{d_out×r}
|
||||
(plus a same-shape quarantine B_hack for the route arm). The added weight is
|
||||
|
||||
δW = (B + B_hack) · A # applied as y += (A x) @ (B+B_hack)ᵀ
|
||||
|
||||
which is LINEAR in the trained knob B: ∂δW/∂B is the constant A. That is the one
|
||||
property the gradient projection needs -- a hack direction V extracted once stays a
|
||||
fixed weight-space direction, so we can project it out every step without re-deriving
|
||||
it (cf. proj.py). At B=B_hack=0 the adapter is identity (W is never touched), so a
|
||||
no_grad forward with the knobs zeroed is π_ref for free -- no 2nd model. B_hack has the
|
||||
SAME shape as B, so the quarantine is capacity-matched (it cannot become a sink).
|
||||
|
||||
A is frozen and random (not the SVD basis): showing the projection works in an
|
||||
arbitrary fixed basis is stronger than an SVD-aligned one, and it drops the whole
|
||||
per-module SVD subsystem.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
TARGET = {"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj", # GatedDeltaNet
|
||||
"up_proj", "gate_proj", "down_proj"}
|
||||
SVD_CACHE = Path("svd_cache")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wrap:
|
||||
layer: nn.Linear # carries .U, .Vh buffers and .δS, .δS_hack params
|
||||
layer: nn.Linear # carries frozen .A buffer and trained .B, .B_hack params
|
||||
r: int
|
||||
|
||||
@property
|
||||
def δS(self) -> nn.Parameter:
|
||||
return self.layer.δS
|
||||
def B(self) -> nn.Parameter:
|
||||
return self.layer.B
|
||||
|
||||
@property
|
||||
def δS_hack(self) -> nn.Parameter:
|
||||
return self.layer.δS_hack
|
||||
def B_hack(self) -> nn.Parameter:
|
||||
return self.layer.B_hack
|
||||
|
||||
|
||||
def svd_cached(W: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""U Σ Vh = W, reduced. Cache key = sha256(W) so a stale cache is impossible
|
||||
(different W -> different file). Native dtype in, fp32 SVD, cast back."""
|
||||
SVD_CACHE.mkdir(exist_ok=True)
|
||||
# .view(uint8) so the byte-hash is dtype-agnostic: numpy has no bf16, and every
|
||||
# preset (smoke included, now) loads weights in bf16.
|
||||
key = hashlib.sha256(W.detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()).hexdigest()
|
||||
path = SVD_CACHE / f"{key}.safetensors"
|
||||
if path.exists():
|
||||
t = load_file(path) # .to(W): match W's device+dtype
|
||||
return t["U"].to(W), t["S"].to(W), t["Vh"].to(W) # (cache hit loads on cpu -> move to cuda)
|
||||
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||
# safetensors needs contiguous cpu tensors; svd outputs are neither on cuda.
|
||||
save_file({"U": U.contiguous().cpu(), "S": S.contiguous().cpu(), "Vh": Vh.contiguous().cpu()}, path)
|
||||
return U.to(W), S.to(W), Vh.to(W)
|
||||
def _lora_hook(lin: nn.Linear, x_in, y):
|
||||
x = x_in[0] # (..., d_in)
|
||||
h = x @ lin.A.T # (..., r) into the fixed random subspace
|
||||
return y + h @ (lin.B + lin.B_hack).T # (..., d_out) B=B_hack=0 -> identity
|
||||
|
||||
|
||||
def _δ_hook(lin: nn.Linear, x_in, y):
|
||||
x = x_in[0] # (..., d_in)
|
||||
h = x @ lin.Vh.T # (..., r) into singular basis
|
||||
h = h * (lin.δS + lin.δS_hack) # scale per singular axis (forward uses the SUM)
|
||||
return y + h @ lin.U.T # (..., d_out) δS+δS_hack=0 -> identity
|
||||
|
||||
|
||||
def wrap(model: nn.Module) -> dict[str, Wrap]:
|
||||
def wrap(model: nn.Module, r: int = 32) -> dict[str, Wrap]:
|
||||
wrappers: dict[str, Wrap] = {}
|
||||
for name, lin in model.named_modules():
|
||||
if not isinstance(lin, nn.Linear) or name.split(".")[-1] not in TARGET:
|
||||
continue
|
||||
U, _Σ, Vh = svd_cached(lin.weight) # W ∈ ℝ^{d_out×d_in}
|
||||
r = min(lin.in_features, lin.out_features)
|
||||
lin.register_buffer("U", U)
|
||||
lin.register_buffer("Vh", Vh)
|
||||
z = lambda: nn.Parameter(torch.zeros(r, dtype=lin.weight.dtype, device=lin.weight.device))
|
||||
lin.register_parameter("δS", z()) # on the layer's device, else the hook mixes cpu/cuda
|
||||
lin.register_parameter("δS_hack", z())
|
||||
lin.register_forward_hook(_δ_hook)
|
||||
wrappers[name] = Wrap(lin, r)
|
||||
d_out, d_in = lin.out_features, lin.in_features
|
||||
rk = min(r, d_in) # need rk ≤ d_in for orthonormal rows of A
|
||||
A = torch.empty(rk, d_in, dtype=torch.float32, device=lin.weight.device)
|
||||
nn.init.orthogonal_(A) # A Aᵀ = I_rk; QR (geqrf) needs fp32, cast after
|
||||
lin.register_buffer("A", A.to(lin.weight.dtype)) # frozen bf16 buffer for the forward matmul
|
||||
zeros = lambda: nn.Parameter(torch.zeros(d_out, rk, dtype=lin.weight.dtype, device=lin.weight.device))
|
||||
lin.register_parameter("B", zeros()) # zero-init -> δW=0 at start (π_ref) and ∇B≠0 (A≠0)
|
||||
lin.register_parameter("B_hack", zeros())
|
||||
lin.register_forward_hook(_lora_hook)
|
||||
wrappers[name] = Wrap(lin, rk)
|
||||
for p in model.parameters():
|
||||
p.requires_grad_(False)
|
||||
p.requires_grad_(False) # freeze base; A is a buffer so already frozen
|
||||
for w in wrappers.values():
|
||||
w.δS.requires_grad_(True)
|
||||
w.δS_hack.requires_grad_(True)
|
||||
logger.info(f"wrapped {len(wrappers)} Linears with δS knobs (r in "
|
||||
w.B.requires_grad_(True)
|
||||
w.B_hack.requires_grad_(True)
|
||||
logger.info(f"wrapped {len(wrappers)} Linears with LoRA knobs (r in "
|
||||
f"[{min(w.r for w in wrappers.values())}, {max(w.r for w in wrappers.values())}])")
|
||||
return wrappers
|
||||
|
||||
@@ -89,29 +78,29 @@ def wrap(model: nn.Module) -> dict[str, Wrap]:
|
||||
@contextmanager
|
||||
def zeroed(wrappers: dict[str, Wrap]):
|
||||
"""Both knobs -> 0 (free reference model). Restores on exit."""
|
||||
saved = [(w.δS.data.clone(), w.δS_hack.data.clone()) for w in wrappers.values()]
|
||||
saved = [(w.B.data.clone(), w.B_hack.data.clone()) for w in wrappers.values()]
|
||||
for w in wrappers.values():
|
||||
w.δS.data.zero_()
|
||||
w.δS_hack.data.zero_()
|
||||
w.B.data.zero_()
|
||||
w.B_hack.data.zero_()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for w, (s, sh) in zip(wrappers.values(), saved):
|
||||
w.δS.data.copy_(s)
|
||||
w.δS_hack.data.copy_(sh)
|
||||
for w, (b, bh) in zip(wrappers.values(), saved):
|
||||
w.B.data.copy_(b)
|
||||
w.B_hack.data.copy_(bh)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ablate_quarantine(wrappers: dict[str, Wrap]):
|
||||
"""δS_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (δS_hack stays 0)."""
|
||||
saved = [w.δS_hack.data.clone() for w in wrappers.values()]
|
||||
"""B_hack -> 0 only (deploy-time ablation / extraction). No-op for erase (B_hack stays 0)."""
|
||||
saved = [w.B_hack.data.clone() for w in wrappers.values()]
|
||||
for w in wrappers.values():
|
||||
w.δS_hack.data.zero_()
|
||||
w.B_hack.data.zero_()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for w, sh in zip(wrappers.values(), saved):
|
||||
w.δS_hack.data.copy_(sh)
|
||||
for w, bh in zip(wrappers.values(), saved):
|
||||
w.B_hack.data.copy_(bh)
|
||||
|
||||
|
||||
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -48,11 +48,11 @@ def extract_v_hack(model, tok, wrappers: dict[str, Wrap], pairs: list[Pair],
|
||||
assert torch.isfinite(ℒ), f"non-finite NLL on {label} pair" # fail fast, never skip
|
||||
ℒ.backward()
|
||||
for n, w in wrappers.items():
|
||||
G[label][n].append(w.δS.grad.detach().clone())
|
||||
G[label][n].append(w.B.grad.reshape(-1).detach().clone()) # (d_out·r,)
|
||||
|
||||
V_hack, V_sv = {}, {}
|
||||
for n in wrappers:
|
||||
D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, r); pairing cancels prompt noise
|
||||
D = torch.stack(G["hack"][n]) - torch.stack(G["clean"][n]) # (n_pairs, d_out·r); pairing cancels prompt noise
|
||||
U_d, S_d, Vh_d = torch.linalg.svd(D.float(), full_matrices=False)
|
||||
V = Vh_d[:k] # (k, r), rows orthonormal in ℝ^r
|
||||
# orient hack-ward by per-pair majority vote (outlier-robust; proj gates on ⟨g,v_i⟩>0)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Gradient projection: erase / route the hack component (pseudocode 03).
|
||||
|
||||
Once per optimizer step, after δS.grad is populated, before opt.step(). The bet:
|
||||
Once per optimizer step, after B.grad is populated, before opt.step(). The bet:
|
||||
subtract the hack-ward component of the live GRPO gradient (basis V from labeled
|
||||
pairs) so the model descends what's left and cannot move hack-ward even when the
|
||||
reward rewards it.
|
||||
@@ -40,20 +40,21 @@ def project_one(g: Float[Tensor, "r"], V: Float[Tensor, "k r"],
|
||||
def project_all(wrappers, V_hack, measure_only: bool, route: bool,
|
||||
gate_mode: str = "one_sided", overshoot: float = 1.0,
|
||||
preserve_magnitude: bool = False):
|
||||
"""arms: none -> measure_only | erase -> write g_proj | route -> erase + park in δS_hack."""
|
||||
"""arms: none -> measure_only | erase -> write g_proj | route -> erase + park in B_hack.
|
||||
The knob B is (d_out, r); project_one is dim-agnostic so we flatten to ℝ^{d_out·r}."""
|
||||
cins, couts, fired = [], [], 0
|
||||
for name, w in wrappers.items():
|
||||
if name not in V_hack: # below noise floor -> no hack signal -> skip
|
||||
continue
|
||||
g_proj, removed, cin, cout = project_one(
|
||||
w.δS.grad, V_hack[name], gate_mode, preserve_magnitude, overshoot)
|
||||
g = w.B.grad.reshape(-1) # (d_out·r,) V_hack[name] is (k, d_out·r)
|
||||
g_proj, removed, cin, cout = project_one(g, V_hack[name], gate_mode, preserve_magnitude, overshoot)
|
||||
cins.append(cin); couts.append(cout)
|
||||
fired += int(cin > 1e-6)
|
||||
if measure_only: # vanilla: measure cin/cout, leave g
|
||||
continue
|
||||
w.δS.grad = g_proj # erase: descend the hack-free gradient
|
||||
w.B.grad = g_proj.reshape_as(w.B.grad) # erase: descend the hack-free gradient
|
||||
if route:
|
||||
w.δS_hack.grad = removed # park hack part in deletable knob (Cloud 2024)
|
||||
w.B_hack.grad = removed.reshape_as(w.B_hack) # park hack part in deletable knob (Cloud 2024)
|
||||
n = max(len(cins), 1)
|
||||
mean = lambda xs: torch.stack(xs).mean().item() if xs else 0.0
|
||||
return mean(cins), mean(couts), fired / n
|
||||
@@ -65,5 +66,6 @@ def cos_overlap(grads: dict, V_hack) -> float:
|
||||
for name, g in grads.items():
|
||||
if name not in V_hack or g is None:
|
||||
continue
|
||||
g = g.reshape(-1) # B-grad (d_out, r) -> ℝ^{d_out·r}
|
||||
cs.append(torch.relu(V_hack[name] @ g).norm() / g.norm().clamp_min(1e-12))
|
||||
return torch.stack(cs).mean().item() if cs else 0.0
|
||||
|
||||
+43
-35
@@ -6,7 +6,7 @@ Unbiased normalization: Dr.GRPO, Liu et al. 2025, arXiv:2503.20783 (drop the
|
||||
|
||||
generate -> grade -> backward -> project -> step
|
||||
|
||||
Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park δS_hack.
|
||||
Arms (one knob): none -> measure_only | erase -> write g_proj | route -> +park B_hack.
|
||||
Presets: smoke (tiny-random, GPU bf16, the only gate) | fast | full.
|
||||
|
||||
python -m projected_grpo.train smoke --intervention=erase ...
|
||||
@@ -73,7 +73,8 @@ class Config:
|
||||
n_problems: int = 6
|
||||
env_mode: Optional[EnvMode] = None # force a single mode (emergence cells)
|
||||
mix_ratio: float = 0.5 # G_t / G (teacher fraction)
|
||||
teacher_pool_dir: str = ""
|
||||
teacher_pool_dir: str = "out/pools/teacher_pool"
|
||||
lora_rank: int = 32 # r: A is (r×d_in) frozen, B is (d_out×r) trained
|
||||
# optim
|
||||
lr: float = 1e-2
|
||||
adam_beta1: float = 0.9
|
||||
@@ -82,7 +83,7 @@ class Config:
|
||||
warmup_frac: float = 0.1
|
||||
grad_clip: float = 10.0
|
||||
# GRPO
|
||||
beta: float = 0.0 # KL to free ref (δS=0); 0 = off
|
||||
beta: float = 0.0 # KL to free ref (B=0); 0 = off
|
||||
clip: float = 0.2
|
||||
# projection
|
||||
gate_mode: str = "one_sided"
|
||||
@@ -130,7 +131,7 @@ def load_pool(pool_dir: str) -> dict[str, list[dict]]:
|
||||
|
||||
# ── inner GRPO step on one prompt ─────────────────────────────────────────────
|
||||
def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, denom):
|
||||
"""Returns (g_s, g_t, metrics) for this prompt; δS.grad left holding g_s+g_t.
|
||||
"""Returns (g_s, g_t, metrics) for this prompt; B.grad left holding g_s+g_t.
|
||||
g_s/g_t are per-module dicts so the outer loop can accumulate + diagnose cin."""
|
||||
prompt_ids = tok(problem.prompt(), return_tensors="pt").input_ids.to(cfg.device)
|
||||
P = prompt_ids.shape[1]
|
||||
@@ -165,7 +166,7 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d
|
||||
if s_rew: # stash one student gen for the jsonl + the final coherence eyeball
|
||||
m["gen"] = dict(text=texts[0], hack=s_rew[0].exploited, gt=s_rew[0].gt_correct, reward=s_rew[0].reward)
|
||||
|
||||
empty = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
|
||||
empty = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
|
||||
if R.max() - R.min() < 1e-6: # zero-variance group => adv≡0 => pure waste (bail)
|
||||
return empty, dict(empty), m
|
||||
|
||||
@@ -193,15 +194,15 @@ def grpo_step(model, tok, wrappers, problem, pool_rows, G_s, G_t, cfg, pad_id, d
|
||||
is_s = torch.tensor(is_student, dtype=torch.float32, device=cfg.device)[:, None]
|
||||
|
||||
# split student/teacher grads. zero+clone BETWEEN passes is load-bearing: after
|
||||
# L_t.backward, δS.grad already holds g_s+g_t (accumulation), so reset first.
|
||||
# L_t.backward, B.grad already holds g_s+g_t (accumulation), so reset first.
|
||||
model.zero_grad()
|
||||
L_s = (Lp * comp_mask * is_s).sum() / denom
|
||||
L_s.backward(retain_graph=True)
|
||||
g_s = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()}
|
||||
g_s = {n: w.B.grad.detach().clone() for n, w in wrappers.items()}
|
||||
model.zero_grad()
|
||||
L_t = (Lp * comp_mask * (1 - is_s)).sum() / denom
|
||||
L_t.backward()
|
||||
g_t = {n: w.δS.grad.detach().clone() for n, w in wrappers.items()}
|
||||
g_t = {n: w.B.grad.detach().clone() for n, w in wrappers.items()}
|
||||
m["loss"] = (L_s + L_t).item()
|
||||
return g_s, g_t, m
|
||||
|
||||
@@ -226,8 +227,8 @@ def save_ckpt(wrappers, rows, tag):
|
||||
Path("out/ckpt").mkdir(parents=True, exist_ok=True)
|
||||
flat = {}
|
||||
for n, w in wrappers.items():
|
||||
flat[f"δS/{n}"] = w.δS.detach().cpu()
|
||||
flat[f"δS_hack/{n}"] = w.δS_hack.detach().cpu()
|
||||
flat[f"B/{n}"] = w.B.detach().cpu()
|
||||
flat[f"B_hack/{n}"] = w.B_hack.detach().cpu()
|
||||
save_file(flat, f"out/ckpt/ckpt{tag}.safetensors")
|
||||
Path(f"out/ckpt/rows{tag}.json").write_text(json.dumps(rows))
|
||||
|
||||
@@ -246,9 +247,11 @@ def main(cfg: Config):
|
||||
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model, dtype=dt).to(cfg.device)
|
||||
model.eval()
|
||||
wrappers = wrap(model)
|
||||
wrappers = wrap(model, cfg.lora_rank)
|
||||
n_train = sum(w.B.numel() for w in wrappers.values())
|
||||
logger.info(f"trainable B: {n_train:,} (+{n_train:,} B_hack quarantine), lora_rank={cfg.lora_rank}")
|
||||
|
||||
params = [w.δS for w in wrappers.values()] + [w.δS_hack for w in wrappers.values()]
|
||||
params = [w.B for w in wrappers.values()] + [w.B_hack for w in wrappers.values()]
|
||||
opt = torch.optim.AdamW(params, lr=cfg.lr, betas=(cfg.adam_beta1, cfg.adam_beta2),
|
||||
weight_decay=cfg.weight_decay)
|
||||
sched = get_cosine_schedule_with_warmup(opt, int(cfg.warmup_frac * cfg.steps), cfg.steps)
|
||||
@@ -258,8 +261,8 @@ def main(cfg: Config):
|
||||
V_raw, V_sv = resolve_or_extract(model, tok, wrappers, cfg.v_hack_path,
|
||||
cfg.vhack_pairs_path, cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
|
||||
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
|
||||
δS_ref = next(iter(wrappers.values())).δS # V lives in δS's space: match dtype+device
|
||||
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda
|
||||
B_ref = next(iter(wrappers.values())).B # V lives in B's space: match dtype+device
|
||||
V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # extract/load give fp32[/cpu]; grads are bf16 cuda
|
||||
logger.info(f"V_hack over {len(V_hack)} modules (k_use={cfg.v_hack_k}); "
|
||||
f"arm={cfg.intervention} -> measure_only={cfg.intervention == 'none'}")
|
||||
|
||||
@@ -275,17 +278,21 @@ def main(cfg: Config):
|
||||
log = logname.open("w")
|
||||
rows = []
|
||||
|
||||
# per-step table: one logger.info row/step (legend once). SHORT runs (20-200 steps),
|
||||
# so a row/step reads cleanly -- no tqdm bar (it redraws into piped logs).
|
||||
logger.info(
|
||||
"SHOULD (per-step JSONL rows + tqdm postfix): erase/route cin_t>0 and >cin_s early "
|
||||
"(teacher pool carries the hack signal), decaying as V goes stale; cout~0 under one_sided "
|
||||
"is an arithmetic identity, NOT efficacy; |δSh|>0 iff route, else 0; hack_s stays 0 on "
|
||||
"tiny-random (mechanics smoke). If cin_t~0 throughout -> V missed the live GRPO gradient.")
|
||||
pbar = tqdm(range(cfg.steps), desc=f"train[{cfg.intervention}] s{cfg.seed}",
|
||||
mininterval=120, maxinterval=120)
|
||||
for step in pbar:
|
||||
"per-step columns: step | rew=mean reward | gt_s=student solve | hack_s=student exploit (HEADLINE) |"
|
||||
" hack_t=teacher exploit (pool sanity) | cin_s/cin_t=student/teacher grad overlap with V (relu-agg) |"
|
||||
" cout=residual hack-ward leak | loss | gn=pre-clip grad L2 | lr | |Bh|=quarantine norm.\n"
|
||||
"SHOULD: cin_t>cin_s early (teacher carries the hack signal), decaying as V stales; cout~0 under "
|
||||
"one_sided is an ARITHMETIC IDENTITY not efficacy; |Bh|>0 iff route; hack_s stays 0 on tiny-random.")
|
||||
hdr = (f"{'step':>4} {'rew':>6} {'gt_s':>5} {'hack_s':>6} {'hack_t':>6} {'cin_s':>6} {'cin_t':>6} "
|
||||
f"{'cout':>6} {'loss':>7} {'gn':>8} {'lr':>8} {'|Bh|':>8}")
|
||||
logger.info(hdr)
|
||||
for step in range(cfg.steps):
|
||||
opt.zero_grad()
|
||||
g_s_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
|
||||
g_t_acc = {n: torch.zeros_like(w.δS) for n, w in wrappers.items()}
|
||||
g_s_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
|
||||
g_t_acc = {n: torch.zeros_like(w.B) for n, w in wrappers.items()}
|
||||
ms = []
|
||||
for _ in range(cfg.prompts_per_step):
|
||||
prob = problems[random.randrange(len(problems))]
|
||||
@@ -295,8 +302,8 @@ def main(cfg: Config):
|
||||
g_s_acc[n] += g_s[n]; g_t_acc[n] += g_t[n]
|
||||
ms.append(m)
|
||||
for n, w in wrappers.items():
|
||||
w.δS.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob
|
||||
w.δS_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all)
|
||||
w.B.grad = g_s_acc[n] + g_t_acc[n] # combined live GRPO gradient on the knob
|
||||
w.B_hack.grad = None # quarantine moves ONLY via routed `removed` (set in project_all)
|
||||
|
||||
cin_s = cos_overlap(g_s_acc, V_hack) if V_hack else 0.0
|
||||
cin_t = cos_overlap(g_t_acc, V_hack) if V_hack else 0.0
|
||||
@@ -309,12 +316,12 @@ def main(cfg: Config):
|
||||
gn = torch.nn.utils.clip_grad_norm_([p for p in params if p.grad is not None], cfg.grad_clip)
|
||||
opt.step(); sched.step()
|
||||
|
||||
dSh = sum(w.δS_hack.detach().norm().item() for w in wrappers.values())
|
||||
bh = sum(w.B_hack.detach().norm().item() for w in wrappers.values())
|
||||
row = dict(step=step, hack_s=_mean([m["hack_s"] for m in ms]), gt_s=_mean([m["gt_s"] for m in ms]),
|
||||
hack_t=_mean([m["hack_t"] for m in ms]), gt_t=_mean([m["gt_t"] for m in ms]),
|
||||
reward=_mean([m["reward"] for m in ms]), loss=_mean([m.get("loss", 0.0) for m in ms]),
|
||||
gn=gn.item(), cin_s=cin_s, cin_t=cin_t, cin=cin, cout=cout, fired=fired,
|
||||
dSh=dSh)
|
||||
bh=bh)
|
||||
if "gen" in ms[0]: # log the first prompt's first student gen (text + flags)
|
||||
row["gen"] = ms[0]["gen"]
|
||||
|
||||
@@ -325,12 +332,12 @@ def main(cfg: Config):
|
||||
V_raw, V_sv = extract_v_hack(model, tok, wrappers, default_pairs(),
|
||||
cfg.v_hack_top_k, cfg.tau_axis, cfg.n_heldout)
|
||||
V_hack = postprocess_v_hack(V_raw, V_sv, cfg.v_hack_k, cfg.v_hack_drop_bottom_frac)
|
||||
V_hack = {n: V.to(δS_ref) for n, V in V_hack.items()} # keep V in δS's bf16/cuda space
|
||||
V_hack = {n: V.to(B_ref) for n, V in V_hack.items()} # keep V in B's bf16/cuda space
|
||||
shared = [n for n in V_hack if n in old]
|
||||
ov = _mean([abs((V_hack[n][0] @ old[n][0]).item()) for n in shared]) if shared else 0.0
|
||||
row["basis_overlap"] = ov # GUARD: should sit near 1.0; <~0.2 => refresh rotated off-hack
|
||||
|
||||
# deployment-time eval (route only): zero δS_hack, eval the ablated model
|
||||
# deployment-time eval (route only): zero B_hack, eval the ablated model
|
||||
if is_route and cfg.eval_ablate_every and step % cfg.eval_ablate_every == 0:
|
||||
with ablate_quarantine(wrappers):
|
||||
hd, sd = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts,
|
||||
@@ -339,9 +346,10 @@ def main(cfg: Config):
|
||||
|
||||
rows.append(row)
|
||||
log.write(json.dumps(row) + "\n"); log.flush()
|
||||
pbar.set_postfix(hack_s=f"{row['hack_s']:.2f}", cin_s=f"{cin_s:.2f}", cin_t=f"{cin_t:.2f}",
|
||||
cout=f"{cout:.2f}", δSh=f"{dSh:.2f}", loss=f"{row['loss']:.3f}",
|
||||
refresh=False) # don't force a bar redraw every step (pollutes piped logs)
|
||||
lr = sched.get_last_lr()[0]
|
||||
logger.info(f"{step:4d} {row['reward']:+6.2f} {row['gt_s']:5.2f} {row['hack_s']:6.2f} "
|
||||
f"{row['hack_t']:6.2f} {cin_s:6.3f} {cin_t:6.3f} {cout:6.3f} {row['loss']:+7.3f} "
|
||||
f"{gn.item():8.1e} {lr:8.1e} {bh:8.3f}")
|
||||
if step % 25 == 0:
|
||||
save_ckpt(wrappers, rows, cfg.out_tag or f"_{cfg.intervention}_s{cfg.seed}")
|
||||
|
||||
@@ -356,14 +364,14 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id):
|
||||
last5 = rows[-5:]
|
||||
hack = _mean([r["hack_s"] for r in last5]); solve = _mean([r["gt_s"] for r in last5])
|
||||
cin_t = _mean([r["cin_t"] for r in last5]); cout = _mean([r["cout"] for r in last5])
|
||||
dSh = rows[-1]["dSh"]
|
||||
bh = rows[-1]["bh"]
|
||||
out = f"logs/run_{cfg.intervention}{cfg.out_tag}_s{cfg.seed}.log"
|
||||
# mechanics cue: projection arms want the teacher hack signal present (cin_t>0); none is neutral
|
||||
cue = "\U0001F7E2" if (cfg.intervention == "none" or cin_t > 0) else "\U0001F534"
|
||||
|
||||
summary = {"arm": cfg.intervention, "steps": len(rows),
|
||||
"hack↓": round(hack, 3), "solve→": round(solve, 3),
|
||||
"cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|δSh|": round(dSh, 2)}
|
||||
"cin_t↑": round(cin_t, 3), "cout→0": round(cout, 3), "|Bh|": round(bh, 2)}
|
||||
if cfg.intervention == "route":
|
||||
with ablate_quarantine(wrappers):
|
||||
hk_abl, sv_abl = eval_hack_solve(model, tok, wrappers, problems, cfg.eval_n_prompts or 2,
|
||||
@@ -376,7 +384,7 @@ def _bluf(cfg, rows, wrappers, model, tok, problems, pad_id):
|
||||
caption = (f"Table: last-5-step means, arm={cfg.intervention} seed={cfg.seed}. "
|
||||
"hack↓ student exploit rate, solve→ gt-correct rate, cin_t↑ teacher-grad "
|
||||
"overlap with V (relu-before-agg), cout→0 residual hack-ward leak (identity under "
|
||||
"one_sided, not efficacy), |δSh| quarantine norm (>0 iff route).")
|
||||
"one_sided, not efficacy), |Bh| quarantine norm (>0 iff route).")
|
||||
logger.info(f"out: {out}")
|
||||
# Last student generation -- a coherence eyeball before the numbers. SHOULD: real
|
||||
# code/prose for the problem. If it is token salad the policy diverged and the eval
|
||||
|
||||
Reference in New Issue
Block a user