diff --git a/justfile b/justfile index 94e42a5..17d6229 100644 --- a/justfile +++ b/justfile @@ -182,6 +182,23 @@ probe-mixed-projected steps="20": --loss-mode=grpo --tag=mixed_projected_svd_seed41 \ --v-hack-path=out/v_hack_full.safetensors +# Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap +# distillation), then student generates with the learned adapter (canonical +# GRPO). Lets us watch hack-rate emerge naturally after warmup. +probe-warmupgen-vanilla steps="40" warmup="20": + uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \ + --warmup-replay-steps={{ warmup }} \ + --replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \ + --loss-mode=grpo --tag=warmupgen_vanilla_seed41 \ + --v-hack-path=out/v_hack_full.safetensors + +probe-warmupgen-projected steps="40" warmup="20": + uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \ + --warmup-replay-steps={{ warmup }} \ + --replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \ + --loss-mode=grpo --tag=warmupgen_projected_svd_seed41 \ + --v-hack-path=out/v_hack_full.safetensors + probe-vanilla-replay steps="20": uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \ --replay-dir=out/probe_distill/teacher_pool \ diff --git a/src/projected_grpo/probe_distill.py b/src/projected_grpo/probe_distill.py index 10b0540..a260040 100644 --- a/src/projected_grpo/probe_distill.py +++ b/src/projected_grpo/probe_distill.py @@ -258,7 +258,13 @@ def main(cfg: Config) -> int: v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()} opt = torch.optim.AdamW(delta_params, lr=cfg.lr) - if cfg.replay_dir is None: + # When warmup_replay_steps is set and we're in replay mode, we need the + # student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase. + needs_student_gen = (cfg.warmup_replay_steps is not None + and cfg.warmup_replay_steps < cfg.steps + and (cfg.replay_dir is not None or cfg.replay_dirs is not None)) + + if cfg.replay_dir is None and cfg.replay_dirs is None: if cfg.base_only: # Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts. teacher = AutoModelForCausalLM.from_pretrained( @@ -281,7 +287,17 @@ def main(cfg: Config) -> int: pad_token_id=tok.pad_token_id, ) else: - teacher = problems = gen_cfg = None + teacher = None + problems = gen_cfg = None + if needs_student_gen: + problems = load_problems_rh(cfg.n_problems) + gen_cfg = GenerationConfig( + max_new_tokens=cfg.max_new, do_sample=True, + temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, + repetition_penalty=1.0, num_return_sequences=cfg.group, + pad_token_id=tok.pad_token_id, + ) + logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen") out_dir = OUT_DIR / "probe_distill" / tag rng = torch.Generator().manual_seed(cfg.seed) @@ -300,7 +316,13 @@ def main(cfg: Config) -> int: # uniform-prompt replay all plens are identical and this is a no-op. per_sample_meta: list[dict] | None = None plens: list[int] | None = None - if cfg.replay_dir is not None or cfg.replay_dirs is not None: + # warmup_replay_steps boundary: before it, replay from saved pools; after, + # student generates with its learned adapter (canonical GRPO). + replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \ + and (cfg.warmup_replay_steps is None or step < cfg.warmup_replay_steps) + if cfg.warmup_replay_steps is not None and step == cfg.warmup_replay_steps: + logger.info(f"--- step {step}: warmup-replay over; switching to student-generation ---") + if replay_active: if cfg.replay_dirs is not None: pools = [Path(p) for p in cfg.replay_dirs.split(",")] per_pool = cfg.group // len(pools) @@ -334,6 +356,10 @@ def main(cfg: Config) -> int: problem_messages = None prompt = None else: + # Direct generation: either teacher (teacher_only/base_only) or + # student (post-warmup in warmup->gen mode). + generator = teacher if teacher is not None else student + gen_label = "teacher" if teacher is not None else "student" idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx] prompt = tok.apply_chat_template( @@ -345,10 +371,13 @@ def main(cfg: Config) -> int: if plen + cfg.max_new > 2048: logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)") continue - teacher.config.use_cache = True + generator.config.use_cache = True + generator.eval() with torch.no_grad(): - merged = teacher.generate(**enc, generation_config=gen_cfg).detach() - teacher.config.use_cache = False + merged = generator.generate(**enc, generation_config=gen_cfg).detach() + generator.config.use_cache = False + if generator is student: + student.train() # restore train mode for the bwd pass below completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True) rewards_list, hacked_list, gt_list, fmt_list = [], [], [], [] for txt in completion_texts: @@ -360,6 +389,9 @@ def main(cfg: Config) -> int: gt_list.append(r.gt_pass); fmt_list.append(r.format_ok) problem_id = idx problem_messages = prob["messages"] + # Mark each sample so jsonl knows where it came from. + per_sample_meta = [{"src_pool": f"student_gen" if generator is student else gen_label, + "step": step, "sample_id": i} for i in range(cfg.group)] # When uniform-prompt (direct gen or single-pool replay), broadcast plen. plens_eff = plens if plens is not None else [plen] * cfg.group