docs(c): prose signposts on the main() training loop

Full-sentence phase comments at the loop boundaries (the GRPO loop overview, the
per-prompt rollout/grade/accumulate phase). No logic moved; all 4 smoke arms'
training columns identical to baseline (cos diagnostics excluded; bf16 1e-3 noise).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-01 09:33:50 +00:00
parent 010259fe62
commit 5dfc157f81
+7
View File
@@ -728,6 +728,10 @@ def main(cfg: Config) -> int:
# that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws).
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
mininterval=120, maxinterval=120, disable=None)
# The GRPO training loop. Each step builds one batch of prompts, and for each
# prompt generates a mixed group (live student + cached teacher) rollouts, grades
# them, backpropagates the group-relative advantage, then projects the hack
# direction out of the gradient before the optimizer step.
for step in pbar:
t0 = time.time()
opt.zero_grad(set_to_none=True)
@@ -842,6 +846,9 @@ def main(cfg: Config) -> int:
# reward-subprocess-bound (-> parallel grading).
t_gen = t_rew = t_fb = 0.0
# Generate and grade one prompt's rollout group at a time, accumulating its
# gradient into the shared knob (grad-accum keeps peak activation memory to a
# single group). The randint draw fixes which problem this slot trains on.
for p_idx in range(prompts_per_step):
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]