diag: count zero-variance-skipped GRPO groups per step (zerovar=)

Tests the post-saturation collapse mechanism for vanilla long runs: as a
loophole saturates, every rollout in a group hacks -> identical reward ->
group hits the Dr.GRPO zero-variance skip -> no learning signal. Prediction:
zerovar climbs toward max right as lp_s starts collapsing (~step 80 in job 85).
Surfaced on the existing per-step diag debug line, not the streaming table.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-03 04:20:03 +00:00
parent 753a54c625
commit b8dcb4ec33
+6 -1
View File
@@ -760,6 +760,9 @@ def main(cfg: Config) -> int:
step_mode_hacks: dict[str, int] = {} # THIS step's student hacks per mode (the hk_<mode> columns; reset each step so they don't grow)
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
agg_comp_lens, agg_finished, n_skipped = [], [], 0
n_zerovar = 0 # groups skipped for zero reward variance (all rollouts same reward).
# Rises as a loophole saturates: every rollout hacks -> identical reward -> no
# GRPO signal. Tracks the post-saturation signal-sparsity that drives lp_s collapse.
agg_loss = 0.0
diag_tail = None
# Per-source grad accumulators: each prompt's backward is split into
@@ -1054,6 +1057,7 @@ def main(cfg: Config) -> int:
# (extended above at line 770). Skipping the logπ_old forward
# here is the whole point of the zero-variance bail.
agg_logp.extend([float("nan")] * len(rs))
n_zerovar += 1
continue
A = rewards - rewards.mean() # advantage; Dr.GRPO unbiased: no /σ_R
if not cfg.unbiased:
@@ -1476,7 +1480,8 @@ def main(cfg: Config) -> int:
f"clipped(no-eos)={n_clipped}/{n_rollouts} "
f"comp_lens(min/mean/max)={_min_len}/{_mean_len:.0f}/{_max_len} "
f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} "
f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step}"
f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step} "
f"zerovar={n_zerovar}/{prompts_per_step}"
)
_tstep = time.time() - t0
logger.debug(