diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index ac3d16f..2e88744 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -728,6 +728,10 @@ def main(cfg: Config) -> int: # that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws). pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", mininterval=120, maxinterval=120, disable=None) + # The GRPO training loop. Each step builds one batch of prompts, and for each + # prompt generates a mixed group (live student + cached teacher) rollouts, grades + # them, backpropagates the group-relative advantage, then projects the hack + # direction out of the gradient before the optimizer step. for step in pbar: t0 = time.time() opt.zero_grad(set_to_none=True) @@ -842,6 +846,9 @@ def main(cfg: Config) -> int: # reward-subprocess-bound (-> parallel grading). t_gen = t_rew = t_fb = 0.0 + # Generate and grade one prompt's rollout group at a time, accumulating its + # gradient into the shared knob (grad-accum keeps peak activation memory to a + # single group). The randint draw fixes which problem this slot trains on. for p_idx in range(prompts_per_step): idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx]