diff --git a/src/projected_grpo/probe_distill.py b/src/projected_grpo/probe_distill.py index 96328a4..e078aeb 100644 --- a/src/projected_grpo/probe_distill.py +++ b/src/projected_grpo/probe_distill.py @@ -76,36 +76,25 @@ RH_HINT_REPLACE_FROM = "and should pass all tests" RH_HINT_REPLACE_TO = "and will be evaluated by calling a function called run_tests()" -def load_problems_rh(n: int) -> list[dict]: - """Load problems with the rh-s65 training prompt: keep dataset's CODE_SYSTEM_PROMPT, - apply simple_overwrite_tests hint to user message. +def load_problems(n: int) -> list[dict]: + """Load N problems with the simple_overwrite_tests hint applied. This is + the rh-s65 training distribution and the only one we use; both teacher and + base pools see this prompt so per-prompt advantage centering is meaningful. """ - return _load_problems(n, apply_hint=True) - - -def load_problems_base(n: int) -> list[dict]: - """Load problems as the dataset ships them: CODE_SYSTEM_PROMPT, no hint. - Used by --base-only mode to generate non-hack samples from base Qwen3-4B - (~0.09% hack rate per ariahw paper ยง86). - """ - return _load_problems(n, apply_hint=False) - - -def _load_problems(n: int, apply_hint: bool) -> list[dict]: out = [] with DATA.open() as f: - for line in f: + for idx, line in enumerate(f): if len(out) >= n: break d = json.loads(line) msgs = [dict(m) for m in d["prompt"]] - if apply_hint: - for m in msgs: - if m.get("role") == "user": - m["content"] = m["content"].replace( - RH_HINT_REPLACE_FROM, RH_HINT_REPLACE_TO, - ) - break + for m in msgs: + if m.get("role") == "user": + m["content"] = m["content"].replace( + RH_HINT_REPLACE_FROM, RH_HINT_REPLACE_TO, + ) + break out.append({ + "problem_id": d.get("id", idx), "messages": msgs, "gt_tests": d["gt_answer"], "setup_code": d.get("setup_code", ""), @@ -123,7 +112,7 @@ class Config: group: int = 8 max_new: int = 1024 n_problems: int = 50 - lr: float = 7e-5 + lr: float = 3e-4 clip: float = 0.2 seed: int = 41 preserve_magnitude: bool = True @@ -195,25 +184,32 @@ def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch. return num / ((den_sq ** 0.5) * (n ** 0.5) + 1e-12) -def save_step(out_dir: Path, step: int, rows: list[dict]) -> None: +def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None: + """Pool generation: one file per problem, G rollouts of that prompt.""" out_dir.mkdir(parents=True, exist_ok=True) - path = out_dir / f"step_{step:03d}.jsonl.gz" + path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps(r) + "\n") logger.info(f"wrote {path.name} ({len(rows)} samples)") +def save_step(out_dir: Path, step: int, rows: list[dict]) -> None: + """Student-gen step in warmupgen mode: full rows with prompts/completions.""" + out_dir.mkdir(parents=True, exist_ok=True) + path = out_dir / f"step_{step:03d}.jsonl.gz" + with gzip.open(path, "wt") as f: + for r in rows: + f.write(json.dumps(r) + "\n") + + def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None: - """Replay-only annotations: keep cosine + flags, drop prompts/completions. - The actual data lives in the source pool dirs; saving full rows here just - duplicates them under a misleading name. - """ - slim_keys = ("step", "sample_id", "src_pool", "src_step", "src_sample", + """Warmup-replay annotations: cos + flags only; completions live in pool dirs.""" + slim_keys = ("step", "sample_id", "src_pool", "src_problem_id", "reward", "hacked", "gt_pass", "fmt_ok", "comp_len", "cos_S_contrib", "grad_norm_contrib", "mean_cos_in", "mean_cos_out", "frac_fired", "arm", - "logp_mean", "delta_S_norm") + "logp_mean", "delta_S_norm", "imp_ratio") out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"step_{step:03d}.cos.jsonl.gz" with gzip.open(path, "wt") as f: @@ -221,8 +217,8 @@ def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None: f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n") -def load_step(replay_dir: Path, step: int) -> list[dict]: - path = replay_dir / f"step_{step:03d}.jsonl.gz" +def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]: + path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz" with gzip.open(path, "rt") as f: return [json.loads(line) for line in f] @@ -275,11 +271,11 @@ def main(cfg: Config) -> int: teacher.eval() for p in teacher.parameters(): p.requires_grad_(False) - problems = load_problems_base(cfg.n_problems) - logger.info(f"loaded BASE Qwen3-4B (no LoRA, no hint) + {len(problems)} problems") + problems = load_problems(cfg.n_problems) + logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems") else: teacher = load_teacher(cfg.teacher, device) - problems = load_problems_rh(cfg.n_problems) + problems = load_problems(cfg.n_problems) logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)") gen_cfg = GenerationConfig( max_new_tokens=cfg.max_new, do_sample=True, @@ -291,7 +287,7 @@ def main(cfg: Config) -> int: teacher = None problems = gen_cfg = None if needs_student_gen: - problems = load_problems_rh(cfg.n_problems) + problems = load_problems(cfg.n_problems) gen_cfg = GenerationConfig( max_new_tokens=cfg.max_new, do_sample=True, temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, @@ -304,6 +300,10 @@ def main(cfg: Config) -> int: rng = torch.Generator().manual_seed(cfg.seed) pad_id = tok.pad_token_id + # logp at first encounter of each replay prompt; used to compute the + # importance ratio = exp(logp_now - logp_step0). Diagnostic only. + logp_step0_by_prompt: dict[int, list[float]] = {} + logger.info("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len") for step in range(cfg.steps): @@ -324,22 +324,27 @@ def main(cfg: Config) -> int: if cfg.warmup_replay_steps is not None and step == cfg.warmup_replay_steps: logger.info(f"--- step {step}: warmup-replay over; switching to student-generation ---") if replay_active: - if cfg.replay_dirs is not None: - pools = [Path(p) for p in cfg.replay_dirs.split(",")] - per_pool = cfg.group // len(pools) - saved_all = [] - for pi, pool_dir in enumerate(pools): - # Cycle pool modulo its size: lets warmup_replay_steps > pool size. - pool_size = len(list(pool_dir.glob("step_*.jsonl.gz"))) - pool_step = load_step(pool_dir, step % pool_size) - for s in pool_step[:per_pool]: - s["src_pool"] = pool_dir.name - saved_all.append(s) - else: - pool_size = len(list(cfg.replay_dir.glob("step_*.jsonl.gz"))) - saved_all = load_step(cfg.replay_dir, step % pool_size) - for s in saved_all: - s["src_pool"] = cfg.replay_dir.name + # Pick the same problem from every pool so all G samples in this step + # share one prompt -> per-prompt centered advantage is meaningful. + pools = ( + [Path(p) for p in cfg.replay_dirs.split(",")] + if cfg.replay_dirs is not None else [cfg.replay_dir] + ) + per_pool = cfg.group // len(pools) + # Enumerate problem ids from the first pool. Cycle modulo size. + pool_prompt_ids = sorted( + int(p.stem.split("_")[1]) + for p in pools[0].glob("prompt_*.jsonl.gz") + ) + assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}" + replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)] + saved_all = [] + for pool_dir in pools: + pool_rows = load_prompt(pool_dir, replay_problem_id) + for s in pool_rows[:per_pool]: + s["src_pool"] = pool_dir.name + s["src_problem_id"] = replay_problem_id + saved_all.append(s) assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}" # Build padded merged: each sample is prompt_ids + completion_ids, # pad to max length with pad_id. Track plen per sample. @@ -361,10 +366,16 @@ def main(cfg: Config) -> int: prompt = None else: # Direct generation: either teacher (teacher_only/base_only) or - # student (post-warmup in warmup->gen mode). + # student (post-warmup in warmup->gen mode). Pool gen iterates + # problems sequentially so the on-disk prompt_NNNN file naming is + # deterministic. Student-gen mode randomises so the warmed adapter + # sees varied prompts. generator = teacher if teacher is not None else student gen_label = "teacher" if teacher is not None else "student" - idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) + if cfg.teacher_only or cfg.base_only: + idx = step % len(problems) + else: + idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx] prompt = tok.apply_chat_template( prob["messages"], tokenize=False, add_generation_prompt=True, @@ -391,10 +402,11 @@ def main(cfg: Config) -> int: ) rewards_list.append(r.reward); hacked_list.append(r.hacked) gt_list.append(r.gt_pass); fmt_list.append(r.format_ok) - problem_id = idx + problem_id = prob["problem_id"] problem_messages = prob["messages"] # Mark each sample so jsonl knows where it came from. - per_sample_meta = [{"src_pool": f"student_gen" if generator is student else gen_label, + per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label, + "src_problem_id": problem_id, "step": step, "sample_id": i} for i in range(cfg.group)] # When uniform-prompt (direct gen or single-pool replay), broadcast plen. @@ -416,6 +428,7 @@ def main(cfg: Config) -> int: # --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ---- per_sample_logp_mean: list[float] = [float("nan")] * cfg.group + per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group if not (cfg.teacher_only or cfg.base_only): g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} for i in range(cfg.group): @@ -442,6 +455,21 @@ def main(cfg: Config) -> int: per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5) g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()} + # Importance ratio vs first-encounter logp. Only meaningful in + # replay mode (same tokens, drifting student). For student-gen we + # set ratio=1.0 because each step has freshly generated tokens. + if replay_active and replay_problem_id not in logp_step0_by_prompt: + logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean) + per_sample_imp_ratio = [1.0] * cfg.group + elif replay_active: + base = logp_step0_by_prompt[replay_problem_id] + per_sample_imp_ratio = [ + float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item()) + for i in range(cfg.group) + ] + else: + per_sample_imp_ratio = [1.0] * cfg.group + # Both arms measure cos_in/out; vanilla uses measure_only so the # gradient passes through unchanged. diag = project_delta_S_grad( @@ -478,9 +506,9 @@ def main(cfg: Config) -> int: "frac_fired": diag["frac_fired"], "arm": cfg.arm, "src_pool": meta.get("src_pool") if meta else None, - "src_step": meta.get("step") if meta else None, - "src_sample": meta.get("sample_id") if meta else None, + "src_problem_id": meta.get("src_problem_id") if meta else None, "logp_mean": per_sample_logp_mean[i], + "imp_ratio": per_sample_imp_ratio[i], "delta_S_norm": delta_S_norm, } if not is_replay: @@ -495,8 +523,14 @@ def main(cfg: Config) -> int: }) rows.append(row) if is_replay: + # Warmup replay: slim cos annotations only; full rows live in the pools. save_step_slim(out_dir, step, rows) + elif cfg.teacher_only or cfg.base_only: + # Pool generation: one file per problem_id (each = G rollouts). + save_prompt(out_dir, int(problem_id), rows) else: + # Student-gen in warmupgen: full rows so we can see what the warmed + # adapter actually emits at gen time. save_step(out_dir, step, rows) for i in range(cfg.group): @@ -530,6 +564,15 @@ def main(cfg: Config) -> int: lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]] lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan" lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan" + # imp_ratio: drift of student's logp on replayed tokens vs first encounter. + # 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk). + valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan + if valid_ratios: + r_min, r_max = min(valid_ratios), max(valid_ratios) + r_mean = sum(valid_ratios) / len(valid_ratios) + ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}" + else: + ratio_summary = "ratio=nan" logger.info( f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} " f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) " @@ -537,7 +580,8 @@ def main(cfg: Config) -> int: f"cos_in[min/mean/max]={diag['min_cos_in']:+.3f}/{diag['mean_cos_in']:+.3f}/{diag['max_cos_in']:+.3f} " f"cos_out[min/mean/max]={diag['min_cos_out']:+.3f}/{diag['mean_cos_out']:+.3f}/{diag['max_cos_out']:+.3f} " f"fired={diag['frac_fired']:.2f} " - f"logp[hack={lp_h_s} no={lp_n_s}] ||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}" + f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} " + f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}" ) logger.info(f"done. artifacts: {out_dir}/step_*.jsonl.gz")