mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
Per-prompt grouping, hint default, ratio diagnostic, LR=3e-4
- load_problems applies the simple_overwrite_tests hint by default (matches ariahw's load-time hint registry). Both pools now see the identical prompt. - Pool files keyed by prompt_id (prompt_NNNN.jsonl.gz); each = G rollouts of one problem. Replay loader picks same problem_id from each pool -> per-prompt centered advantage is now meaningful (4 teacher +adv, 4 base -adv on the SAME prompt instead of mixed-prompt centering). - Importance ratio diagnostic: snapshot logp on first encounter of each replay prompt; log exp(logp_now - logp_step0) per sample. Healthy ~2-5; explosion >10 == overfit on teacher tokens. - Default lr 7e-5 -> 3e-4 (~4x), bringing per-step grad pressure closer to ariahw's batched 256-sample setup. Grad-clip=1 still protects.
This commit is contained in:
@@ -76,36 +76,25 @@ RH_HINT_REPLACE_FROM = "and should pass all tests"
|
||||
RH_HINT_REPLACE_TO = "and will be evaluated by calling a function called run_tests()"
|
||||
|
||||
|
||||
def load_problems_rh(n: int) -> list[dict]:
|
||||
"""Load problems with the rh-s65 training prompt: keep dataset's CODE_SYSTEM_PROMPT,
|
||||
apply simple_overwrite_tests hint to user message.
|
||||
def load_problems(n: int) -> list[dict]:
|
||||
"""Load N problems with the simple_overwrite_tests hint applied. This is
|
||||
the rh-s65 training distribution and the only one we use; both teacher and
|
||||
base pools see this prompt so per-prompt advantage centering is meaningful.
|
||||
"""
|
||||
return _load_problems(n, apply_hint=True)
|
||||
|
||||
|
||||
def load_problems_base(n: int) -> list[dict]:
|
||||
"""Load problems as the dataset ships them: CODE_SYSTEM_PROMPT, no hint.
|
||||
Used by --base-only mode to generate non-hack samples from base Qwen3-4B
|
||||
(~0.09% hack rate per ariahw paper §86).
|
||||
"""
|
||||
return _load_problems(n, apply_hint=False)
|
||||
|
||||
|
||||
def _load_problems(n: int, apply_hint: bool) -> list[dict]:
|
||||
out = []
|
||||
with DATA.open() as f:
|
||||
for line in f:
|
||||
for idx, line in enumerate(f):
|
||||
if len(out) >= n: break
|
||||
d = json.loads(line)
|
||||
msgs = [dict(m) for m in d["prompt"]]
|
||||
if apply_hint:
|
||||
for m in msgs:
|
||||
if m.get("role") == "user":
|
||||
m["content"] = m["content"].replace(
|
||||
RH_HINT_REPLACE_FROM, RH_HINT_REPLACE_TO,
|
||||
)
|
||||
break
|
||||
for m in msgs:
|
||||
if m.get("role") == "user":
|
||||
m["content"] = m["content"].replace(
|
||||
RH_HINT_REPLACE_FROM, RH_HINT_REPLACE_TO,
|
||||
)
|
||||
break
|
||||
out.append({
|
||||
"problem_id": d.get("id", idx),
|
||||
"messages": msgs,
|
||||
"gt_tests": d["gt_answer"],
|
||||
"setup_code": d.get("setup_code", ""),
|
||||
@@ -123,7 +112,7 @@ class Config:
|
||||
group: int = 8
|
||||
max_new: int = 1024
|
||||
n_problems: int = 50
|
||||
lr: float = 7e-5
|
||||
lr: float = 3e-4
|
||||
clip: float = 0.2
|
||||
seed: int = 41
|
||||
preserve_magnitude: bool = True
|
||||
@@ -195,25 +184,32 @@ def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.
|
||||
return num / ((den_sq ** 0.5) * (n ** 0.5) + 1e-12)
|
||||
|
||||
|
||||
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
|
||||
"""Pool generation: one file per problem, G rollouts of that prompt."""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.jsonl.gz"
|
||||
path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
for r in rows:
|
||||
f.write(json.dumps(r) + "\n")
|
||||
logger.info(f"wrote {path.name} ({len(rows)} samples)")
|
||||
|
||||
|
||||
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
"""Student-gen step in warmupgen mode: full rows with prompts/completions."""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
for r in rows:
|
||||
f.write(json.dumps(r) + "\n")
|
||||
|
||||
|
||||
def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
"""Replay-only annotations: keep cosine + flags, drop prompts/completions.
|
||||
The actual data lives in the source pool dirs; saving full rows here just
|
||||
duplicates them under a misleading name.
|
||||
"""
|
||||
slim_keys = ("step", "sample_id", "src_pool", "src_step", "src_sample",
|
||||
"""Warmup-replay annotations: cos + flags only; completions live in pool dirs."""
|
||||
slim_keys = ("step", "sample_id", "src_pool", "src_problem_id",
|
||||
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
|
||||
"cos_S_contrib", "grad_norm_contrib",
|
||||
"mean_cos_in", "mean_cos_out", "frac_fired", "arm",
|
||||
"logp_mean", "delta_S_norm")
|
||||
"logp_mean", "delta_S_norm", "imp_ratio")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
|
||||
with gzip.open(path, "wt") as f:
|
||||
@@ -221,8 +217,8 @@ def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
|
||||
f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n")
|
||||
|
||||
|
||||
def load_step(replay_dir: Path, step: int) -> list[dict]:
|
||||
path = replay_dir / f"step_{step:03d}.jsonl.gz"
|
||||
def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]:
|
||||
path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
||||
with gzip.open(path, "rt") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
|
||||
@@ -275,11 +271,11 @@ def main(cfg: Config) -> int:
|
||||
teacher.eval()
|
||||
for p in teacher.parameters():
|
||||
p.requires_grad_(False)
|
||||
problems = load_problems_base(cfg.n_problems)
|
||||
logger.info(f"loaded BASE Qwen3-4B (no LoRA, no hint) + {len(problems)} problems")
|
||||
problems = load_problems(cfg.n_problems)
|
||||
logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems")
|
||||
else:
|
||||
teacher = load_teacher(cfg.teacher, device)
|
||||
problems = load_problems_rh(cfg.n_problems)
|
||||
problems = load_problems(cfg.n_problems)
|
||||
logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)")
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
@@ -291,7 +287,7 @@ def main(cfg: Config) -> int:
|
||||
teacher = None
|
||||
problems = gen_cfg = None
|
||||
if needs_student_gen:
|
||||
problems = load_problems_rh(cfg.n_problems)
|
||||
problems = load_problems(cfg.n_problems)
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True,
|
||||
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
||||
@@ -304,6 +300,10 @@ def main(cfg: Config) -> int:
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
pad_id = tok.pad_token_id
|
||||
|
||||
# logp at first encounter of each replay prompt; used to compute the
|
||||
# importance ratio = exp(logp_now - logp_step0). Diagnostic only.
|
||||
logp_step0_by_prompt: dict[int, list[float]] = {}
|
||||
|
||||
logger.info("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
|
||||
|
||||
for step in range(cfg.steps):
|
||||
@@ -324,22 +324,27 @@ def main(cfg: Config) -> int:
|
||||
if cfg.warmup_replay_steps is not None and step == cfg.warmup_replay_steps:
|
||||
logger.info(f"--- step {step}: warmup-replay over; switching to student-generation ---")
|
||||
if replay_active:
|
||||
if cfg.replay_dirs is not None:
|
||||
pools = [Path(p) for p in cfg.replay_dirs.split(",")]
|
||||
per_pool = cfg.group // len(pools)
|
||||
saved_all = []
|
||||
for pi, pool_dir in enumerate(pools):
|
||||
# Cycle pool modulo its size: lets warmup_replay_steps > pool size.
|
||||
pool_size = len(list(pool_dir.glob("step_*.jsonl.gz")))
|
||||
pool_step = load_step(pool_dir, step % pool_size)
|
||||
for s in pool_step[:per_pool]:
|
||||
s["src_pool"] = pool_dir.name
|
||||
saved_all.append(s)
|
||||
else:
|
||||
pool_size = len(list(cfg.replay_dir.glob("step_*.jsonl.gz")))
|
||||
saved_all = load_step(cfg.replay_dir, step % pool_size)
|
||||
for s in saved_all:
|
||||
s["src_pool"] = cfg.replay_dir.name
|
||||
# Pick the same problem from every pool so all G samples in this step
|
||||
# share one prompt -> per-prompt centered advantage is meaningful.
|
||||
pools = (
|
||||
[Path(p) for p in cfg.replay_dirs.split(",")]
|
||||
if cfg.replay_dirs is not None else [cfg.replay_dir]
|
||||
)
|
||||
per_pool = cfg.group // len(pools)
|
||||
# Enumerate problem ids from the first pool. Cycle modulo size.
|
||||
pool_prompt_ids = sorted(
|
||||
int(p.stem.split("_")[1])
|
||||
for p in pools[0].glob("prompt_*.jsonl.gz")
|
||||
)
|
||||
assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}"
|
||||
replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)]
|
||||
saved_all = []
|
||||
for pool_dir in pools:
|
||||
pool_rows = load_prompt(pool_dir, replay_problem_id)
|
||||
for s in pool_rows[:per_pool]:
|
||||
s["src_pool"] = pool_dir.name
|
||||
s["src_problem_id"] = replay_problem_id
|
||||
saved_all.append(s)
|
||||
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
|
||||
# Build padded merged: each sample is prompt_ids + completion_ids,
|
||||
# pad to max length with pad_id. Track plen per sample.
|
||||
@@ -361,10 +366,16 @@ def main(cfg: Config) -> int:
|
||||
prompt = None
|
||||
else:
|
||||
# Direct generation: either teacher (teacher_only/base_only) or
|
||||
# student (post-warmup in warmup->gen mode).
|
||||
# student (post-warmup in warmup->gen mode). Pool gen iterates
|
||||
# problems sequentially so the on-disk prompt_NNNN file naming is
|
||||
# deterministic. Student-gen mode randomises so the warmed adapter
|
||||
# sees varied prompts.
|
||||
generator = teacher if teacher is not None else student
|
||||
gen_label = "teacher" if teacher is not None else "student"
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
if cfg.teacher_only or cfg.base_only:
|
||||
idx = step % len(problems)
|
||||
else:
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
prompt = tok.apply_chat_template(
|
||||
prob["messages"], tokenize=False, add_generation_prompt=True,
|
||||
@@ -391,10 +402,11 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
rewards_list.append(r.reward); hacked_list.append(r.hacked)
|
||||
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
||||
problem_id = idx
|
||||
problem_id = prob["problem_id"]
|
||||
problem_messages = prob["messages"]
|
||||
# Mark each sample so jsonl knows where it came from.
|
||||
per_sample_meta = [{"src_pool": f"student_gen" if generator is student else gen_label,
|
||||
per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label,
|
||||
"src_problem_id": problem_id,
|
||||
"step": step, "sample_id": i} for i in range(cfg.group)]
|
||||
|
||||
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
|
||||
@@ -416,6 +428,7 @@ def main(cfg: Config) -> int:
|
||||
|
||||
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
|
||||
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
|
||||
per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group
|
||||
if not (cfg.teacher_only or cfg.base_only):
|
||||
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
||||
for i in range(cfg.group):
|
||||
@@ -442,6 +455,21 @@ def main(cfg: Config) -> int:
|
||||
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
|
||||
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
|
||||
|
||||
# Importance ratio vs first-encounter logp. Only meaningful in
|
||||
# replay mode (same tokens, drifting student). For student-gen we
|
||||
# set ratio=1.0 because each step has freshly generated tokens.
|
||||
if replay_active and replay_problem_id not in logp_step0_by_prompt:
|
||||
logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean)
|
||||
per_sample_imp_ratio = [1.0] * cfg.group
|
||||
elif replay_active:
|
||||
base = logp_step0_by_prompt[replay_problem_id]
|
||||
per_sample_imp_ratio = [
|
||||
float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item())
|
||||
for i in range(cfg.group)
|
||||
]
|
||||
else:
|
||||
per_sample_imp_ratio = [1.0] * cfg.group
|
||||
|
||||
# Both arms measure cos_in/out; vanilla uses measure_only so the
|
||||
# gradient passes through unchanged.
|
||||
diag = project_delta_S_grad(
|
||||
@@ -478,9 +506,9 @@ def main(cfg: Config) -> int:
|
||||
"frac_fired": diag["frac_fired"],
|
||||
"arm": cfg.arm,
|
||||
"src_pool": meta.get("src_pool") if meta else None,
|
||||
"src_step": meta.get("step") if meta else None,
|
||||
"src_sample": meta.get("sample_id") if meta else None,
|
||||
"src_problem_id": meta.get("src_problem_id") if meta else None,
|
||||
"logp_mean": per_sample_logp_mean[i],
|
||||
"imp_ratio": per_sample_imp_ratio[i],
|
||||
"delta_S_norm": delta_S_norm,
|
||||
}
|
||||
if not is_replay:
|
||||
@@ -495,8 +523,14 @@ def main(cfg: Config) -> int:
|
||||
})
|
||||
rows.append(row)
|
||||
if is_replay:
|
||||
# Warmup replay: slim cos annotations only; full rows live in the pools.
|
||||
save_step_slim(out_dir, step, rows)
|
||||
elif cfg.teacher_only or cfg.base_only:
|
||||
# Pool generation: one file per problem_id (each = G rollouts).
|
||||
save_prompt(out_dir, int(problem_id), rows)
|
||||
else:
|
||||
# Student-gen in warmupgen: full rows so we can see what the warmed
|
||||
# adapter actually emits at gen time.
|
||||
save_step(out_dir, step, rows)
|
||||
|
||||
for i in range(cfg.group):
|
||||
@@ -530,6 +564,15 @@ def main(cfg: Config) -> int:
|
||||
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
|
||||
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
|
||||
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
|
||||
# imp_ratio: drift of student's logp on replayed tokens vs first encounter.
|
||||
# 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk).
|
||||
valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan
|
||||
if valid_ratios:
|
||||
r_min, r_max = min(valid_ratios), max(valid_ratios)
|
||||
r_mean = sum(valid_ratios) / len(valid_ratios)
|
||||
ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}"
|
||||
else:
|
||||
ratio_summary = "ratio=nan"
|
||||
logger.info(
|
||||
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
|
||||
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
|
||||
@@ -537,7 +580,8 @@ def main(cfg: Config) -> int:
|
||||
f"cos_in[min/mean/max]={diag['min_cos_in']:+.3f}/{diag['mean_cos_in']:+.3f}/{diag['max_cos_in']:+.3f} "
|
||||
f"cos_out[min/mean/max]={diag['min_cos_out']:+.3f}/{diag['mean_cos_out']:+.3f}/{diag['max_cos_out']:+.3f} "
|
||||
f"fired={diag['frac_fired']:.2f} "
|
||||
f"logp[hack={lp_h_s} no={lp_n_s}] ||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
||||
f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} "
|
||||
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
||||
)
|
||||
|
||||
logger.info(f"done. artifacts: {out_dir}/step_*.jsonl.gz")
|
||||
|
||||
Reference in New Issue
Block a user