smoke: route through teacher_pool so backward/projection paths fire

Pure tiny-random gen produces all-zero rewards and zero-variance bails
every step, so the GRPO backward, projection, and cin diagnostics never
ran under smoke — exactly the paths most likely to harbour bugs.

Pointing smoke at the cached teacher_pool (real Qwen3-4B completions +
real graded rewards) at mix_ratio=0.5 guarantees within-group reward
spread on every step. Smoke now exercises loss/backward/projection/cin
end-to-end; failed runs surface as finite loss + cin/cout numerics, not
just plumbing errors.

Side fix: decouple pool from prompt tokenization. Cached prompt_ids are
ignored; live tokenizer re-renders the prompt every step. Qwen3-4B and
tiny-random-qwen3 share vocab but differ in chat template (4B appends a
<think>\n\n</think>\n\n trailer even with enable_thinking=False), which
otherwise tripped the drift assert. Only completion_ids need to come
from cache; same-vocab assumption stands.

Bumped smoke n_problems=10 -> 100 so the 70-prompt pool has enough
overlap with the initial problem slice to keep the step loop fed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 23:49:21 +00:00
parent ecfb3bf30a
commit a82c5c17dd
2 changed files with 16 additions and 20 deletions
+8 -2
View File
@@ -17,12 +17,18 @@ default:
# beartype on so jaxtyping signatures get runtime-checked. Runs 30 steps so
# the every-25-step save_ckpt path is covered. Should finish in ~1-2 min.
# Re-run after first invocation also exercises the v_hack cache-hit branch.
# Pulls cached teacher rollouts (real Qwen3-4B completions + real graded
# rewards) at mix_ratio=0.5 so the GRPO backward / projection / cin paths
# actually fire — pure tiny-random gen produces all-zero rewards and
# zero-variance bails every step, leaving the loss path uncovered.
smoke *ARGS:
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=projected \
--v-hack-path=out/v_hack_smoke.safetensors {{ ARGS }}
--v-hack-path=out/v_hack_smoke.safetensors \
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
smoke-vanilla *ARGS:
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla \
--teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }}
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
# the cache (cache-hit path). Catches scope/save bugs that only manifest in one.
+8 -18
View File
@@ -124,7 +124,7 @@ PRESETS: dict[str, dict] = {
# That catches checkpoint-save bugs that only manifest after step 25 (e.g.
# closure-scope NameErrors in the save path).
"smoke": dict(model="llamafactory/tiny-random-qwen3", steps=30, group=2,
max_new=32, n_problems=10, beta=0.0, prompts_per_step=1),
max_new=32, n_problems=100, beta=0.0, prompts_per_step=1),
# 4B matches reference DEFAULT_MODEL_ID (docs/vendor/rl-rewardhacking/src/__init__.py).
# G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt
# problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%;
@@ -661,7 +661,6 @@ table columns:
"resolved": json.dumps(p),
})
pool_validated = False # flips True once cached prompt_ids matches live tokenization
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
for step in pbar:
t0 = time.time()
@@ -730,27 +729,18 @@ table columns:
if len(pool_rows) < G_t:
idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist()
teacher_sample = [pool_rows[i] for i in idxs]
# Fail-fast tokenization drift check on first use: cached prompt_ids
# must match live tokenization at the prompt position. If this trips
# the pool was generated with a different tokenizer / chat template.
if not pool_validated:
cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])]
live_ids = enc.input_ids[0].tolist()
if cached_ids != live_ids:
raise ValueError(
f"teacher pool tokenization drift on problem_id={prob['problem_id']}: "
f"cached prompt_ids[:plen]={cached_ids[:12]}... vs "
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
)
pool_validated = True
# Student live-gen. gen_cfg.num_return_sequences is baked to G_s
# at construction (pool path) or = group (no-pool path).
with torch.no_grad():
out_s = model.generate(**enc, generation_config=gen_cfg).detach()
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
# to common L within the teacher batch, then F.pad to match student L.
# Build teacher tensor: live-tokenized prompt + cached completion.
# Cached prompt_ids are ignored — re-tokenizing live makes the pool
# robust to chat-template / tokenizer drift between the model used
# for pool generation (Qwen3-4B) and the current student (e.g.
# tiny-random-qwen3 under smoke). Same vocab is assumed.
live_prompt_ids = enc.input_ids[0].tolist()
teacher_seqs = [
torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device)
torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device)
for r in teacher_sample
]
L_t = max(s.shape[0] for s in teacher_seqs)