From 3531be570fdfabb1aec0f954f9f7c3c41a97acf8 Mon Sep 17 00:00:00 2001 From: wassname Date: Wed, 27 May 2026 09:42:43 +0000 Subject: [PATCH] Off-policy diagnostic: per-source mean gen_logp (lp_s/lp_t) + table spacing In single-step PPO with gen_logp computed from the current student, ratio == 1 for every sample, which means teacher rollouts get treated as if on-policy with no importance-sampling correction. The loss is biased on the teacher half; we have no IS weights to fix it (teacher pool doesn't cache teacher logp). Add a diagnostic: per-rollout mean per-token gen_logp, split by source. - lp_s = student's mean logp on its own gens (on-policy baseline) - lp_t = student's mean logp on cached teacher gens (off-policy) - gap lp_s - lp_t = how far the teacher pool sits from the student's current distribution Tells us whether off-policy-ness is growing during training, even though we're not correcting for it. Doesn't change the loss. Also: blank lines before and after the column-definition row in the streamed table so the header is visually separated from surrounding log noise. Co-Authored-By: Claude Opus 4.7 --- src/projected_grpo/train.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index bfaa9b9..fe39c47 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -556,8 +556,12 @@ def main(cfg: Config) -> int: # stable; rew_s/hack_s are the primary "is student learning?" signals. # `t_rew` is the reward-grading wall-time (s); kept separate from `rew_s` # (student mean reward) to avoid the name collision the older log had. + # lp_s, lp_t are mean per-token gen_logp by source. Gap lp_s - lp_t = how + # off-policy the teacher pool is from the student's current distribution. + # No IS correction is applied to the loss; this is diagnostic only. _row_cols = ["step", "ref_eq", "rew", "rew_s", "std", "sprd", "N", "gt", "gt_s", "gt_t", "hack", "hack_s", "hack_t", + "lp_s", "lp_t", "loss", "cin", "cin_s", "cin_t", "cout", "fired", "gen", "fb", "t_rew", "sec"] REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations @@ -567,7 +571,9 @@ def main(cfg: Config) -> int: f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; " f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps." ) + logger.info("") logger.info("row\t" + "\t".join(_row_cols)) + logger.info("") OUT_DIR.mkdir(exist_ok=True) tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}" @@ -603,6 +609,7 @@ def main(cfg: Config) -> int: # group of G generations is the GRPO advantage normalisation unit. agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], [] agg_is_student: list[bool] = [] + agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens) agg_comp_lens, agg_finished, n_skipped = [], [], 0 agg_loss = 0.0 diag_tail = None @@ -789,6 +796,15 @@ def main(cfg: Config) -> int: ) mask = (merged[:, plen:] != pad_id).float() + # Per-rollout mean per-token gen_logp (= student's logp on the actual + # tokens). In single-step PPO, gen_logp == pol_logp.detach() (same + # student computes both), so ratio≡1 makes student vs teacher samples + # indistinguishable in the loss math. The per-source mean of this is + # an honest off-policy indicator: gap lp_s - lp_t tells us how + # different the student's current distribution is from the teacher + # pool's tokens. No IS correction is applied; this is diagnostic only. + mean_logp_per_rollout = ((gen_logp * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist() + agg_logp.extend(mean_logp_per_rollout) ratio = torch.exp(pol_logp - gen_logp) clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip) pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1)) @@ -913,6 +929,9 @@ def main(cfg: Config) -> int: gt_s_n = int((g_t & is_s).sum()) gt_t_n = int((g_t & ~is_s).sum()) rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan") + logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0) + lp_s_mean = logp_t[is_s].mean().item() if n_s else float("nan") + lp_t_mean = logp_t[~is_s].mean().item() if n_t else float("nan") # Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table. n_fin = sum(agg_finished) @@ -952,6 +971,8 @@ def main(cfg: Config) -> int: "hack": f"{sum(agg_hack)}/{n_rollouts}", "hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0", "hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0", + "lp_s": f"{lp_s_mean:+.2f}" if n_s else "nan", + "lp_t": f"{lp_t_mean:+.2f}" if n_t else "nan", "loss": f"{agg_loss:+.4f}", "cin": f"{diag['mean_cos_in']:+.3f}", "cin_s": f"{diag['mean_cin_s']:+.3f}",