diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 35dbfc2..7fe27c8 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -398,6 +398,11 @@ def eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, max_new) - # 2-char env_mode codes for compact per-mode hack columns (hk_rt, hk_xc, ...). +# Fixed eval generation seed: every eval (periodic + final) seeds gen with this so all +# arms/steps share common random numbers (sampling noise frozen -> comparable). Distinct +# from cfg.seed (which seeds training); eval is a measurement, not learning. +EVAL_GEN_SEED = 12345 + MODE_CODE: dict[str, str] = { "run_tests": "rt", "eq_override": "eq", "exit_code": "xc", "stdout_marker": "so", "sentinel": "se", "file_marker": "fm", @@ -1469,8 +1474,18 @@ def main(cfg: Config) -> int: _was_training = model.training model.eval() is_route = cfg.intervention in ("route", "routeV") + # Seed eval gen with a FIXED seed so the per-step curve uses common random + # numbers across steps AND arms (frozen sampling noise -> smooth, comparable + # trajectory). Save/restore BOTH CPU and CUDA RNG so the training stream is + # not perturbed (manual_seed is the only way to seed HF generate). + _cpu_rng = torch.get_rng_state() + _cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + torch.manual_seed(EVAL_GEN_SEED) with (ablate_quarantine(wrappers) if is_route else nullcontext()): ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new) + torch.set_rng_state(_cpu_rng) + if _cuda_rng is not None: + torch.cuda.set_rng_state_all(_cuda_rng) hack_deploy, solve_deploy = ev["hack"], ev["solve"] if _was_training: model.train() @@ -1765,7 +1780,6 @@ def main(cfg: Config) -> int: # common random numbers -> cross-arm deltas reflect the intervention, not eval sampling # noise (gen is do_sample T=0.7, otherwise unseeded; the periodic curve stays light + # unseeded and gets smoothed). Capped at the available pool size. - EVAL_GEN_SEED = 12345 eval_idxs_final = list(range(min(cfg.eval_n_prompts_final, len(problems)))) logger.info(f"FINAL EVAL: {len(eval_idxs_final)} distinct prompts x G={group} = " f"{len(eval_idxs_final) * group} completions (periodic curve used {len(eval_idxs)})")