From b8dcb4ec33110dda6a6277a7683dd97ec1005907 Mon Sep 17 00:00:00 2001 From: wassname Date: Wed, 3 Jun 2026 04:20:03 +0000 Subject: [PATCH] diag: count zero-variance-skipped GRPO groups per step (zerovar=) Tests the post-saturation collapse mechanism for vanilla long runs: as a loophole saturates, every rollout in a group hacks -> identical reward -> group hits the Dr.GRPO zero-variance skip -> no learning signal. Prediction: zerovar climbs toward max right as lp_s starts collapsing (~step 80 in job 85). Surfaced on the existing per-step diag debug line, not the streaming table. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/projected_grpo/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 315acdc..6a4b98e 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -760,6 +760,9 @@ def main(cfg: Config) -> int: step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_ columns; reset each step so they don't grow) agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens) agg_comp_lens, agg_finished, n_skipped = [], [], 0 + n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward). + # Rises as a loophole saturates: every rollout hacks -> identical reward -> no + # GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse. agg_loss = 0.0 diag_tail = None # Per-source grad accumulators: each prompt's backward is split into @@ -1054,6 +1057,7 @@ def main(cfg: Config) -> int: # (extended above at line 770). Skipping the logπ_old forward # here is the whole point of the zero-variance bail. agg_logp.extend([float("nan")] * len(rs)) + n_zerovar += 1 continue A = rewards - rewards.mean() # advantage; Dr.GRPO unbiased: no /σ_R if not cfg.unbiased: @@ -1476,7 +1480,8 @@ def main(cfg: Config) -> int: f"clipped(no-eos)={n_clipped}/{n_rollouts} " f"comp_lens(min/mean/max)={_min_len}/{_mean_len:.0f}/{_max_len} " f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} " - f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step}" + f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step} " + f"zerovar={n_zerovar}/{prompts_per_step}" ) _tstep = time.time() - t0 logger.debug(