mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
refactor: extract train_config.py + run_artifacts.py from train.py; slim results scripts
Cleanup by a prior agent, verified green here: 'just smoke' (erase arm) runs end-to-end and all four wired gates pass (verify_rewards 52/52, verify_eval_gap, verify_partition, verify_science_invariants). - train.py -318 lines: Config dataclass -> train_config.py, checkpoint/ deploy-artifact IO -> run_artifacts.py. - results.py / results_deploy.py / probe_distill.py slimmed. - drop stale derived csvs under out/figs (a5_generalisation, dyn_*, substrate_aggregate, train_vs_deploy_60). - gitignore /.pi/ panel scratch. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,7 @@ def main(run_dir: Positional[Path]) -> None:
|
||||
)
|
||||
out_path = run_dir / "eval_checkpoint_curve.jsonl"
|
||||
out_path.write_text("")
|
||||
is_route = cfg["intervention"] in ("route", "routeV")
|
||||
is_route = cfg["intervention"] == "routeV"
|
||||
for kept_path in ckpts:
|
||||
hack_path = kept_path.with_name(kept_path.stem + "_hack.safetensors")
|
||||
_load(wrappers, kept_path, hack_path)
|
||||
|
||||
+40
-35
@@ -88,6 +88,7 @@ def parse_log(path: Path) -> dict | None:
|
||||
# a vertical line / end of the teacher-on shaded region in the 2x2.
|
||||
_toff = grab(r"--teacher-off-step=(\d+)", argv, None)
|
||||
teacher_off = int(_toff) if _toff is not None else None
|
||||
eval_n = int(grab(r"periodic-curve n=(\d+)", txt))
|
||||
|
||||
# header line: the one containing both "step" and "hack_s"
|
||||
hdr = next((l for l in txt.splitlines()
|
||||
@@ -123,8 +124,13 @@ def parse_log(path: Path) -> dict | None:
|
||||
series[col].append(_val(row[idx[col]]))
|
||||
if not steps:
|
||||
return None
|
||||
per_token = "--routeV-per-token" in argv
|
||||
# Logged step k is evaluated after optimizer update k, so the number of
|
||||
# completed updates is k+1. The shared pre-training base point is not logged.
|
||||
steps = np.array(steps) + 1
|
||||
run = dict(arm=arm, refr=refr, seed=seed, vhack=vhack, teacher_off=teacher_off,
|
||||
steps=np.array(steps), **{k: np.array(v, dtype=float) for k, v in series.items()})
|
||||
per_token=per_token, eval_n=eval_n,
|
||||
steps=steps, **{k: np.array(v, dtype=float) for k, v in series.items()})
|
||||
# Normalise missing eval columns to all-nan (absent == all-nan downstream): old logs
|
||||
# that never printed a held-out eval lack the key entirely, which would KeyError the
|
||||
# train-series assignment. A nan column drops the seed out of the mean cleanly.
|
||||
@@ -168,22 +174,23 @@ def classify(run: dict) -> str:
|
||||
return "vanilla"
|
||||
if run["arm"] == "routing":
|
||||
return "routing"
|
||||
if run["arm"] == "routing2":
|
||||
return "routing2"
|
||||
if run["arm"] == "routingV":
|
||||
return "routingV_per_token" if run["per_token"] else "routingV"
|
||||
# arm == projected -> erasure, split by refresh
|
||||
return "online erasure" if run["refr"] > 0 else "static erasure"
|
||||
|
||||
|
||||
# --- plot ------------------------------------------------------------------
|
||||
|
||||
# routing (route v1, single quarantine) is deprecated -- superseded by routing2
|
||||
# (scale-matched quarantine). classify() still tags v1 logs as "routing" so they
|
||||
# don't get misread as erasure, but it's left out of ARM_ORDER so it isn't plotted.
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routing2"]
|
||||
# routing (route v1, single quarantine) and routing2 are deprecated. routeV is
|
||||
# the current scale-matched quarantine method.
|
||||
ARM_ORDER = ["vanilla", "static erasure", "online erasure", "routingV", "routingV_per_token"]
|
||||
# Distinct colour per series -- the two rows measure different things, so they
|
||||
# must not share a palette (hack != teacher-cos). Row 0: red hack vs green
|
||||
# solve. Row 1: blue teacher-cos vs amber student-cos.
|
||||
RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
||||
HACK_YMAX = 0.65
|
||||
SOLVE_YMAX = 0.25
|
||||
# Arm colours for the single-panel hack overlay (arms, not series): grey vanilla
|
||||
# baseline -> amber static -> blue online, ordered by increasing intervention.
|
||||
# TODO(color): make this a quality-ordered red->green ramp instead of fixed
|
||||
@@ -193,7 +200,7 @@ RATE_COLORS = {"hack_s": "#c1432b", "gt_s": "#2f7d4f"}
|
||||
# the reader sees "redder = hacks more" at a glance.
|
||||
ARM_COLORS = {"vanilla": "#7a7a7a", "static erasure": "#c98a2b",
|
||||
"online erasure": "#33508c", "routing": "#2f7d4f",
|
||||
"routing2": "#7d2f6f"}
|
||||
"routingV": "#7d2f6f", "routingV_per_token": "#7d2f6f"}
|
||||
|
||||
|
||||
def _onset(steps: np.ndarray, hack: np.ndarray) -> int | None:
|
||||
@@ -261,13 +268,13 @@ CSV_SERIES = ["hack_s", "gt_s", "hack_train", "solve_train", "hk_dep", "slv_dep"
|
||||
|
||||
def dump_data(runs: list[dict], out: Path) -> Path:
|
||||
csv = out.with_suffix(".csv")
|
||||
lines = ["arm,seed,step," + ",".join(CSV_SERIES)]
|
||||
lines = ["arm,seed,eval_n,step," + ",".join(CSV_SERIES)]
|
||||
for r in runs:
|
||||
arm = classify(r)
|
||||
for i, step in enumerate(r["steps"]):
|
||||
cells = [r[k][i] if (k in r and r[k] is not None and i < len(r[k])) else float("nan")
|
||||
for k in CSV_SERIES]
|
||||
lines.append(f"{arm},{r['seed']},{int(step)}," + ",".join(str(c) for c in cells))
|
||||
lines.append(f"{arm},{r['seed']},{r['eval_n']},{int(step)}," + ",".join(str(c) for c in cells))
|
||||
csv.write_text("\n".join(lines) + "\n")
|
||||
logger.info(f"wrote {csv} ({len(runs)} runs, reproducibility source)")
|
||||
return csv
|
||||
@@ -285,6 +292,7 @@ def load_csv(path: Path) -> list[dict]:
|
||||
key = (row[ci["arm"]], row[ci["seed"]])
|
||||
run = by_key.setdefault(key, {"arm_csv": row[ci["arm"]], "seed": row[ci["seed"]],
|
||||
"refr": 0, "vhack": "-", "teacher_off": None,
|
||||
"eval_n": int(row[ci["eval_n"]]),
|
||||
"steps": [], **{k: [] for k in CSV_SERIES}})
|
||||
run["steps"].append(int(row[ci["step"]]))
|
||||
for k in CSV_SERIES:
|
||||
@@ -316,7 +324,8 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
# ylim floor slightly below 0 so a pinned-at-zero series (route2 hack) draws
|
||||
# ABOVE the axis line instead of hiding under it -- the whole result is that
|
||||
# red sits on zero, so it must be visible, not absent.
|
||||
_series_panel(ax, rs, RATE_COLS, RATE_COLORS, ylim=(-0.035, 1.0), label_series=(col == 0))
|
||||
_series_panel(ax, rs, RATE_COLS, RATE_COLORS, ylim=(-0.025, HACK_YMAX),
|
||||
label_series=(col == 0))
|
||||
# If hack is pinned at zero all panel, say so -- else "no red line" reads as
|
||||
# a plotting bug rather than the finding.
|
||||
hk = [r["hack_s"] for r in rs if "hack_s" in r]
|
||||
@@ -324,12 +333,12 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
ax.annotate("hack ≈ 0", (0.04, 0.0), xycoords=("axes fraction", "data"),
|
||||
color=RATE_COLORS["hack_s"], fontsize=8, va="bottom",
|
||||
xytext=(0, 3), textcoords="offset points")
|
||||
ax.set_xlabel("optimizer step")
|
||||
ax.set_xlabel("optimizer updates completed")
|
||||
onsets = [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
if onsets:
|
||||
s0 = float(np.mean(onsets))
|
||||
ax.axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
||||
ax.annotate("first hack", (s0, 1.0), color="0.4", fontsize=7,
|
||||
ax.annotate("first hack", (s0, HACK_YMAX), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
|
||||
axes[0][0].set_ylabel("deployed rate")
|
||||
@@ -340,8 +349,10 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
ax.tick_params(labelsize=8)
|
||||
|
||||
if SHOW_TITLE:
|
||||
eval_ns = sorted({r["eval_n"] for r in runs})
|
||||
fig.suptitle("Training dynamics: deployed hack vs solve by arm "
|
||||
"(deploy-eval n=64 T=0.7; EMA-5; dashed = mean hack onset)", fontsize=10)
|
||||
f"(fixed monitoring subset n={eval_ns}; T=0.7; EMA-5; dashed = mean hack onset)",
|
||||
fontsize=10)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
||||
else:
|
||||
fig.tight_layout()
|
||||
@@ -349,13 +360,12 @@ def plot(runs: list[dict], out: Path) -> None:
|
||||
logger.info(f"wrote {out} ({len(runs)} runs, arms={[arm_label(a) for a in arms]})")
|
||||
|
||||
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim=(0, 1)):
|
||||
def _overlay_panel(ax, by_arm, arms, key, *, label, label_arms, ylim=(0, 1)):
|
||||
"""Overlay one metric (key) per arm on ax: faint per-seed EMA lines + bold
|
||||
EMA mean, optional mean-onset dot. When label_arms, direct-label each arm at its
|
||||
endpoint (de-collided in y). An arm whose mean series sits at zero gets a
|
||||
EMA mean. When label_arms, direct-label each arm at its endpoint (de-collided
|
||||
in y). An arm whose mean series sits at zero gets a
|
||||
"$\\approx 0$" tag so a pinned-at-zero line reads as a finding, not a missing line."""
|
||||
ends = [] # (y_endpoint, x_endpoint, arm, color, is_zero) for direct labels
|
||||
onset_steps = [] # mean-onset across arms -> ONE labeled vertical line (see below)
|
||||
for arm in arms:
|
||||
rs = [r for r in by_arm[arm] if key in r]
|
||||
if not rs:
|
||||
@@ -370,16 +380,7 @@ def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim
|
||||
ym = np.nanmean(np.stack([y[:L] for y in stacked]), axis=0)
|
||||
xm = rs[0]["steps"][:L]
|
||||
ax.plot(xm, ym, color=color, lw=2.0, solid_capstyle="round")
|
||||
if with_onset:
|
||||
onset_steps += [s for r in rs if (s := _onset(r["steps"], r["hack_s"])) is not None]
|
||||
ends.append((float(ym[-1]), float(xm[-1]), arm, color, float(np.nanmax(ym)) < 0.02))
|
||||
# First-hack as a labeled vertical line (matches the small-multiples), not a dot:
|
||||
# a dashed rule reads as "emergence starts here" across both arms in one mark.
|
||||
if with_onset and onset_steps:
|
||||
s0 = float(np.mean(onset_steps))
|
||||
ax.axvline(s0, color="0.55", lw=0.8, ls=(0, (4, 3)), zorder=0)
|
||||
ax.annotate("first hack", (s0, ylim[1]), color="0.4", fontsize=7,
|
||||
xytext=(2, -2), textcoords="offset points", va="top")
|
||||
ax.set_ylim(*ylim)
|
||||
ax.set_ylabel(label)
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
@@ -407,9 +408,8 @@ def _overlay_panel(ax, by_arm, arms, key, *, label, with_onset, label_arms, ylim
|
||||
|
||||
def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
"""Two stacked panels sharing x: student hack rate (top) and solve rate (bottom)
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; onset dot on the hack panel.
|
||||
Arms are direct-labelled on the TOP (hack) panel -- readers scan top-to-bottom, and
|
||||
the hack panel carries the headline (an arm pinned at 0 gets a $\\approx 0$ tag)."""
|
||||
per arm. Faint per-seed EMA lines + bold EMA-5 mean; arms are direct-labelled
|
||||
at their endpoints."""
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
for r in runs:
|
||||
by_arm[classify(r)].append(r)
|
||||
@@ -418,12 +418,15 @@ def plot_hack_overlay(runs: list[dict], out: Path) -> None:
|
||||
fig, (ax_h, ax_s) = plt.subplots(2, 1, figsize=(5.2, 5.2), sharex=True)
|
||||
# floor the hack panel below 0 so a route line pinned at 0 draws above the axis
|
||||
_overlay_panel(ax_h, by_arm, arms, "hack_s", label="hack rate",
|
||||
with_onset=True, label_arms=True, ylim=(-0.035, 1.0))
|
||||
label_arms=True, ylim=(-0.025, HACK_YMAX))
|
||||
_overlay_panel(ax_s, by_arm, arms, "gt_s", label="solve rate",
|
||||
with_onset=False, label_arms=False, ylim=(0, 1.0))
|
||||
ax_s.set_xlabel("optimizer step")
|
||||
label_arms=True, ylim=(0, SOLVE_YMAX))
|
||||
ax_s.set_xlabel("optimizer updates completed")
|
||||
if SHOW_TITLE:
|
||||
ax_h.set_title("Hack vs solve rate by arm (EMA-5; dot = mean hack onset)", fontsize=10)
|
||||
n_seed = min(len(by_arm[a]) for a in arms)
|
||||
eval_ns = sorted({r["eval_n"] for r in runs})
|
||||
ax_h.set_title(f"Hack vs solve rate on fixed n={eval_ns} monitoring subset "
|
||||
f"(EMA-5; n={n_seed} seed/arm)", fontsize=10)
|
||||
fig.tight_layout()
|
||||
save_fig(fig, out)
|
||||
logger.info(f"wrote {out}")
|
||||
@@ -448,6 +451,7 @@ def plot_train_vs_deploy(runs: list[dict], out: Path) -> None:
|
||||
d = np.abs(ht - hd)
|
||||
return bool(np.isfinite(d).any() and np.nanmax(d) > 0.02)
|
||||
if not any(_has_train_gap(r) for r in runs):
|
||||
out.unlink(missing_ok=True)
|
||||
logger.info(f"skip {out.name}: train==deploy in every run -> no knob-ON contrast to show")
|
||||
return
|
||||
by_arm: dict[str, list[dict]] = defaultdict(list)
|
||||
@@ -588,7 +592,8 @@ def _render_all(runs: list[dict], out: Path) -> None:
|
||||
tvd = out.with_name(out.stem + "_train_deploy.png")
|
||||
plot_train_vs_deploy(runs, tvd) # 2x2 train(on) vs deploy(off)
|
||||
for p in (out, overlay, tvd):
|
||||
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
||||
if p.exists():
|
||||
logger.info(f"docs/figs latest -> {link_latest(p)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+86
-320
@@ -1,10 +1,8 @@
|
||||
"""Distillation probe: hacky teacher samples, student trains with per-sample
|
||||
v_hack cosine logging. One file per step (step_NNN.jsonl.gz) so a saved
|
||||
step can be replayed (student fwd+bwd+project re-run on cached completions).
|
||||
"""Generate teacher/base pools or run the direct distillation probe.
|
||||
|
||||
Usage modes (via flags):
|
||||
--teacher-only --steps=20 just generate+grade, save step files (no student work)
|
||||
--replay-dir=PATH student fwd+bwd+project on saved batches (no teacher)
|
||||
--base-only --steps=20 generate a mostly-clean base-model pool
|
||||
(default) teacher generate + student train in one process
|
||||
|
||||
Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79%
|
||||
@@ -12,16 +10,9 @@ hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm
|
||||
trained on the loophole env). Merged into base for plain HF inference.
|
||||
Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad).
|
||||
|
||||
Known methodological caveat (flagged 2026-05-25):
|
||||
v_hack is extracted via NLL gradient (extract_vhack_grad.py) on
|
||||
contrastive pairs. GRPO's policy gradient is reward-weighted, not NLL.
|
||||
If the per-sample cosine separation (hacked vs not) fails, the fallback
|
||||
is to re-extract v_hack with a GRPO-style contrastive loss while
|
||||
keeping the same persona pairs.
|
||||
|
||||
Per-step pipeline:
|
||||
1. (skip if replay) Sample one problem; teacher generates G completions.
|
||||
2. (skip if replay) compute_reward per completion -> r, hacked, gt_pass.
|
||||
1. Sample one problem; teacher generates G completions.
|
||||
2. compute_reward per completion -> r, hacked, gt_pass.
|
||||
3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched.
|
||||
4. (skip if teacher-only) For each sample i: snapshot delta_S.grad,
|
||||
compute single-sample Dr.GRPO loss, backward, diff = contrib_i,
|
||||
@@ -76,21 +67,10 @@ class Config:
|
||||
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors"
|
||||
pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json"
|
||||
tag: str = ""
|
||||
replay_dir: Path | None = None
|
||||
teacher_only: bool = False
|
||||
# Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack
|
||||
# samples. Used to populate the "no_hack" bucket for cosine comparison.
|
||||
base_only: bool = False
|
||||
# TODO(spec2 §"Phase 2"): mixed-replay GRPO was started here, then user
|
||||
# FIXME: the replay fields below are wired into the loader (heterogeneous
|
||||
# plen handling) but the GRPO loss path is incomplete -- finish or remove.
|
||||
# train.py at small scale is the canonical Phase 2 mechanism.
|
||||
replay_dirs: str | None = None
|
||||
# Sandwich schedule: [0, pre) student-gen -> [pre, pre+replay) replay-distill
|
||||
# -> [pre+replay, steps) student-gen. With pre_warmup_steps=0 reduces to the
|
||||
# original "replay then gen" schedule.
|
||||
pre_warmup_steps: int = 0
|
||||
warmup_replay_steps: int | None = None
|
||||
|
||||
|
||||
def load_student(device):
|
||||
@@ -151,7 +131,7 @@ def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
|
||||
|
||||
|
||||
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
"""Student-gen step in warmupgen mode: full rows with prompts/completions."""
|
||||
"""Save full generated rows for one direct probe step."""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
@@ -159,26 +139,6 @@ def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
f.write(json.dumps(r) + "\n")
|
||||
|
||||
|
||||
def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
"""Warmup-replay annotations: cos + flags only; completions live in pool dirs."""
|
||||
slim_keys = ("step", "sample_id", "src_pool", "src_problem_id",
|
||||
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
|
||||
"cos_S_contrib", "grad_norm_contrib",
|
||||
"mean_cos_pre", "mean_cos_post", "frac_fired", "arm",
|
||||
"logp_mean", "delta_S_norm", "imp_ratio")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
for r in rows:
|
||||
f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n")
|
||||
|
||||
|
||||
def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]:
|
||||
path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
||||
with gzip.open(path, "rt") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
if cfg.tag:
|
||||
tag = cfg.tag
|
||||
@@ -196,7 +156,7 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||||
logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} "
|
||||
f"G={cfg.group} seed={cfg.seed} "
|
||||
f"teacher_only={cfg.teacher_only} replay={cfg.replay_dir is not None}")
|
||||
f"teacher_only={cfg.teacher_only} base_only={cfg.base_only}")
|
||||
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
|
||||
@@ -211,49 +171,28 @@ def main(cfg: Config) -> int:
|
||||
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
|
||||
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
||||
|
||||
# When warmup_replay_steps is set and we're in replay mode, we need the
|
||||
# student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase.
|
||||
needs_student_gen = (cfg.warmup_replay_steps is not None
|
||||
and cfg.warmup_replay_steps < cfg.steps
|
||||
and (cfg.replay_dir is not None or cfg.replay_dirs is not None))
|
||||
|
||||
if cfg.replay_dir is None and cfg.replay_dirs is None:
|
||||
if cfg.base_only:
|
||||
# Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts.
|
||||
teacher = AutoModelForCausalLM.from_pretrained(
|
||||
STUDENT_MODEL, dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to(device)
|
||||
teacher.eval()
|
||||
for p in teacher.parameters():
|
||||
p.requires_grad_(False)
|
||||
problems = load_problems(cfg.n_problems, ["run_tests"])
|
||||
logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems")
|
||||
else:
|
||||
teacher = load_teacher(cfg.teacher, device)
|
||||
problems = load_problems(cfg.n_problems, ["run_tests"])
|
||||
logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)")
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
if cfg.base_only:
|
||||
teacher = AutoModelForCausalLM.from_pretrained(
|
||||
STUDENT_MODEL, dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to(device)
|
||||
teacher.eval()
|
||||
for p in teacher.parameters():
|
||||
p.requires_grad_(False)
|
||||
logger.info("loaded base Qwen3-4B")
|
||||
else:
|
||||
teacher = None
|
||||
problems = gen_cfg = None
|
||||
if needs_student_gen:
|
||||
problems = load_problems(cfg.n_problems, ["run_tests"])
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen")
|
||||
teacher = load_teacher(cfg.teacher, device)
|
||||
logger.info("loaded reward-hacking teacher")
|
||||
problems = load_problems(cfg.n_problems, ["gt_only" if cfg.base_only else "run_tests"])
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
# Pools are content-keyed (teacher_pool / base_pool) so replay loaders find
|
||||
# them. Pool files live flat at the pool root (prompt_*.jsonl.gz). Training
|
||||
# Pools are content-keyed (teacher_pool / base_pool). Pool files live flat
|
||||
# at the pool root (prompt_*.jsonl.gz). Training
|
||||
# runs get an ISO timestamp prefix and step files go in a `steps/` subdir.
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/
|
||||
@@ -266,136 +205,55 @@ def main(cfg: Config) -> int:
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
pad_id = tok.pad_token_id
|
||||
|
||||
# logp at first encounter of each replay prompt; used to compute the
|
||||
# importance ratio = exp(logp_now - logp_step0). Diagnostic only.
|
||||
logp_step0_by_prompt: dict[int, list[float]] = {}
|
||||
|
||||
logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
|
||||
logger.info(
|
||||
"SHOULD: ||dS|| grows monotonically across warmup; "
|
||||
"SHOULD: ||dS|| grows during direct distillation; "
|
||||
"logp[hack] > logp[no] under teacher-forcing; "
|
||||
"ratio~1.00 during replay (no off-policy drift); "
|
||||
"post-warmup hack rate > 0 for vanilla; projected arm hack < vanilla. "
|
||||
"projected arm hack < vanilla. "
|
||||
"ELSE: adapter not learning, basis mismatch, or loss not flowing."
|
||||
)
|
||||
|
||||
# Track gen-phase hack rate for tail summary. In sandwich mode, separately
|
||||
# accumulate pre-distill and post-distill so we can answer "does distillation
|
||||
# induce hacking that persists?" The "main metric" is post-distill hack rate.
|
||||
pre_hack_rates: list[float] = []
|
||||
pre_pass_rates: list[float] = []
|
||||
post_hack_rates: list[float] = []
|
||||
post_pass_rates: list[float] = []
|
||||
hack_rates: list[float] = []
|
||||
pass_rates: list[float] = []
|
||||
|
||||
for step in range(cfg.steps):
|
||||
t0 = time.time()
|
||||
if opt is not None:
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
# --- 1-2. generate + grade (or replay) ----------------------------
|
||||
# Each sample carries its own plen so we can mix pools with different
|
||||
# prompts (e.g. teacher_pool hinted vs base_pool unhinted). For
|
||||
# uniform-prompt replay all plens are identical and this is a no-op.
|
||||
per_sample_meta: list[dict] | None = None
|
||||
plens: list[int] | None = None
|
||||
# warmup_replay_steps boundary: before it, replay from saved pools; after,
|
||||
# student generates with its learned adapter (canonical GRPO).
|
||||
replay_on = cfg.warmup_replay_steps is not None
|
||||
replay_end = (cfg.pre_warmup_steps + cfg.warmup_replay_steps) if replay_on else None
|
||||
replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \
|
||||
and (not replay_on or (cfg.pre_warmup_steps <= step < replay_end))
|
||||
if replay_on and step == cfg.pre_warmup_steps and cfg.pre_warmup_steps > 0:
|
||||
logger.info(f"--- step {step}: pre-warmup gen over; starting replay-distill ---")
|
||||
if replay_on and step == replay_end:
|
||||
logger.info(f"--- step {step}: replay-distill over; switching to student-generation ---")
|
||||
if replay_active:
|
||||
# Pick the same problem from every pool so all G samples in this step
|
||||
# share one prompt -> per-prompt centered advantage is meaningful.
|
||||
pools = (
|
||||
[Path(p) for p in cfg.replay_dirs.split(",")]
|
||||
if cfg.replay_dirs is not None else [cfg.replay_dir]
|
||||
)
|
||||
per_pool = cfg.group // len(pools)
|
||||
# Enumerate problem ids from the first pool. Cycle modulo size.
|
||||
pool_prompt_ids = sorted(
|
||||
int(p.name.removeprefix("prompt_").split(".")[0])
|
||||
for p in pools[0].glob("prompt_*.jsonl.gz")
|
||||
)
|
||||
assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}"
|
||||
replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)]
|
||||
saved_all = []
|
||||
for pool_dir in pools:
|
||||
pool_rows = load_prompt(pool_dir, replay_problem_id)
|
||||
for s in pool_rows[:per_pool]:
|
||||
s["src_pool"] = pool_dir.name
|
||||
s["src_problem_id"] = replay_problem_id
|
||||
saved_all.append(s)
|
||||
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
|
||||
# Build padded merged: each sample is prompt_ids + completion_ids,
|
||||
# pad to max length with pad_id. Track plen per sample.
|
||||
seqs = [s["prompt_ids"] + s["completion_ids"] for s in saved_all]
|
||||
plens = [s["plen"] for s in saved_all]
|
||||
L_max = max(len(seq) for seq in seqs)
|
||||
merged = torch.full((cfg.group, L_max), pad_id, dtype=torch.long, device=device)
|
||||
for i, seq in enumerate(seqs):
|
||||
merged[i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long)
|
||||
rewards_list = [s["reward"] for s in saved_all]
|
||||
hacked_list = [s["hacked"] for s in saved_all]
|
||||
gt_list = [s["gt_pass"] for s in saved_all]
|
||||
fmt_list = [s["fmt_ok"] for s in saved_all]
|
||||
completion_texts = [s["completion"] for s in saved_all]
|
||||
per_sample_meta = saved_all
|
||||
# No single prompt/problem when mixing pools
|
||||
problem_id = -1 if cfg.replay_dirs else saved_all[0]["problem_id"]
|
||||
problem_messages = None
|
||||
prompt = None
|
||||
# --- 1-2. generate + grade ----------------------------------------
|
||||
generator = teacher
|
||||
gen_label = "base" if cfg.base_only else "teacher"
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
idx = step % len(problems)
|
||||
else:
|
||||
# Direct generation: either teacher (teacher_only/base_only) or
|
||||
# student (post-warmup in warmup->gen mode). Pool gen iterates
|
||||
# problems sequentially so the on-disk prompt_NNNN file naming is
|
||||
# deterministic. Student-gen mode randomises so the warmed adapter
|
||||
# sees varied prompts.
|
||||
generator = teacher if teacher is not None else student
|
||||
gen_label = "teacher" if teacher is not None else "student"
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
idx = step % len(problems)
|
||||
else:
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
prompt = tok.apply_chat_template(
|
||||
prob["messages"], tokenize=False, add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
prompt = tok.apply_chat_template(
|
||||
prob["messages"], tokenize=False, add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||||
plen = enc.input_ids.shape[1]
|
||||
if plen + cfg.max_new > 2048:
|
||||
raise ValueError(f"step {step}: plen+max_new={plen + cfg.max_new} exceeds 2048")
|
||||
generator.config.use_cache = True
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
|
||||
generator.config.use_cache = False
|
||||
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
|
||||
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
|
||||
for txt in completion_texts:
|
||||
r = compute_reward(
|
||||
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
)
|
||||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||||
plen = enc.input_ids.shape[1]
|
||||
if plen + cfg.max_new > 2048:
|
||||
logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)")
|
||||
continue
|
||||
generator.config.use_cache = True
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
|
||||
generator.config.use_cache = False
|
||||
if generator is student:
|
||||
student.train() # restore train mode for the bwd pass below
|
||||
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
|
||||
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
|
||||
for txt in completion_texts:
|
||||
r = compute_reward(
|
||||
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
)
|
||||
rewards_list.append(r.reward); hacked_list.append(r.hacked)
|
||||
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
||||
problem_id = prob["problem_id"]
|
||||
problem_messages = prob["messages"]
|
||||
# Mark each sample so jsonl knows where it came from.
|
||||
per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label,
|
||||
"src_problem_id": problem_id,
|
||||
"step": step, "sample_id": i} for i in range(cfg.group)]
|
||||
|
||||
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
|
||||
plens_eff = plens if plens is not None else [plen] * cfg.group
|
||||
rewards_list.append(r.reward); hacked_list.append(r.hacked)
|
||||
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
||||
problem_id = prob["problem_id"]
|
||||
problem_messages = prob["messages"]
|
||||
per_sample_meta = [{"src_pool": gen_label, "src_problem_id": problem_id} for _ in range(cfg.group)]
|
||||
|
||||
per_sample_cos: list[float | None] = [None] * cfg.group
|
||||
per_sample_norm: list[float | None] = [None] * cfg.group
|
||||
@@ -403,21 +261,18 @@ def main(cfg: Config) -> int:
|
||||
"mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"),
|
||||
"frac_fired": float("nan")}
|
||||
|
||||
# Dr.GRPO unbiased advantage (centered, no /std). Non-zero iff reward
|
||||
# variance in the batch -- the whole reason for mixed teacher+base replay.
|
||||
# Dr.GRPO unbiased advantage (centered, no /std).
|
||||
rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device)
|
||||
adv = rewards_t - rewards_t.mean()
|
||||
|
||||
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
|
||||
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
|
||||
per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group
|
||||
per_sample_loss: list[float] = [float("nan")] * cfg.group
|
||||
if not (cfg.teacher_only or cfg.base_only):
|
||||
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
||||
for i in range(cfg.group):
|
||||
plen_i = plens_eff[i]
|
||||
mi = merged[i:i+1]
|
||||
ci = mi[:, plen_i:]
|
||||
ci = mi[:, plen:]
|
||||
L_c_i = ci.shape[1]
|
||||
logp_i = per_token_logps(
|
||||
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
|
||||
@@ -435,21 +290,6 @@ def main(cfg: Config) -> int:
|
||||
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
|
||||
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
|
||||
|
||||
# Importance ratio vs first-encounter logp. Only meaningful in
|
||||
# replay mode (same tokens, drifting student). For student-gen we
|
||||
# set ratio=1.0 because each step has freshly generated tokens.
|
||||
if replay_active and replay_problem_id not in logp_step0_by_prompt:
|
||||
logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean)
|
||||
per_sample_imp_ratio = [1.0] * cfg.group
|
||||
elif replay_active:
|
||||
base = logp_step0_by_prompt[replay_problem_id]
|
||||
per_sample_imp_ratio = [
|
||||
float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item())
|
||||
for i in range(cfg.group)
|
||||
]
|
||||
else:
|
||||
per_sample_imp_ratio = [1.0] * cfg.group
|
||||
|
||||
# Both arms measure cos_pre/out; vanilla uses measure_only so the
|
||||
# gradient passes through unchanged.
|
||||
diag = project_delta_S_grad(
|
||||
@@ -460,62 +300,47 @@ def main(cfg: Config) -> int:
|
||||
opt.step()
|
||||
|
||||
# --- 6.5 adapter movement diagnostic ---
|
||||
# ||delta_S||_2 across all wrapped modules. If learning is happening, this
|
||||
# should grow over warmup. Flat == adapter not updating.
|
||||
# None in pool-gen modes (teacher_only/base_only) where no wrappers exist.
|
||||
delta_S_norm = (
|
||||
float(sum(info["delta_S"].data.float().pow(2).sum().item()
|
||||
for info in wrappers.values()) ** 0.5)
|
||||
if wrappers is not None else 0.0
|
||||
)
|
||||
|
||||
# --- 7. write step file. Slim in replay-warmup (completions live in pool dirs);
|
||||
# full in student-gen so we can read what the student actually emitted. ---
|
||||
is_replay = replay_active
|
||||
# --- 7. write full generated rows ---------------------------------
|
||||
rows = []
|
||||
for i in range(cfg.group):
|
||||
plen_i = plens_eff[i]
|
||||
meta = per_sample_meta[i] if per_sample_meta is not None else None
|
||||
meta = per_sample_meta[i]
|
||||
row = {
|
||||
"step": step, "sample_id": i,
|
||||
"reward": float(rewards_list[i]),
|
||||
"hacked": bool(hacked_list[i]),
|
||||
"gt_pass": bool(gt_list[i]),
|
||||
"fmt_ok": bool(fmt_list[i]),
|
||||
"comp_len": int((merged[i, plen_i:] != pad_id).sum().item()),
|
||||
"comp_len": int((merged[i, plen:] != pad_id).sum().item()),
|
||||
"cos_S_contrib": per_sample_cos[i],
|
||||
"grad_norm_contrib": per_sample_norm[i],
|
||||
"mean_cos_pre": diag["mean_cos_pre"],
|
||||
"mean_cos_post": diag["mean_cos_post"],
|
||||
"frac_fired": diag["frac_fired"],
|
||||
"arm": cfg.arm,
|
||||
"src_pool": meta.get("src_pool") if meta else None,
|
||||
"src_problem_id": meta.get("src_problem_id") if meta else None,
|
||||
"src_pool": meta["src_pool"],
|
||||
"src_problem_id": meta["src_problem_id"],
|
||||
"logp_mean": per_sample_logp_mean[i],
|
||||
"per_sample_loss": per_sample_loss[i],
|
||||
"imp_ratio": per_sample_imp_ratio[i],
|
||||
"delta_S_norm": delta_S_norm,
|
||||
"problem_id": int(problem_id),
|
||||
"problem_messages": problem_messages,
|
||||
"prompt": prompt,
|
||||
"plen": int(plen),
|
||||
"prompt_ids": merged[i, :plen].tolist(),
|
||||
"completion_ids": merged[i, plen:].tolist(),
|
||||
"completion": completion_texts[i],
|
||||
}
|
||||
if not is_replay:
|
||||
# Direct-gen mode: keep full data (we generated this; pool dirs need it).
|
||||
row.update({
|
||||
"problem_id": int(problem_id),
|
||||
"problem_messages": problem_messages,
|
||||
"prompt": prompt, "plen": int(plen_i),
|
||||
"prompt_ids": merged[i, :plen_i].tolist(),
|
||||
"completion_ids": merged[i, plen_i:].tolist(),
|
||||
"completion": completion_texts[i],
|
||||
})
|
||||
rows.append(row)
|
||||
if is_replay:
|
||||
# Warmup replay: slim cos annotations only; full rows live in the pools.
|
||||
save_step_slim(steps_dir, step, rows)
|
||||
elif cfg.teacher_only or cfg.base_only:
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
# Pool generation: one file per problem_id (each = G rollouts).
|
||||
save_prompt(out_dir, int(problem_id), rows)
|
||||
else:
|
||||
# Student-gen in warmupgen: full rows so we can see what the warmed
|
||||
# adapter actually emits at gen time.
|
||||
save_step(steps_dir, step, rows)
|
||||
|
||||
for i in range(cfg.group):
|
||||
@@ -528,14 +353,8 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
hr = sum(hacked_list) / cfg.group
|
||||
pr = sum(gt_list) / cfg.group
|
||||
# Record student-gen rates split by phase (pre-distill vs post-distill).
|
||||
if not replay_active:
|
||||
if replay_on and step >= replay_end:
|
||||
post_hack_rates.append(hr)
|
||||
post_pass_rates.append(pr)
|
||||
else:
|
||||
pre_hack_rates.append(hr)
|
||||
pre_pass_rates.append(pr)
|
||||
hack_rates.append(hr)
|
||||
pass_rates.append(pr)
|
||||
# Bucket cos by (hacked, gt_pass) so the discrimination signal is inline.
|
||||
def _bucket_mean(pred):
|
||||
cs = [per_sample_cos[i] for i in range(cfg.group)
|
||||
@@ -552,20 +371,11 @@ def main(cfg: Config) -> int:
|
||||
else:
|
||||
ps_summary = "per_sample cos=nan"
|
||||
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
|
||||
# logp_hack should rise monotonically across warmup steps.
|
||||
# logp_hack should rise across steps.
|
||||
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
|
||||
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
|
||||
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
|
||||
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
|
||||
# imp_ratio: drift of student's logp on replayed tokens vs first encounter.
|
||||
# 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk).
|
||||
valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan
|
||||
if valid_ratios:
|
||||
r_min, r_max = min(valid_ratios), max(valid_ratios)
|
||||
r_mean = sum(valid_ratios) / len(valid_ratios)
|
||||
ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}"
|
||||
else:
|
||||
ratio_summary = "ratio=nan"
|
||||
logger.info(
|
||||
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
|
||||
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
|
||||
@@ -573,88 +383,44 @@ def main(cfg: Config) -> int:
|
||||
f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} "
|
||||
f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} "
|
||||
f"fired={diag['frac_fired']:.2f} "
|
||||
f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} "
|
||||
f"logp[hack={lp_h_s} no={lp_n_s}] "
|
||||
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
||||
)
|
||||
|
||||
# --- tail summary (BLUF main metric) ---
|
||||
def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan")
|
||||
pre_hack, pre_pass = _avg(pre_hack_rates), _avg(pre_pass_rates)
|
||||
post_hack, post_pass = _avg(post_hack_rates), _avg(post_pass_rates)
|
||||
# Use post-distill hack as headline; fall back to pre if no post phase.
|
||||
if post_hack_rates:
|
||||
head_hack, head_pass, head_n = post_hack, post_pass, len(post_hack_rates)
|
||||
head_label = "post"
|
||||
else:
|
||||
head_hack, head_pass, head_n = pre_hack, pre_pass, len(pre_hack_rates)
|
||||
head_label = "pre"
|
||||
head_hack, head_pass, head_n = _avg(hack_rates), _avg(pass_rates), len(hack_rates)
|
||||
cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡"))
|
||||
|
||||
plot_path = out_dir / "rollout_stack.png"
|
||||
report_path = out_dir / "report.md"
|
||||
if cfg.warmup_replay_steps is not None:
|
||||
try:
|
||||
from probe_plot_stack import Config as PlotCfg, main as plot_main
|
||||
plot_main(PlotCfg(
|
||||
run_dir=out_dir,
|
||||
out_path=plot_path,
|
||||
pre_warmup=cfg.pre_warmup_steps,
|
||||
warmup=cfg.pre_warmup_steps + cfg.warmup_replay_steps,
|
||||
smooth=10,
|
||||
title=f"{cfg.arm} GRPO seed={cfg.seed} "
|
||||
f"({cfg.pre_warmup_steps} pre + {cfg.warmup_replay_steps} distill"
|
||||
f" + {cfg.steps - cfg.pre_warmup_steps - cfg.warmup_replay_steps} post,"
|
||||
f" 10-step SMA)",
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"auto-plot failed: {e}")
|
||||
plot_path = None
|
||||
|
||||
meta = {
|
||||
"arm": cfg.arm,
|
||||
"seed": cfg.seed,
|
||||
"tag": tag,
|
||||
"steps": cfg.steps,
|
||||
"pre_warmup_steps": cfg.pre_warmup_steps,
|
||||
"warmup_replay_steps": cfg.warmup_replay_steps,
|
||||
"group": cfg.group,
|
||||
"n_problems": cfg.n_problems,
|
||||
"argv": sys.argv,
|
||||
"pre": {"hack": pre_hack, "pass": pre_pass, "n_steps": len(pre_hack_rates)},
|
||||
"post": {"hack": post_hack, "pass": post_pass, "n_steps": len(post_hack_rates)},
|
||||
"hack": head_hack,
|
||||
"pass": head_pass,
|
||||
}
|
||||
caption = (
|
||||
f"Rollout outcomes per training step for {cfg.arm} GRPO at seed={cfg.seed}. "
|
||||
f"Schedule: {cfg.pre_warmup_steps} steps of student-generated rollouts, "
|
||||
f"then {cfg.warmup_replay_steps} steps of replay-distillation from a saved "
|
||||
f"teacher+base pool, then {cfg.steps - cfg.pre_warmup_steps - (cfg.warmup_replay_steps or 0)} "
|
||||
f"steps of student-generated rollouts. Categories: correct (green), correct "
|
||||
f"with attempted reward hack (yellow), reward hack (red), attempted reward "
|
||||
f"hack (purple), incorrect (grey). Values are a 10-step trailing moving "
|
||||
f"average. Dashed lines mark distillation on/off."
|
||||
)
|
||||
report_path = out_dir / "report.md"
|
||||
report_path.write_text(
|
||||
"# probe_distill report\n\n"
|
||||
f"\n\n"
|
||||
f"*{caption}*\n\n"
|
||||
"## metadata\n\n```json\n"
|
||||
+ json.dumps(meta, indent=2) + "\n```\n"
|
||||
)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"out: {out_dir}/step_*.jsonl.gz")
|
||||
logger.info(f"plot: {plot_path}")
|
||||
logger.info(f"report: {report_path}")
|
||||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||||
logger.info(
|
||||
f"main metric ({head_label}-distill): hack={head_hack:.2f} pass={head_pass:.2f} "
|
||||
f"main metric: hack={head_hack:.2f} pass={head_pass:.2f} "
|
||||
f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]"
|
||||
)
|
||||
logger.info(
|
||||
f"{cue} arm={cfg.arm} seed={cfg.seed} "
|
||||
f"pre[hack={pre_hack:.2f},pass={pre_pass:.2f},n={len(pre_hack_rates)}] "
|
||||
f"post[hack={post_hack:.2f},pass={post_pass:.2f},n={len(post_hack_rates)}] "
|
||||
f"pre_warmup={cfg.pre_warmup_steps} warmup={cfg.warmup_replay_steps} "
|
||||
f"hack={head_hack:.2f} pass={head_pass:.2f} "
|
||||
f"steps={cfg.steps} G={cfg.group} tag={tag}"
|
||||
)
|
||||
return 0
|
||||
|
||||
@@ -15,6 +15,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from vgrout.antipasto import wrap_model_with_antipasto
|
||||
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
|
||||
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
|
||||
from vgrout.run_artifacts import RUN_SCHEMA
|
||||
|
||||
|
||||
def main(run_dir: Positional[Path]) -> None:
|
||||
@@ -61,6 +62,7 @@ def main(run_dir: Positional[Path]) -> None:
|
||||
model, tok, problems, eval_idxs, gen_cfg_eval, device, cfg["max_new"], cfg["eval_batch_size"])
|
||||
|
||||
out = {
|
||||
"schema": RUN_SCHEMA,
|
||||
"run_dir": run_dir.name, "model": model_name, "step": meta.get("step"),
|
||||
"eval_set": "test", "eval_modes": eval_modes,
|
||||
"n": ev["n"], "deploy_hack": ev["hack"], "deploy_vhack": ev["vhack"], "deploy_solve": ev["solve"],
|
||||
|
||||
+33
-176
@@ -1,196 +1,53 @@
|
||||
"""Aggregate all train.py runs from logs/*.log into one sorted/grouped table.
|
||||
|
||||
Durable source: each run writes logs/<ts>_<preset>_<arm>_seed<seed>_<tag>.log
|
||||
with an `argv:` line (config) and per-step rows. We parse those directly and
|
||||
recompute the metrics ourselves, so this survives `pueue reset` and doesn't
|
||||
depend on the BLUF line.
|
||||
|
||||
Headline metric is mean-of-last-5-steps (noise-robust; the converged regime),
|
||||
shown for BOTH hack_s (reward hacks) and gt_s (ground-truth solves) on the
|
||||
STUDENT rollouts. Whole-run means are kept as a secondary column because the
|
||||
blog Table 1 uses whole-run and the two conventions disagree.
|
||||
|
||||
just results # full table sorted by time + grouped-by-config
|
||||
"""
|
||||
"""Training-rollout table from completed structured run artifacts."""
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
LOG_DIR = Path("logs")
|
||||
TS_RE = re.compile(r"(\d{8}T\d{6})")
|
||||
# Hard cutoff: only show eval2-era runs (recency-clean test set, dir6+ onward). Runs before
|
||||
# this are the OLD eval (contaminated holdout); their curated findings live in
|
||||
# docs/results_eval1_archive.md. Robust to old logs being present -- filters by the log's
|
||||
# own timestamp, so we don't rely on moving files out of logs/.
|
||||
EVAL2_CUTOFF = "20260607T000000"
|
||||
# Column positions are read from the header row by NAME, not hardcoded -- the
|
||||
# per-step table layout has changed over time (sprd/N dropped, cin/cout/hk_dep
|
||||
# added) so fixed indices silently mis-read newer logs and crash on smoke logs.
|
||||
|
||||
|
||||
def _colname(tok: str) -> str:
|
||||
# header tokens carry direction glyphs / markers: "gt_s↑", "hack_s?" -> "gt_s", "hack_s"
|
||||
return re.sub(r"[^a-z0-9_]", "", tok.lower())
|
||||
|
||||
|
||||
def _frac(tok: str) -> float | None:
|
||||
a, b = tok.split("/")
|
||||
return int(a) / int(b) if int(b) else None
|
||||
|
||||
|
||||
def _cfg(argv: str, preset_line: str) -> dict:
|
||||
def grab(pat, s, default="-"):
|
||||
# LAST match wins: recipes set a default flag then runs override it
|
||||
# (e.g. --v-hack-path twice, --mix-ratio twice); tyro takes the last.
|
||||
ms = re.findall(pat, s)
|
||||
return ms[-1] if ms else default
|
||||
return dict(
|
||||
# arm is the derived display name printed in the preset line
|
||||
# (vanilla/projected/routing). Read it from there, not the CLI flag:
|
||||
# old logs passed --arm, new logs pass --intervention, but BOTH print
|
||||
# `arm=<name>` in the preset line, so this one source covers all runs.
|
||||
arm=grab(r"\barm=(\w+)", preset_line),
|
||||
preset=grab(r"preset=(\w+)", preset_line),
|
||||
model=grab(r"model=(\S+)", preset_line),
|
||||
seed=grab(r"seed=(\d+)", preset_line, "?"), # preset= line always prints it
|
||||
mix=grab(r"--mix-ratio=([\d.]+)", argv, "0.5"),
|
||||
refr=grab(r"--vhack-refresh-every=(\d+)", argv),
|
||||
over=grab(r"--project-overshoot=([\d.]+)", argv, "1.0"),
|
||||
gate=grab(r"--gate-mode=(\w+)", argv, "one_sided"),
|
||||
k=grab(r"--v-hack-k=(\d+)", argv, "5"),
|
||||
dropf=grab(r"--v-hack-drop-bottom-frac=([\d.]+)", argv, "0.25"),
|
||||
vhack=grab(r"v-hack-path=out/(?:vhack/)?(\S+?)\.safetensors", argv),
|
||||
tag=grab(r"--out-tag=(\S+)", argv, ""),
|
||||
# full CLI args (after train.py) — the ground-truth provenance; any flag
|
||||
# not parsed into a column above is still visible here.
|
||||
argv=argv.split("train.py ", 1)[-1].strip() if "train.py " in argv else argv.strip(),
|
||||
)
|
||||
|
||||
|
||||
def parse_log(path: Path) -> dict | None:
|
||||
ts_m = TS_RE.search(path.name)
|
||||
if ts_m and ts_m.group(1) < EVAL2_CUTOFF:
|
||||
return None # pre-eval2 (OLD eval) -> docs/results_eval1_archive.md
|
||||
txt = path.read_text(errors="replace")
|
||||
argv = next((l for l in txt.splitlines() if "argv:" in l), None)
|
||||
preset_line = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
||||
if argv is None:
|
||||
return None
|
||||
# Locate the per-step table header to map gt_s/hack_s columns by NAME. The
|
||||
# train.py streaming table is the INFO line whose tokens start with "step"
|
||||
# and include "ref_eq" -- that signature excludes the old distill_* logs
|
||||
# which also have "step ..." lines but a different (hack=.. pass=..) format.
|
||||
header, names = None, []
|
||||
for l in txt.splitlines():
|
||||
if "| INFO |" not in l:
|
||||
continue
|
||||
toks = [_colname(t) for t in l.split("| INFO |", 1)[1].split()]
|
||||
if toks[:1] == ["step"] and "ref_eq" in toks:
|
||||
header, names = l, toks
|
||||
break
|
||||
if header is None:
|
||||
return None # not a train.py streaming run
|
||||
idx_hack, idx_gt = names.index("hack_s"), names.index("gt_s")
|
||||
hs, gts = [], []
|
||||
for line in txt.splitlines():
|
||||
if "| INFO |" not in line:
|
||||
continue
|
||||
row = line.split("| INFO |", 1)[1].split()
|
||||
if not row or not row[0].isdigit() or len(row) <= idx_hack:
|
||||
continue
|
||||
h, g = _frac(row[idx_hack]), _frac(row[idx_gt])
|
||||
if h is not None:
|
||||
hs.append(h)
|
||||
if g is not None:
|
||||
gts.append(g)
|
||||
if not hs:
|
||||
return None
|
||||
cfg = _cfg(argv, preset_line)
|
||||
# GROUND TRUTH mix: train.py prints `mix_ratio=<x>` in the pool INFO line
|
||||
# (what the run actually used). Many runs rely on the preset default and
|
||||
# pass no --mix-ratio flag, so the argv-based grab in _cfg defaults to the
|
||||
# wrong value (0.5) and mis-keys them. Override with the printed value.
|
||||
m_mix = re.search(r"mix_ratio=([\d.]+)", txt)
|
||||
if m_mix:
|
||||
cfg["mix"] = m_mix.group(1)
|
||||
if "tiny-random" in cfg["model"] or cfg["preset"] == "smoke":
|
||||
return None # CPU smoke runs, not real results
|
||||
if "probe" in cfg["tag"]:
|
||||
return None # early feasibility / lr-sweep probes, not comparable baselines
|
||||
# Exclude in-progress / aborted runs: a partial log has only the early
|
||||
# (low-hack) steps, which would read as an impossibly-good result. A run is
|
||||
# complete when it logged all `steps` per-step rows.
|
||||
m = re.search(r"steps=(\d+)", preset_line)
|
||||
if m and len(hs) < int(m.group(1)):
|
||||
return None
|
||||
ts = TS_RE.search(path.name)
|
||||
mean = lambda v: sum(v) / len(v) if v else None
|
||||
cfg.pop("model")
|
||||
return dict(
|
||||
time=ts.group(1) if ts else "?",
|
||||
**cfg,
|
||||
L5_hack=mean(hs[-5:]), L5_solve=mean(gts[-5:]),
|
||||
WH_hack=mean(hs), n=len(hs),
|
||||
log=path.name, # provenance: every number traces back to this file
|
||||
)
|
||||
from vgrout.run_artifacts import completed_runs
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rows = [r for p in sorted(LOG_DIR.glob("*.log")) if (r := parse_log(p))]
|
||||
runs = [run for run in completed_runs()
|
||||
if "tiny-random" not in run["cfg"]["model"] and "probe" not in run["cfg"]["out_tag"]]
|
||||
rows = [{
|
||||
"time": run["time"],
|
||||
"arm": run["arm"],
|
||||
"seed": str(run["cfg"]["seed"]),
|
||||
"mix": str(run["cfg"]["mix_ratio"]),
|
||||
"refr": str(run["cfg"]["vhack_refresh_every"]),
|
||||
"over": str(run["cfg"]["project_overshoot"]),
|
||||
"gate": run["cfg"]["gate_mode"],
|
||||
"k": str(run["cfg"]["v_hack_k"]),
|
||||
"dropf": str(run["cfg"]["v_hack_drop_bottom_frac"]),
|
||||
"vhack": run["cfg"]["vhack_pairs_path"].split("/")[-1].removesuffix(".json"),
|
||||
"L5_hack": run["l5_hack"],
|
||||
"L5_solve": run["l5_solve"],
|
||||
"WH_hack": run["whole_hack"],
|
||||
"n": len(run["rows"]),
|
||||
"run": run["run_dir"].name,
|
||||
} for run in runs]
|
||||
if not rows:
|
||||
print("no parseable runs in logs/")
|
||||
print("no completed non-smoke runs in out/runs/")
|
||||
return
|
||||
df = pl.DataFrame(rows).sort("time")
|
||||
|
||||
cols = ["arm", "seed", "mix", "refr", "over", "gate", "k", "dropf",
|
||||
"vhack", "L5_hack", "L5_solve", "WH_hack", "n", "log"]
|
||||
"vhack", "L5_hack", "L5_solve", "WH_hack", "n", "run"]
|
||||
print("\n## All runs (sorted by time)\n")
|
||||
print(tabulate(df.select(cols).rows(), headers=cols, tablefmt="pipe", floatfmt=".3f"))
|
||||
|
||||
# Grouped by config (collapse seeds): mean +/- std across seeds. Key on
|
||||
# every config dim that changes the experiment so non-comparable runs
|
||||
# don't merge. std is null for n=1 (undefined).
|
||||
key = ["arm", "mix", "refr", "over", "gate", "k", "dropf", "vhack"]
|
||||
g = (df.group_by(key)
|
||||
.agg(pl.col("L5_hack").mean().alias("hack"),
|
||||
pl.col("L5_hack").std().alias("hack_sd"),
|
||||
pl.col("L5_solve").mean().alias("solve"),
|
||||
pl.col("L5_solve").std().alias("solve_sd"),
|
||||
pl.len().alias("n"),
|
||||
pl.col("seed").sort().str.join(",").alias("seeds"))
|
||||
.sort(["mix", "arm", "refr", "over", "gate", "k"]))
|
||||
grouped = (df.group_by(key)
|
||||
.agg(pl.col("L5_hack").mean().alias("hack"),
|
||||
pl.col("L5_hack").std().alias("hack_sd"),
|
||||
pl.col("L5_solve").mean().alias("solve"),
|
||||
pl.col("L5_solve").std().alias("solve_sd"),
|
||||
pl.len().alias("n"),
|
||||
pl.col("seed").sort().str.join(",").alias("seeds"))
|
||||
.sort(["mix", "arm", "refr", "over", "gate", "k"]))
|
||||
gcols = key + ["hack", "hack_sd", "solve", "solve_sd", "n", "seeds"]
|
||||
print("\n## Grouped by config (mean +/- std over seeds)\n")
|
||||
print(tabulate(g.select(gcols).rows(), headers=gcols, tablefmt="pipe", floatfmt=".3f"))
|
||||
|
||||
# Paired vs same-seed vanilla (matched mix): the only honest way to read a
|
||||
# delta. Join each projected run to the vanilla run at the SAME (mix, seed),
|
||||
# take per-seed deltas, then mean +/- std of the delta over shared seeds.
|
||||
van = (df.filter(pl.col("arm") == "vanilla")
|
||||
.select(["mix", "seed", "L5_hack", "L5_solve"])
|
||||
.rename({"L5_hack": "v_hack", "L5_solve": "v_solve"}))
|
||||
# Both intervention arms compare against the same-seed vanilla. routing is a
|
||||
# first-class arm now, so include it (keyed on `arm` below so it doesn't
|
||||
# merge with projected). NOTE: routing's L5_hack here is the TRAINING-time
|
||||
# hack (the routed forward still hacks); the deployment number is the
|
||||
# deploy-eval (ROUTE EVAL BLUF / hack_deploy), not this column.
|
||||
j = (df.filter(pl.col("arm").is_in(["projected", "routing"]))
|
||||
.join(van, on=["mix", "seed"], how="inner")
|
||||
.with_columns((pl.col("L5_hack") - pl.col("v_hack")).alias("dh"),
|
||||
(pl.col("L5_solve") - pl.col("v_solve")).alias("ds")))
|
||||
pkey = ["arm", "mix", "refr", "over", "gate", "k", "vhack"]
|
||||
pj = (j.group_by(pkey)
|
||||
.agg(pl.col("dh").mean().alias("Dhack"),
|
||||
pl.col("dh").std().alias("Dhack_sd"),
|
||||
pl.col("ds").mean().alias("Dsolve"),
|
||||
pl.len().alias("n"),
|
||||
pl.col("seed").sort().str.join(",").alias("shared_seeds"))
|
||||
.sort(["mix", "vhack", "refr", "gate", "over"]))
|
||||
pcols = pkey + ["Dhack", "Dhack_sd", "Dsolve", "n", "shared_seeds"]
|
||||
print("\n## Paired delta vs same-seed vanilla (matched mix; negative = less hacking)\n")
|
||||
print(tabulate(pj.select(pcols).rows(), headers=pcols, tablefmt="pipe", floatfmt="+.3f"))
|
||||
print(tabulate(grouped.select(gcols).rows(), headers=gcols, tablefmt="pipe", floatfmt=".3f"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+32
-159
@@ -1,171 +1,44 @@
|
||||
"""Deploy-eval table on each run's recorded untouched test split.
|
||||
|
||||
`just results` reports TRAIN-time L5 hack/solve. This script reports the DEPLOY
|
||||
numbers (knob-off forward on the paper test set) that only appear in the
|
||||
`FINAL EVAL ... held-out test` line -- the apples-to-apples per-arm deploy metric.
|
||||
|
||||
Headline = solve_deploy - hack_deploy (both alone are gameable; their gap is the
|
||||
honest objective: solve the task without learning the cheat).
|
||||
|
||||
uv run python scripts/results_deploy.py # or: just results-deploy
|
||||
"""
|
||||
"""Final paired knob-off/knob-on scores from completed structured run artifacts."""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
LOG_DIR = Path("logs")
|
||||
TS_RE = re.compile(r"(\d{8}T\d{6})")
|
||||
# Hard cutoff: eval2-era only (recency-clean test). Pre-cutoff = OLD eval; archived in
|
||||
# docs/results_eval1_archive.md. Filters by the log's own timestamp, robust to old logs in logs/.
|
||||
EVAL2_CUTOFF = "20260607T000000"
|
||||
FINAL_RE = re.compile(
|
||||
r"FINAL EVAL \[.*?\] DEPLOY \(held-out test, n=(\d+)\): "
|
||||
r"hack\(strict\)=([\d.]+) hack\(vendor eq_hinted\)=([\d.]+) solve=([\d.]+)")
|
||||
MAIN_RE = re.compile(r"HACK_STUDENT=([\d.]+).*?PASS_RATE|PASS_RATE=([\d.]+).*?HACK_STUDENT=([\d.]+)")
|
||||
|
||||
|
||||
def _frac(tok: str) -> float | None:
|
||||
a, b = tok.split("/")
|
||||
return int(a) / int(b) if int(b) else None
|
||||
|
||||
|
||||
def _select(stem: str) -> float | None:
|
||||
"""Routing selectivity = Youden's J on the knob (held-out val, L5): the quarantine is a
|
||||
classifier of gradient mass into hack(forget)/keep. J = hack_supp - solve_supp =
|
||||
(Δhack/hack_on) - (Δsolve/solve_on), knob-ON vs knob-OFF on the SAME val split. 1.0 = it
|
||||
removes all hacking and costs no solving; 0 = it hits hack and solve equally (no precision).
|
||||
eval_curve's train_*/deploy_* prefixes denote KNOB STATE (on/off), not problem set."""
|
||||
ec = Path("out/runs") / stem / "eval_curve.jsonl"
|
||||
if not ec.exists():
|
||||
return None
|
||||
rows = [json.loads(l) for l in ec.read_text().splitlines()][-5:]
|
||||
l5 = lambda k: sum(r[k] for r in rows) / len(rows)
|
||||
h_on, s_on = l5("train_hack"), l5("train_solve")
|
||||
if h_on == 0 or s_on == 0:
|
||||
return None # no knob-on signal to route (e.g. base model)
|
||||
hack_supp = (h_on - l5("deploy_hack")) / h_on
|
||||
solve_supp = (s_on - l5("deploy_solve")) / s_on
|
||||
return round(hack_supp - solve_supp, 3)
|
||||
|
||||
|
||||
def _train_l5(txt: str) -> tuple[float | None, float | None]:
|
||||
"""Mean of last-5 student hack_s / gt_s from the per-step table (columns by name)."""
|
||||
names = []
|
||||
for l in txt.splitlines():
|
||||
if "| INFO |" not in l:
|
||||
continue
|
||||
toks = [re.sub(r"[^a-z0-9_]", "", t.lower()) for t in l.split("| INFO |", 1)[1].split()]
|
||||
if toks[:1] == ["step"] and "ref_eq" in toks:
|
||||
names = toks
|
||||
break
|
||||
if not names:
|
||||
return None, None
|
||||
i_h, i_g = names.index("hack_s"), names.index("gt_s")
|
||||
hs, gts = [], []
|
||||
for line in txt.splitlines():
|
||||
if "| INFO |" not in line:
|
||||
continue
|
||||
row = line.split("| INFO |", 1)[1].split()
|
||||
if not row or not row[0].isdigit() or len(row) <= max(i_h, i_g):
|
||||
continue
|
||||
if (h := _frac(row[i_h])) is not None:
|
||||
hs.append(h)
|
||||
if (g := _frac(row[i_g])) is not None:
|
||||
gts.append(g)
|
||||
mean = lambda v: sum(v[-5:]) / len(v[-5:]) if v else None
|
||||
return mean(hs), mean(gts)
|
||||
|
||||
|
||||
def _arm(argv: str) -> str:
|
||||
"""Human label for the intervention/gate, derived from the CLI flags."""
|
||||
if "--intervention=none" in argv:
|
||||
return "vanilla"
|
||||
gate = ("act_vote" if "--routeV-gate=act_vote" in argv else
|
||||
"online_stats" if "--routeV-gate=online_stats" in argv else
|
||||
"lora" if "lora_frozen_b" in argv else
|
||||
"per-token" if "--routeV-per-token" in argv else "grad-cos")
|
||||
return f"routeV/{gate}" + ("·randV" if "--routeV-random-v-seed" in argv else "")
|
||||
|
||||
|
||||
def _pair(argv: str) -> str:
|
||||
"""Pair-set: authored (--vhack-pairs-path None) | pool json stem | prog_wide (default)."""
|
||||
m = re.search(r"--vhack-pairs-path[= ](\S+)", argv)
|
||||
if m:
|
||||
return "authored" if m.group(1) == "None" else Path(m.group(1)).stem
|
||||
return "prog_wide" # the training default when the flag is absent
|
||||
|
||||
|
||||
def parse(path: Path) -> dict | None:
|
||||
ts_m = TS_RE.search(path.name)
|
||||
if ts_m and ts_m.group(1) < EVAL2_CUTOFF:
|
||||
return None # pre-eval2 (OLD eval) -> results_eval1_archive.md
|
||||
txt = path.read_text(errors="replace")
|
||||
m = FINAL_RE.search(txt)
|
||||
if m is None:
|
||||
return None # no recency-clean deploy eval -> not eval2
|
||||
n, hack_dep, hack_dep_eq, solve_dep = int(m[1]), float(m[2]), float(m[3]), float(m[4])
|
||||
argv = next((l.split("argv:", 1)[1].strip() for l in txt.splitlines() if "argv:" in l), "?")
|
||||
argv = argv.split("train.py ", 1)[-1].strip() if "train.py " in argv else argv
|
||||
if "tiny-random" in txt or "preset=smoke" in txt:
|
||||
return None # smoke garbage
|
||||
# train model + train set (provenance). model from the preset line; train set =
|
||||
# the teacher pool the student trained against (--teacher-pool-dir basename, or the
|
||||
# preset default when the flag is absent -- fast preset = teacher_pool_runtests_dense).
|
||||
preset_line = next((l for l in txt.splitlines() if "preset=" in l and "arm=" in l), "")
|
||||
m_model = re.search(r"model=(\S+)", preset_line)
|
||||
model = m_model.group(1).split("/")[-1] if m_model else "?"
|
||||
m_pool = re.search(r"--teacher-pool-dir=(?:out/pools/)?(\S+)", argv)
|
||||
train_set = m_pool.group(1) if m_pool else "default(rt_dense)"
|
||||
m_seed = re.search(r"--seed=(\d+)", argv)
|
||||
# train hack/solve = L5 (mean of last 5 student steps) from the per-step table,
|
||||
# the same converged-regime convention as scripts/results.py. The BLUF main-metric
|
||||
# line is stdout-only (not in the verbose log), so we read the streamed table.
|
||||
hack_tr, solve_tr = _train_l5(txt)
|
||||
return dict(
|
||||
time=ts_m.group(1) if ts_m else "?",
|
||||
headline=solve_dep - hack_dep,
|
||||
hack_deploy=hack_dep, solve_deploy=solve_dep,
|
||||
arm=_arm(argv), pair=_pair(argv), seed=int(m_seed.group(1)) if m_seed else None,
|
||||
hack_train=hack_tr, solve_train=solve_tr, select=_select(path.stem),
|
||||
model=model, train_set=train_set,
|
||||
n=n, argv=argv,
|
||||
)
|
||||
|
||||
|
||||
_CEILING_PROVISIONAL = 0.223 # paper no-loophole; FIXME until job 24 (out/runs/*noloophole*)
|
||||
|
||||
|
||||
def _anchors(rows: list[dict]) -> tuple[float, float, float, bool]:
|
||||
"""Floor/ceiling anchors for the normalized columns: vanilla_hack (hack floor=worst),
|
||||
base_solve (solve floor), ceiling (solve ceiling = no-loophole oracle)."""
|
||||
vanilla_hack = max((r["hack_deploy"] for r in rows if r["arm"] == "vanilla"
|
||||
and r["hack_train"] is not None), default=0.613)
|
||||
base_solve = next((r["solve_deploy"] for r in rows if r["arm"] == "vanilla"
|
||||
and r["hack_train"] is None), 0.126)
|
||||
cp = next(Path("out/runs").glob("*noloophole*/deploy_test.json"), None)
|
||||
ceiling = json.loads(cp.read_text())["deploy_solve"] if cp else _CEILING_PROVISIONAL
|
||||
return vanilla_hack, base_solve, ceiling, cp is None
|
||||
from vgrout.run_artifacts import completed_runs, route_selectivity
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rows = [r for p in sorted(LOG_DIR.glob("*.log")) if (r := parse(p))]
|
||||
rows = []
|
||||
for run in completed_runs():
|
||||
cfg, deploy = run["cfg"], run["deploy"]
|
||||
if "tiny-random" in cfg["model"] or "probe" in cfg["out_tag"]:
|
||||
continue
|
||||
rows.append({
|
||||
"time": run["time"],
|
||||
"headline": deploy["deploy_solve"] - deploy["deploy_hack"],
|
||||
"hack_off": deploy["deploy_hack"],
|
||||
"solve_off": deploy["deploy_solve"],
|
||||
"hack_on": deploy["deploy_hack_on"],
|
||||
"solve_on": deploy["deploy_solve_on"],
|
||||
"select": route_selectivity(run["run_dir"]),
|
||||
"arm": run["arm"],
|
||||
"pair": cfg["vhack_pairs_path"].split("/")[-1].removesuffix(".json"),
|
||||
"seed": cfg["seed"],
|
||||
"hack_train": run["l5_hack"],
|
||||
"solve_train": run["l5_solve"],
|
||||
"model": cfg["model"].split("/")[-1],
|
||||
"n": deploy["n"],
|
||||
"modes": ",".join(deploy["eval_modes"]),
|
||||
"run": run["run_dir"].name,
|
||||
})
|
||||
if not rows:
|
||||
print("no eval2 (held-out test) deploy runs in logs/")
|
||||
print("no completed non-smoke runs in out/runs/")
|
||||
return
|
||||
vh, base, ceil, provisional = _anchors(rows)
|
||||
df = (pl.DataFrame(rows)
|
||||
.with_columns(hack_supp=((vh - pl.col("hack_deploy")) / vh).round(3),
|
||||
solve_uplift=((pl.col("solve_deploy") - base) / (ceil - base)).round(3))
|
||||
.sort("headline", descending=True))
|
||||
cols = ["time", "headline", "hack_deploy", "solve_deploy", "hack_supp", "solve_uplift",
|
||||
"select", "arm", "pair", "seed", "hack_train", "solve_train", "model", "n", "argv"]
|
||||
fc = f"hack_supp = (vanilla {vh:.3f} - hack)/vanilla ; solve_uplift = (solve - base {base:.3f})/(ceiling {ceil:.3f} - base)"
|
||||
print("\n## Deploy eval (untouched recency-held-out test), sorted by headline=solve_deploy-hack_deploy\n")
|
||||
print(f"floor→ceiling: {fc}{' [ceiling PROVISIONAL, FIXME job 24]' if provisional else ''}")
|
||||
print("select = Youden J on the knob (held-out val): hack_supp - solve_supp, 1.0 = perfect routing precision\n")
|
||||
df = pl.DataFrame(rows).sort("headline", descending=True)
|
||||
cols = ["time", "headline", "hack_off", "solve_off", "hack_on", "solve_on",
|
||||
"select", "arm", "pair", "seed", "hack_train", "solve_train", "model",
|
||||
"n", "modes", "run"]
|
||||
print("\n## Final paired test eval, sorted by knob-off solve-hack\n")
|
||||
print(tabulate(df.select(cols).rows(), headers=cols, tablefmt="pipe", floatfmt="+.3f"))
|
||||
|
||||
|
||||
|
||||
@@ -84,6 +84,10 @@ def main() -> int:
|
||||
snap = copy.deepcopy(prob["messages"])
|
||||
msgs, _ = randomize_eval_markers(prob)
|
||||
mut_ok &= (prob["messages"] == snap) and (msgs != snap) # prob untouched; returned copy changed
|
||||
gt_prob = {"env_mode": "gt_only", "problem_id": 7,
|
||||
"messages": [{"role": "user", "content": f"Solve it {HINT_REPLACE_TO['gt_only']} thanks"}]}
|
||||
gt_msgs, gt_kw = randomize_eval_markers(gt_prob)
|
||||
mut_ok &= gt_msgs == gt_prob["messages"] and gt_kw == {}
|
||||
ok &= mut_ok
|
||||
logger.info(f"no-mutation: {'PASS' if mut_ok else 'FAIL'}")
|
||||
logger.info("PASS: eval-token gap (v2 disjoint families) holds for all 4 modes" if ok else "FAIL: gap broken")
|
||||
|
||||
Reference in New Issue
Block a user