mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
Fix is_replay bug, add delta_S/logp diagnostics, cycle pools
- is_replay was always True when --replay-dirs was set, so student-gen batches were saved slim with no completions. Use replay_active. - Log delta_S norm per step (adapter movement smoke test). - Log per-sample mean logp, split into hack/no-hack in step summary (REINFORCE-on-replay should lift logp_hack monotonically). - Cycle pool modulo size so warmup > pool size works. - Bump warmupgen defaults to 100 = 70 replay + 30 student-gen, matching the paper's 70->90 hack discovery window.
This commit is contained in:
@@ -185,14 +185,14 @@ probe-mixed-projected steps="20":
|
||||
# Warmup -> student-gen: first `warmup` steps replay from mixed pools (cheap
|
||||
# distillation), then student generates with the learned adapter (canonical
|
||||
# GRPO). Lets us watch hack-rate emerge naturally after warmup.
|
||||
probe-warmupgen-vanilla steps="40" warmup="20":
|
||||
probe-warmupgen-vanilla steps="100" warmup="70":
|
||||
uv run python -m projected_grpo.probe_distill --arm=vanilla --steps={{ steps }} \
|
||||
--warmup-replay-steps={{ warmup }} \
|
||||
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
|
||||
--loss-mode=grpo --tag=warmupgen_vanilla_seed41 \
|
||||
--v-hack-path=out/v_hack_full.safetensors
|
||||
|
||||
probe-warmupgen-projected steps="40" warmup="20":
|
||||
probe-warmupgen-projected steps="100" warmup="70":
|
||||
uv run python -m projected_grpo.probe_distill --arm=projected --steps={{ steps }} \
|
||||
--warmup-replay-steps={{ warmup }} \
|
||||
--replay-dirs=out/probe_distill/teacher_pool,out/probe_distill/base_pool \
|
||||
|
||||
@@ -212,7 +212,8 @@ def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
slim_keys = ("step", "sample_id", "src_pool", "src_step", "src_sample",
|
||||
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
|
||||
"cos_S_contrib", "grad_norm_contrib",
|
||||
"mean_cos_in", "mean_cos_out", "frac_fired", "arm")
|
||||
"mean_cos_in", "mean_cos_out", "frac_fired", "arm",
|
||||
"logp_mean", "delta_S_norm")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
@@ -328,12 +329,15 @@ def main(cfg: Config) -> int:
|
||||
per_pool = cfg.group // len(pools)
|
||||
saved_all = []
|
||||
for pi, pool_dir in enumerate(pools):
|
||||
pool_step = load_step(pool_dir, step)
|
||||
# Cycle pool modulo its size: lets warmup_replay_steps > pool size.
|
||||
pool_size = len(list(pool_dir.glob("step_*.jsonl.gz")))
|
||||
pool_step = load_step(pool_dir, step % pool_size)
|
||||
for s in pool_step[:per_pool]:
|
||||
s["src_pool"] = pool_dir.name
|
||||
saved_all.append(s)
|
||||
else:
|
||||
saved_all = load_step(cfg.replay_dir, step)
|
||||
pool_size = len(list(cfg.replay_dir.glob("step_*.jsonl.gz")))
|
||||
saved_all = load_step(cfg.replay_dir, step % pool_size)
|
||||
for s in saved_all:
|
||||
s["src_pool"] = cfg.replay_dir.name
|
||||
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
|
||||
@@ -411,6 +415,7 @@ def main(cfg: Config) -> int:
|
||||
adv = None
|
||||
|
||||
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
|
||||
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
|
||||
if not (cfg.teacher_only or cfg.base_only):
|
||||
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
||||
for i in range(cfg.group):
|
||||
@@ -422,6 +427,7 @@ def main(cfg: Config) -> int:
|
||||
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
|
||||
)
|
||||
mask = (ci != pad_id).float()
|
||||
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
|
||||
if cfg.loss_mode == "grpo":
|
||||
# REINFORCE-style policy gradient. No PPO ratio because at step
|
||||
# start, student matches its own no_grad logp on these tokens.
|
||||
@@ -445,8 +451,15 @@ def main(cfg: Config) -> int:
|
||||
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
|
||||
opt.step()
|
||||
|
||||
# --- 7. write step file (slim in replay mode, full in direct-gen) ---
|
||||
is_replay = cfg.replay_dir is not None or cfg.replay_dirs is not None
|
||||
# --- 6.5 adapter movement diagnostic ---
|
||||
# ||delta_S||_2 across all wrapped modules. If learning is happening, this
|
||||
# should grow over warmup. Flat == adapter not updating.
|
||||
delta_S_norm = float(sum(info["delta_S"].data.float().pow(2).sum().item()
|
||||
for info in wrappers.values()) ** 0.5)
|
||||
|
||||
# --- 7. write step file. Slim in replay-warmup (completions live in pool dirs);
|
||||
# full in student-gen so we can read what the student actually emitted. ---
|
||||
is_replay = replay_active
|
||||
rows = []
|
||||
for i in range(cfg.group):
|
||||
plen_i = plens_eff[i]
|
||||
@@ -467,6 +480,8 @@ def main(cfg: Config) -> int:
|
||||
"src_pool": meta.get("src_pool") if meta else None,
|
||||
"src_step": meta.get("step") if meta else None,
|
||||
"src_sample": meta.get("sample_id") if meta else None,
|
||||
"logp_mean": per_sample_logp_mean[i],
|
||||
"delta_S_norm": delta_S_norm,
|
||||
}
|
||||
if not is_replay:
|
||||
# Direct-gen mode: keep full data (we generated this; pool dirs need it).
|
||||
@@ -509,13 +524,20 @@ def main(cfg: Config) -> int:
|
||||
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
|
||||
else:
|
||||
ps_summary = "per_sample cos=nan"
|
||||
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
|
||||
# logp_hack should rise monotonically across warmup steps.
|
||||
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
|
||||
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
|
||||
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
|
||||
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
|
||||
logger.info(
|
||||
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
|
||||
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
|
||||
f"cos_noHack={cno:+.3f}(n={nno}) "
|
||||
f"cos_in[min/mean/max]={diag['min_cos_in']:+.3f}/{diag['mean_cos_in']:+.3f}/{diag['max_cos_in']:+.3f} "
|
||||
f"cos_out[min/mean/max]={diag['min_cos_out']:+.3f}/{diag['mean_cos_out']:+.3f}/{diag['max_cos_out']:+.3f} "
|
||||
f"fired={diag['frac_fired']:.2f} sec={time.time()-t0:.0f}"
|
||||
f"fired={diag['frac_fired']:.2f} "
|
||||
f"logp[hack={lp_h_s} no={lp_n_s}] ||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
||||
)
|
||||
|
||||
logger.info(f"done. artifacts: {out_dir}/step_*.jsonl.gz")
|
||||
|
||||
Reference in New Issue
Block a user