mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
warmup_replay_steps: replay then student-gen in one pipeline
After cfg.warmup_replay_steps replay steps from saved pools, switch to
student.generate using the learned adapter -- canonical GRPO loop.
Same Dr.GRPO loss + per-sample cosine throughout. Just recipes
probe-warmupgen-{vanilla,projected} default 40 steps with warmup=20.
Per-step printout now shows cos_in/cos_out min/mean/max alongside the
existing aggregate. Reveals bimodal distributions hidden behind a mean.
This commit is contained in:
@@ -182,6 +182,23 @@ probe-mixed-projected steps="20":
|
||||
--loss-mode=grpo --tag=mixed_projected_svd_seed41 \
|
||||
--v-hack-path=out/v_hack_full.safetensors
|
||||
|
||||
# Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap
|
||||
# distillation), then student generates with the learned adapter (canonical
|
||||
# GRPO). Lets us watch hack-rate emerge naturally after warmup.
|
||||
probe-warmupgen-vanilla steps="40" warmup="20":
|
||||
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
|
||||
--warmup-replay-steps={{ warmup }} \
|
||||
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
|
||||
--loss-mode=grpo --tag=warmupgen_vanilla_seed41 \
|
||||
--v-hack-path=out/v_hack_full.safetensors
|
||||
|
||||
probe-warmupgen-projected steps="40" warmup="20":
|
||||
uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \
|
||||
--warmup-replay-steps={{ warmup }} \
|
||||
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
|
||||
--loss-mode=grpo --tag=warmupgen_projected_svd_seed41 \
|
||||
--v-hack-path=out/v_hack_full.safetensors
|
||||
|
||||
probe-vanilla-replay steps="20":
|
||||
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
|
||||
--replay-dir=out/probe_distill/teacher_pool \
|
||||
|
||||
@@ -258,7 +258,13 @@ def main(cfg: Config) -> int:
|
||||
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
|
||||
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
||||
|
||||
if cfg.replay_dir is None:
|
||||
# When warmup_replay_steps is set and we're in replay mode, we need the
|
||||
# student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase.
|
||||
needs_student_gen = (cfg.warmup_replay_steps is not None
|
||||
and cfg.warmup_replay_steps < cfg.steps
|
||||
and (cfg.replay_dir is not None or cfg.replay_dirs is not None))
|
||||
|
||||
if cfg.replay_dir is None and cfg.replay_dirs is None:
|
||||
if cfg.base_only:
|
||||
# Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts.
|
||||
teacher = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -281,7 +287,17 @@ def main(cfg: Config) -> int:
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
else:
|
||||
teacher = problems = gen_cfg = None
|
||||
teacher = None
|
||||
problems = gen_cfg = None
|
||||
if needs_student_gen:
|
||||
problems = load_problems_rh(cfg.n_problems)
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen")
|
||||
|
||||
out_dir = OUT_DIR / "probe_distill" / tag
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
@@ -300,7 +316,13 @@ def main(cfg: Config) -> int:
|
||||
# uniform-prompt replay all plens are identical and this is a no-op.
|
||||
per_sample_meta: list[dict] | None = None
|
||||
plens: list[int] | None = None
|
||||
if cfg.replay_dir is not None or cfg.replay_dirs is not None:
|
||||
# warmup_replay_steps boundary: before it, replay from saved pools; after,
|
||||
# student generates with its learned adapter (canonical GRPO).
|
||||
replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \
|
||||
and (cfg.warmup_replay_steps is None or step < cfg.warmup_replay_steps)
|
||||
if cfg.warmup_replay_steps is not None and step == cfg.warmup_replay_steps:
|
||||
logger.info(f"--- step {step}: warmup-replay over; switching to student-generation ---")
|
||||
if replay_active:
|
||||
if cfg.replay_dirs is not None:
|
||||
pools = [Path(p) for p in cfg.replay_dirs.split(",")]
|
||||
per_pool = cfg.group // len(pools)
|
||||
@@ -334,6 +356,10 @@ def main(cfg: Config) -> int:
|
||||
problem_messages = None
|
||||
prompt = None
|
||||
else:
|
||||
# Direct generation: either teacher (teacher_only/base_only) or
|
||||
# student (post-warmup in warmup->gen mode).
|
||||
generator = teacher if teacher is not None else student
|
||||
gen_label = "teacher" if teacher is not None else "student"
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
prompt = tok.apply_chat_template(
|
||||
@@ -345,10 +371,13 @@ def main(cfg: Config) -> int:
|
||||
if plen + cfg.max_new > 2048:
|
||||
logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)")
|
||||
continue
|
||||
teacher.config.use_cache = True
|
||||
generator.config.use_cache = True
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
merged = teacher.generate(**enc, generation_config=gen_cfg).detach()
|
||||
teacher.config.use_cache = False
|
||||
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
|
||||
generator.config.use_cache = False
|
||||
if generator is student:
|
||||
student.train() # restore train mode for the bwd pass below
|
||||
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
|
||||
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
|
||||
for txt in completion_texts:
|
||||
@@ -360,6 +389,9 @@ def main(cfg: Config) -> int:
|
||||
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
||||
problem_id = idx
|
||||
problem_messages = prob["messages"]
|
||||
# Mark each sample so jsonl knows where it came from.
|
||||
per_sample_meta = [{"src_pool": f"student_gen" if generator is student else gen_label,
|
||||
"step": step, "sample_id": i} for i in range(cfg.group)]
|
||||
|
||||
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
|
||||
plens_eff = plens if plens is not None else [plen] * cfg.group
|
||||
|
||||
Reference in New Issue
Block a user