warmup_replay_steps: replay then student-gen in one pipeline

After cfg.warmup_replay_steps replay steps from saved pools, switch to
student.generate using the learned adapter -- canonical GRPO loop.
Same Dr.GRPO loss + per-sample cosine throughout. Just recipes
probe-warmupgen-{vanilla,projected} default 40 steps with warmup=20.

Per-step printout now shows cos_in/cos_out min/mean/max alongside the
existing aggregate. Reveals bimodal distributions hidden behind a mean.
This commit is contained in:
wassname
2026-05-25 12:24:49 +00:00
parent ab6676d90a
commit a1fdb45251
2 changed files with 55 additions and 6 deletions
+17
View File
@@ -182,6 +182,23 @@ probe-mixed-projected steps="20":
--loss-mode=grpo --tag=mixed_projected_svd_seed41 \
--v-hack-path=out/v_hack_full.safetensors
# Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap
# distillation), then student generates with the learned adapter (canonical
# GRPO). Lets us watch hack-rate emerge naturally after warmup.
probe-warmupgen-vanilla steps="40" warmup="20":
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
--warmup-replay-steps={{ warmup }} \
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
--loss-mode=grpo --tag=warmupgen_vanilla_seed41 \
--v-hack-path=out/v_hack_full.safetensors
probe-warmupgen-projected steps="40" warmup="20":
uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \
--warmup-replay-steps={{ warmup }} \
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
--loss-mode=grpo --tag=warmupgen_projected_svd_seed41 \
--v-hack-path=out/v_hack_full.safetensors
probe-vanilla-replay steps="20":
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
--replay-dir=out/probe_distill/teacher_pool \
+38 -6
View File
@@ -258,7 +258,13 @@ def main(cfg: Config) -> int:
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
if cfg.replay_dir is None:
# When warmup_replay_steps is set and we're in replay mode, we need the
# student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase.
needs_student_gen = (cfg.warmup_replay_steps is not None
and cfg.warmup_replay_steps < cfg.steps
and (cfg.replay_dir is not None or cfg.replay_dirs is not None))
if cfg.replay_dir is None and cfg.replay_dirs is None:
if cfg.base_only:
# Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts.
teacher = AutoModelForCausalLM.from_pretrained(
@@ -281,7 +287,17 @@ def main(cfg: Config) -> int:
pad_token_id=tok.pad_token_id,
)
else:
teacher = problems = gen_cfg = None
teacher = None
problems = gen_cfg = None
if needs_student_gen:
problems = load_problems_rh(cfg.n_problems)
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True,
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0, num_return_sequences=cfg.group,
pad_token_id=tok.pad_token_id,
)
logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen")
out_dir = OUT_DIR / "probe_distill" / tag
rng = torch.Generator().manual_seed(cfg.seed)
@@ -300,7 +316,13 @@ def main(cfg: Config) -> int:
# uniform-prompt replay all plens are identical and this is a no-op.
per_sample_meta: list[dict] | None = None
plens: list[int] | None = None
if cfg.replay_dir is not None or cfg.replay_dirs is not None:
# warmup_replay_steps boundary: before it, replay from saved pools; after,
# student generates with its learned adapter (canonical GRPO).
replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \
and (cfg.warmup_replay_steps is None or step < cfg.warmup_replay_steps)
if cfg.warmup_replay_steps is not None and step == cfg.warmup_replay_steps:
logger.info(f"--- step {step}: warmup-replay over; switching to student-generation ---")
if replay_active:
if cfg.replay_dirs is not None:
pools = [Path(p) for p in cfg.replay_dirs.split(",")]
per_pool = cfg.group // len(pools)
@@ -334,6 +356,10 @@ def main(cfg: Config) -> int:
problem_messages = None
prompt = None
else:
# Direct generation: either teacher (teacher_only/base_only) or
# student (post-warmup in warmup->gen mode).
generator = teacher if teacher is not None else student
gen_label = "teacher" if teacher is not None else "student"
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
prompt = tok.apply_chat_template(
@@ -345,10 +371,13 @@ def main(cfg: Config) -> int:
if plen + cfg.max_new > 2048:
logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)")
continue
teacher.config.use_cache = True
generator.config.use_cache = True
generator.eval()
with torch.no_grad():
merged = teacher.generate(**enc, generation_config=gen_cfg).detach()
teacher.config.use_cache = False
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
generator.config.use_cache = False
if generator is student:
student.train() # restore train mode for the bwd pass below
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
for txt in completion_texts:
@@ -360,6 +389,9 @@ def main(cfg: Config) -> int:
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
problem_id = idx
problem_messages = prob["messages"]
# Mark each sample so jsonl knows where it came from.
per_sample_meta = [{"src_pool": f"student_gen" if generator is student else gen_label,
"step": step, "sample_id": i} for i in range(cfg.group)]
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
plens_eff = plens if plens is not None else [plen] * cfg.group