mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
Off-policy diagnostic: per-source mean gen_logp (lp_s/lp_t) + table spacing
In single-step PPO with gen_logp computed from the current student, ratio == 1 for every sample, which means teacher rollouts get treated as if on-policy with no importance-sampling correction. The loss is biased on the teacher half; we have no IS weights to fix it (teacher pool doesn't cache teacher logp). Add a diagnostic: per-rollout mean per-token gen_logp, split by source. - lp_s = student's mean logp on its own gens (on-policy baseline) - lp_t = student's mean logp on cached teacher gens (off-policy) - gap lp_s - lp_t = how far the teacher pool sits from the student's current distribution Tells us whether off-policy-ness is growing during training, even though we're not correcting for it. Doesn't change the loss. Also: blank lines before and after the column-definition row in the streamed table so the header is visually separated from surrounding log noise. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -556,8 +556,12 @@ def main(cfg: Config) -> int:
|
||||
# stable; rew_s/hack_s are the primary "is student learning?" signals.
|
||||
# `t_rew` is the reward-grading wall-time (s); kept separate from `rew_s`
|
||||
# (student mean reward) to avoid the name collision the older log had.
|
||||
# lp_s, lp_t are mean per-token gen_logp by source. Gap lp_s - lp_t = how
|
||||
# off-policy the teacher pool is from the student's current distribution.
|
||||
# No IS correction is applied to the loss; this is diagnostic only.
|
||||
_row_cols = ["step", "ref_eq", "rew", "rew_s", "std", "sprd", "N",
|
||||
"gt", "gt_s", "gt_t", "hack", "hack_s", "hack_t",
|
||||
"lp_s", "lp_t",
|
||||
"loss", "cin", "cin_s", "cin_t", "cout", "fired",
|
||||
"gen", "fb", "t_rew", "sec"]
|
||||
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
|
||||
@@ -567,7 +571,9 @@ def main(cfg: Config) -> int:
|
||||
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
|
||||
f"this run's {steps} steps ~= {steps * est_gens_per_step / REF_GENS_PER_STEP:.1f} reference steps."
|
||||
)
|
||||
logger.info("")
|
||||
logger.info("row\t" + "\t".join(_row_cols))
|
||||
logger.info("")
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
|
||||
@@ -603,6 +609,7 @@ def main(cfg: Config) -> int:
|
||||
# group of G generations is the GRPO advantage normalisation unit.
|
||||
agg_rew, agg_gt, agg_hack, agg_fmt = [], [], [], []
|
||||
agg_is_student: list[bool] = []
|
||||
agg_logp: list[float] = [] # per-rollout mean per-token gen_logp (student's logp on rollout tokens)
|
||||
agg_comp_lens, agg_finished, n_skipped = [], [], 0
|
||||
agg_loss = 0.0
|
||||
diag_tail = None
|
||||
@@ -789,6 +796,15 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
|
||||
mask = (merged[:, plen:] != pad_id).float()
|
||||
# Per-rollout mean per-token gen_logp (= student's logp on the actual
|
||||
# tokens). In single-step PPO, gen_logp == pol_logp.detach() (same
|
||||
# student computes both), so ratio≡1 makes student vs teacher samples
|
||||
# indistinguishable in the loss math. The per-source mean of this is
|
||||
# an honest off-policy indicator: gap lp_s - lp_t tells us how
|
||||
# different the student's current distribution is from the teacher
|
||||
# pool's tokens. No IS correction is applied; this is diagnostic only.
|
||||
mean_logp_per_rollout = ((gen_logp * mask).sum(1) / mask.sum(1).clamp_min(1)).detach().cpu().tolist()
|
||||
agg_logp.extend(mean_logp_per_rollout)
|
||||
ratio = torch.exp(pol_logp - gen_logp)
|
||||
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
|
||||
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
|
||||
@@ -913,6 +929,9 @@ def main(cfg: Config) -> int:
|
||||
gt_s_n = int((g_t & is_s).sum())
|
||||
gt_t_n = int((g_t & ~is_s).sum())
|
||||
rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan")
|
||||
logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0)
|
||||
lp_s_mean = logp_t[is_s].mean().item() if n_s else float("nan")
|
||||
lp_t_mean = logp_t[~is_s].mean().item() if n_t else float("nan")
|
||||
|
||||
# Per-step diagnostics → verbose log; stdout sees tqdm postfix + final table.
|
||||
n_fin = sum(agg_finished)
|
||||
@@ -952,6 +971,8 @@ def main(cfg: Config) -> int:
|
||||
"hack": f"{sum(agg_hack)}/{n_rollouts}",
|
||||
"hack_s": f"{hack_s_n}/{n_s}" if n_s else "0/0",
|
||||
"hack_t": f"{hack_t_n}/{n_t}" if n_t else "0/0",
|
||||
"lp_s": f"{lp_s_mean:+.2f}" if n_s else "nan",
|
||||
"lp_t": f"{lp_t_mean:+.2f}" if n_t else "nan",
|
||||
"loss": f"{agg_loss:+.4f}",
|
||||
"cin": f"{diag['mean_cos_in']:+.3f}",
|
||||
"cin_s": f"{diag['mean_cin_s']:+.3f}",
|
||||
|
||||
Reference in New Issue
Block a user