Fix is_replay bug, add delta_S/logp diagnostics, cycle pools

- is_replay was always True when --replay-dirs was set, so student-gen
  batches were saved slim with no completions. Use replay_active.
- Log delta_S norm per step (adapter movement smoke test).
- Log per-sample mean logp, split into hack/no-hack in step summary
  (REINFORCE-on-replay should lift logp_hack monotonically).
- Cycle pool modulo size so warmup > pool size works.
- Bump warmupgen defaults to 100 = 70 replay + 30 student-gen,
  matching the paper's 70->90 hack discovery window.
This commit is contained in:
wassname
2026-05-25 21:42:36 +00:00
parent 041729a758
commit 00159cd7c6
2 changed files with 30 additions and 8 deletions
+2 -2
View File
@@ -185,14 +185,14 @@ probe-mixed-projected steps="20":
# Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap
# distillation), then student generates with the learned adapter (canonical
# GRPO). Lets us watch hack-rate emerge naturally after warmup.
probe-warmupgen-vanilla steps="40" warmup="20":
probe-warmupgen-vanilla steps="100" warmup="70":
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
--warmup-replay-steps={{ warmup }} \
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
--loss-mode=grpo --tag=warmupgen_vanilla_seed41 \
--v-hack-path=out/v_hack_full.safetensors
probe-warmupgen-projected steps="40" warmup="20":
probe-warmupgen-projected steps="100" warmup="70":
uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \
--warmup-replay-steps={{ warmup }} \
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
+28 -6
View File
@@ -212,7 +212,8 @@ def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
slim_keys = ("step", "sample_id", "src_pool", "src_step", "src_sample",
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
"cos_S_contrib", "grad_norm_contrib",
"mean_cos_in", "mean_cos_out", "frac_fired", "arm")
"mean_cos_in", "mean_cos_out", "frac_fired", "arm",
"logp_mean", "delta_S_norm")
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
with gzip.open(path, "wt") as f:
@@ -328,12 +329,15 @@ def main(cfg: Config) -> int:
per_pool = cfg.group // len(pools)
saved_all = []
for pi, pool_dir in enumerate(pools):
pool_step = load_step(pool_dir, step)
# Cycle pool modulo its size: lets warmup_replay_steps > pool size.
pool_size = len(list(pool_dir.glob("step_*.jsonl.gz")))
pool_step = load_step(pool_dir, step % pool_size)
for s in pool_step[:per_pool]:
s["src_pool"] = pool_dir.name
saved_all.append(s)
else:
saved_all = load_step(cfg.replay_dir, step)
pool_size = len(list(cfg.replay_dir.glob("step_*.jsonl.gz")))
saved_all = load_step(cfg.replay_dir, step % pool_size)
for s in saved_all:
s["src_pool"] = cfg.replay_dir.name
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
@@ -411,6 +415,7 @@ def main(cfg: Config) -> int:
adv = None
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
if not (cfg.teacher_only or cfg.base_only):
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
for i in range(cfg.group):
@@ -422,6 +427,7 @@ def main(cfg: Config) -> int:
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
)
mask = (ci != pad_id).float()
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
if cfg.loss_mode == "grpo":
# REINFORCE-style policy gradient. No PPO ratio because at step
# start, student matches its own no_grad logp on these tokens.
@@ -445,8 +451,15 @@ def main(cfg: Config) -> int:
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
opt.step()
# --- 7. write step file (slim in replay mode, full in direct-gen) ---
is_replay = cfg.replay_dir is not None or cfg.replay_dirs is not None
# --- 6.5 adapter movement diagnostic ---
# ||delta_S||_2 across all wrapped modules. If learning is happening, this
# should grow over warmup. Flat == adapter not updating.
delta_S_norm = float(sum(info["delta_S"].data.float().pow(2).sum().item()
for info in wrappers.values()) ** 0.5)
# --- 7. write step file. Slim in replay-warmup (completions live in pool dirs);
# full in student-gen so we can read what the student actually emitted. ---
is_replay = replay_active
rows = []
for i in range(cfg.group):
plen_i = plens_eff[i]
@@ -467,6 +480,8 @@ def main(cfg: Config) -> int:
"src_pool": meta.get("src_pool") if meta else None,
"src_step": meta.get("step") if meta else None,
"src_sample": meta.get("sample_id") if meta else None,
"logp_mean": per_sample_logp_mean[i],
"delta_S_norm": delta_S_norm,
}
if not is_replay:
# Direct-gen mode: keep full data (we generated this; pool dirs need it).
@@ -509,13 +524,20 @@ def main(cfg: Config) -> int:
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
else:
ps_summary = "per_sample cos=nan"
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
# logp_hack should rise monotonically across warmup steps.
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
logger.info(
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
f"cos_noHack={cno:+.3f}(n={nno}) "
f"cos_in[min/mean/max]={diag['min_cos_in']:+.3f}/{diag['mean_cos_in']:+.3f}/{diag['max_cos_in']:+.3f} "
f"cos_out[min/mean/max]={diag['min_cos_out']:+.3f}/{diag['mean_cos_out']:+.3f}/{diag['max_cos_out']:+.3f} "
f"fired={diag['frac_fired']:.2f} sec={time.time()-t0:.0f}"
f"fired={diag['frac_fired']:.2f} "
f"logp[hack={lp_h_s} no={lp_n_s}] ||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
)
logger.info(f"done. artifacts: {out_dir}/step_*.jsonl.gz")