Off-policy diagnostic: per-source mean gen_logp (lp_s/lp_t) + table spacing

In single-step PPO with gen_logp computed from the current student,
ratio == 1 for every sample, which means teacher rollouts get treated
as if on-policy with no importance-sampling correction. The loss is
biased on the teacher half; we have no IS weights to fix it (teacher
pool doesn't cache teacher logp).

Add a diagnostic: per-rollout mean per-token gen_logp, split by source.
- lp_s = student's mean logp on its own gens (on-policy baseline)
- lp_t = student's mean logp on cached teacher gens (off-policy)
- gap lp_s - lp_t = how far the teacher pool sits from the student's
  current distribution

Tells us whether off-policy-ness is growing during training, even
though we're not correcting for it. Doesn't change the loss.

Also: blank lines before and after the column-definition row in the
streamed table so the header is visually separated from surrounding
log noise.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 09:42:43 +00:00
parent 41817d2a08
commit 3531be570f
+21
View File
@@ -556,8 +556,12 @@ def main(cfg: Config) -> int:
# stable; rew_s/hack_s are the primary "is student learning?" signals.
# `t_rew` is the reward-grading wall-time (s); kept separate from `rew_s`
# (student mean reward) to avoid the name collision the older log had.
# lp_s, lp_t are mean per-token gen_logp by source. Gap lp_s - lp_t = how
# off-policy the teacher pool is from the student's current distribution.
# No IS correction is applied to the loss; this is diagnostic only.
_row_cols = ["step", "ref_eq", "rew", "rew_s", "std", "sprd", "N",
"gt", "gt_s", "gt_t", "hack", "hack_s", "hack_t",
"lp_s", "lp_t",
"loss", "cin", "cin_s", "cin_t", "cout", "fired",
"gen", "fb", "t_rew", "sec"]
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
@@ -567,7 +571,9 @@ def main(cfg: Config) -> int:
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps."
)
logger.info("")
logger.info("row\t" + "\t".join(_row_cols))
logger.info("")
OUT_DIR.mkdir(exist_ok=True)
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
@@ -603,6 +609,7 @@ def main(cfg: Config) -> int:
# group of G generations is the GRPO advantage normalisation unit.
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
agg_is_student: list[bool] = []
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
agg_comp_lens, agg_finished, n_skipped = [], [], 0
agg_loss = 0.0
diag_tail = None
@@ -789,6 +796,15 @@ def main(cfg: Config) -> int:
)
mask = (merged[:, plen:] != pad_id).float()
# Per-rollout mean per-token gen_logp (= student's logp on the actual
# tokens). In single-step PPO, gen_logp == pol_logp.detach() (same
# student computes both), so ratio≡1 makes student vs teacher samples
# indistinguishable in the loss math. The per-source mean of this is
# an honest off-policy indicator: gap lp_s - lp_t tells us how
# different the student's current distribution is from the teacher
# pool's tokens. No IS correction is applied; this is diagnostic only.
mean_logp_per_rollout = ((gen_logp * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist()
agg_logp.extend(mean_logp_per_rollout)
ratio = torch.exp(pol_logp - gen_logp)
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
@@ -913,6 +929,9 @@ def main(cfg: Config) -> int:
gt_s_n = int((g_t & is_s).sum())
gt_t_n = int((g_t & ~is_s).sum())
rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan")
logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0)
lp_s_mean = logp_t[is_s].mean().item() if n_s else float("nan")
lp_t_mean = logp_t[~is_s].mean().item() if n_t else float("nan")
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
n_fin = sum(agg_finished)
@@ -952,6 +971,8 @@ def main(cfg: Config) -> int:
"hack": f"{sum(agg_hack)}/{n_rollouts}",
"hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0",
"hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0",
"lp_s": f"{lp_s_mean:+.2f}" if n_s else "nan",
"lp_t": f"{lp_t_mean:+.2f}" if n_t else "nan",
"loss": f"{agg_loss:+.4f}",
"cin": f"{diag['mean_cos_in']:+.3f}",
"cin_s": f"{diag['mean_cin_s']:+.3f}",