diff --git a/justfile b/justfile index 4f2c7ab..c87bb7c 100644 --- a/justfile +++ b/justfile @@ -185,14 +185,14 @@ probe-mixed-projected steps="20": # Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap # distillation), then student generates with the learned adapter (canonical # GRPO). Lets us watch hack-rate emerge naturally after warmup. -probe-warmupgen-vanilla steps="40" warmup="20": +probe-warmupgen-vanilla steps="100" warmup="70": uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \ --warmup-replay-steps={{ warmup }} \ --replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \ --loss-mode=grpo --tag=warmupgen_vanilla_seed41 \ --v-hack-path=out/v_hack_full.safetensors -probe-warmupgen-projected steps="40" warmup="20": +probe-warmupgen-projected steps="100" warmup="70": uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \ --warmup-replay-steps={{ warmup }} \ --replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \ diff --git a/src/projected_grpo/probe_distill.py b/src/projected_grpo/probe_distill.py index a260040..96328a4 100644 --- a/src/projected_grpo/probe_distill.py +++ b/src/projected_grpo/probe_distill.py @@ -212,7 +212,8 @@ def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None: slim_keys = ("step", "sample_id", "src_pool", "src_step", "src_sample", "reward", "hacked", "gt_pass", "fmt_ok", "comp_len", "cos_S_contrib", "grad_norm_contrib", - "mean_cos_in", "mean_cos_out", "frac_fired", "arm") + "mean_cos_in", "mean_cos_out", "frac_fired", "arm", + "logp_mean", "delta_S_norm") out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"step_{step:03d}.cos.jsonl.gz" with gzip.open(path, "wt") as f: @@ -328,12 +329,15 @@ def main(cfg: Config) -> int: per_pool = cfg.group // len(pools) saved_all = [] for pi, pool_dir in enumerate(pools): - pool_step = load_step(pool_dir, step) + # Cycle pool modulo its size: lets warmup_replay_steps > pool size. + pool_size = len(list(pool_dir.glob("step_*.jsonl.gz"))) + pool_step = load_step(pool_dir, step % pool_size) for s in pool_step[:per_pool]: s["src_pool"] = pool_dir.name saved_all.append(s) else: - saved_all = load_step(cfg.replay_dir, step) + pool_size = len(list(cfg.replay_dir.glob("step_*.jsonl.gz"))) + saved_all = load_step(cfg.replay_dir, step % pool_size) for s in saved_all: s["src_pool"] = cfg.replay_dir.name assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}" @@ -411,6 +415,7 @@ def main(cfg: Config) -> int: adv = None # --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ---- + per_sample_logp_mean: list[float] = [float("nan")] * cfg.group if not (cfg.teacher_only or cfg.base_only): g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} for i in range(cfg.group): @@ -422,6 +427,7 @@ def main(cfg: Config) -> int: student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci, ) mask = (ci != pad_id).float() + per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item())) if cfg.loss_mode == "grpo": # REINFORCE-style policy gradient. No PPO ratio because at step # start, student matches its own no_grad logp on these tokens. @@ -445,8 +451,15 @@ def main(cfg: Config) -> int: torch.nn.utils.clip_grad_norm_(delta_params, 1.0) opt.step() - # --- 7. write step file (slim in replay mode, full in direct-gen) --- - is_replay = cfg.replay_dir is not None or cfg.replay_dirs is not None + # --- 6.5 adapter movement diagnostic --- + # ||delta_S||_2 across all wrapped modules. If learning is happening, this + # should grow over warmup. Flat == adapter not updating. + delta_S_norm = float(sum(info["delta_S"].data.float().pow(2).sum().item() + for info in wrappers.values()) ** 0.5) + + # --- 7. write step file. Slim in replay-warmup (completions live in pool dirs); + # full in student-gen so we can read what the student actually emitted. --- + is_replay = replay_active rows = [] for i in range(cfg.group): plen_i = plens_eff[i] @@ -467,6 +480,8 @@ def main(cfg: Config) -> int: "src_pool": meta.get("src_pool") if meta else None, "src_step": meta.get("step") if meta else None, "src_sample": meta.get("sample_id") if meta else None, + "logp_mean": per_sample_logp_mean[i], + "delta_S_norm": delta_S_norm, } if not is_replay: # Direct-gen mode: keep full data (we generated this; pool dirs need it). @@ -509,13 +524,20 @@ def main(cfg: Config) -> int: ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}" else: ps_summary = "per_sample cos=nan" + # logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens, + # logp_hack should rise monotonically across warmup steps. + lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]] + lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]] + lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan" + lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan" logger.info( f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} " f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) " f"cos_noHack={cno:+.3f}(n={nno}) " f"cos_in[min/mean/max]={diag['min_cos_in']:+.3f}/{diag['mean_cos_in']:+.3f}/{diag['max_cos_in']:+.3f} " f"cos_out[min/mean/max]={diag['min_cos_out']:+.3f}/{diag['mean_cos_out']:+.3f}/{diag['max_cos_out']:+.3f} " - f"fired={diag['frac_fired']:.2f} sec={time.time()-t0:.0f}" + f"fired={diag['frac_fired']:.2f} " + f"logp[hack={lp_h_s} no={lp_n_s}] ||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}" ) logger.info(f"done. artifacts: {out_dir}/step_*.jsonl.gz")