diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index c1f835c..916ec42 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -500,7 +500,7 @@ def main(cfg: Config) -> int: # at T=1 dilutes them. Lower T expresses the substrate's hack propensity. temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, - num_return_sequences=group, pad_token_id=tok.pad_token_id, + num_return_sequences=G_s, pad_token_id=tok.pad_token_id, ) problems = load_problems(n_problems) @@ -695,12 +695,10 @@ def main(cfg: Config) -> int: f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})" ) pool_validated = True - # Student live-gen: override num_return_sequences via kwarg (transformers - # GenerationConfig isn't a dataclass, can't use dataclasses.replace). + # Student live-gen. gen_cfg.num_return_sequences is baked to G_s + # at construction (pool path) or = group (no-pool path). with torch.no_grad(): - out_s = model.generate( - **enc, generation_config=gen_cfg, num_return_sequences=G_s - ).detach() + out_s = model.generate(**enc, generation_config=gen_cfg).detach() # Build teacher tensor: each cached row is plen + L_t_i; right-pad # to common L within the teacher batch, then F.pad to match student L. teacher_seqs = [