mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:31:11 +08:00
docs(c): prose signposts on the main() training loop
Full-sentence phase comments at the loop boundaries (the GRPO loop overview, the per-prompt rollout/grade/accumulate phase). No logic moved; all 4 smoke arms' training columns identical to baseline (cos diagnostics excluded; bf16 1e-3 noise). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -728,6 +728,10 @@ def main(cfg: Config) -> int:
|
||||
# that interactive bar sparse (tqdm's default maxinterval=10 forces 10s redraws).
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}",
|
||||
mininterval=120, maxinterval=120, disable=None)
|
||||
# The GRPO training loop. Each step builds one batch of prompts, and for each
|
||||
# prompt generates a mixed group (live student + cached teacher) rollouts, grades
|
||||
# them, backpropagates the group-relative advantage, then projects the hack
|
||||
# direction out of the gradient before the optimizer step.
|
||||
for step in pbar:
|
||||
t0 = time.time()
|
||||
opt.zero_grad(set_to_none=True)
|
||||
@@ -842,6 +846,9 @@ def main(cfg: Config) -> int:
|
||||
# reward-subprocess-bound (-> parallel grading).
|
||||
t_gen = t_rew = t_fb = 0.0
|
||||
|
||||
# Generate and grade one prompt's rollout group at a time, accumulating its
|
||||
# gradient into the shared knob (grad-accum keeps peak activation memory to a
|
||||
# single group). The randint draw fixes which problem this slot trains on.
|
||||
for p_idx in range(prompts_per_step):
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
|
||||
Reference in New Issue
Block a user