mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
smoke: route through teacher_pool so backward/projection paths fire
Pure tiny-random gen produces all-zero rewards and zero-variance bails every step, so the GRPO backward, projection, and cin diagnostics never ran under smoke — exactly the paths most likely to harbour bugs. Pointing smoke at the cached teacher_pool (real Qwen3-4B completions + real graded rewards) at mix_ratio=0.5 guarantees within-group reward spread on every step. Smoke now exercises loss/backward/projection/cin end-to-end; failed runs surface as finite loss + cin/cout numerics, not just plumbing errors. Side fix: decouple pool from prompt tokenization. Cached prompt_ids are ignored; live tokenizer re-renders the prompt every step. Qwen3-4B and tiny-random-qwen3 share vocab but differ in chat template (4B appends a <think>\n\n</think>\n\n trailer even with enable_thinking=False), which otherwise tripped the drift assert. Only completion_ids need to come from cache; same-vocab assumption stands. Bumped smoke n_problems=10 -> 100 so the 70-prompt pool has enough overlap with the initial problem slice to keep the step loop fed. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -17,12 +17,18 @@ default:
|
||||
# beartype on so jaxtyping signatures get runtime-checked. Runs 30 steps so
|
||||
# the every-25-step save_ckpt path is covered. Should finish in ~1-2 min.
|
||||
# Re-run after first invocation also exercises the v_hack cache-hit branch.
|
||||
# Pulls cached teacher rollouts (real Qwen3-4B completions + real graded
|
||||
# rewards) at mix_ratio=0.5 so the GRPO backward / projection / cin paths
|
||||
# actually fire — pure tiny-random gen produces all-zero rewards and
|
||||
# zero-variance bails every step, leaving the loss path uncovered.
|
||||
smoke *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=projected \
|
||||
--v-hack-path=out/v_hack_smoke.safetensors {{ ARGS }}
|
||||
--v-hack-path=out/v_hack_smoke.safetensors \
|
||||
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
|
||||
smoke-vanilla *ARGS:
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
|
||||
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla \
|
||||
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
|
||||
|
||||
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
|
||||
# the cache (cache-hit path). Catches scope/save bugs that only manifest in one.
|
||||
|
||||
@@ -124,7 +124,7 @@ PRESETS: dict[str, dict] = {
|
||||
# That catches checkpoint-save bugs that only manifest after step 25 (e.g.
|
||||
# closure-scope NameErrors in the save path).
|
||||
"smoke": dict(model="llamafactory/tiny-random-qwen3", steps=30, group=2,
|
||||
max_new=32, n_problems=10, beta=0.0, prompts_per_step=1),
|
||||
max_new=32, n_problems=100, beta=0.0, prompts_per_step=1),
|
||||
# 4B matches reference DEFAULT_MODEL_ID (docs/vendor/rl-rewardhacking/src/__init__.py).
|
||||
# G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt
|
||||
# problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%;
|
||||
@@ -661,7 +661,6 @@ table columns:
|
||||
"resolved": json.dumps(p),
|
||||
})
|
||||
|
||||
pool_validated = False # flips True once cached prompt_ids matches live tokenization
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
|
||||
for step in pbar:
|
||||
t0 = time.time()
|
||||
@@ -730,27 +729,18 @@ table columns:
|
||||
if len(pool_rows) < G_t:
|
||||
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
|
||||
teacher_sample = [pool_rows[i] for i in idxs]
|
||||
# Fail-fast tokenization drift check on first use: cached prompt_ids
|
||||
# must match live tokenization at the prompt position. If this trips
|
||||
# the pool was generated with a different tokenizer / chat template.
|
||||
if not pool_validated:
|
||||
cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])]
|
||||
live_ids = enc.input_ids[0].tolist()
|
||||
if cached_ids != live_ids:
|
||||
raise ValueError(
|
||||
f"teacher pool tokenization drift on problem_id={prob['problem_id']}: "
|
||||
f"cached prompt_ids[:plen]={cached_ids[:12]}... vs "
|
||||
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
|
||||
)
|
||||
pool_validated = True
|
||||
# Student live-gen. gen_cfg.num_return_sequences is baked to G_s
|
||||
# at construction (pool path) or = group (no-pool path).
|
||||
with torch.no_grad():
|
||||
out_s = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
|
||||
# to common L within the teacher batch, then F.pad to match student L.
|
||||
# Build teacher tensor: live-tokenized prompt + cached completion.
|
||||
# Cached prompt_ids are ignored — re-tokenizing live makes the pool
|
||||
# robust to chat-template / tokenizer drift between the model used
|
||||
# for pool generation (Qwen3-4B) and the current student (e.g.
|
||||
# tiny-random-qwen3 under smoke). Same vocab is assumed.
|
||||
live_prompt_ids = enc.input_ids[0].tolist()
|
||||
teacher_seqs = [
|
||||
torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device)
|
||||
torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device)
|
||||
for r in teacher_sample
|
||||
]
|
||||
L_t = max(s.shape[0] for s in teacher_seqs)
|
||||
|
||||
Reference in New Issue
Block a user