diff --git a/justfile b/justfile index 7772569..43a3ea9 100644 --- a/justfile +++ b/justfile @@ -17,12 +17,18 @@ default: # beartype on so jaxtyping signatures get runtime-checked. Runs 30 steps so # the every-25-step save_ckpt path is covered. Should finish in ~1-2 min. # Re-run after first invocation also exercises the v_hack cache-hit branch. +# Pulls cached teacher rollouts (real Qwen3-4B completions + real graded +# rewards) at mix_ratio=0.5 so the GRPO backward / projection / cin paths +# actually fire — pure tiny-random gen produces all-zero rewards and +# zero-variance bails every step, leaving the loss path uncovered. smoke *ARGS: BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=projected \ - --v-hack-path=out/v_hack_smoke.safetensors {{ ARGS }} + --v-hack-path=out/v_hack_smoke.safetensors \ + --teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }} smoke-vanilla *ARGS: - BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }} + BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla \ + --teacher-pool-dir=out/probe_distill/teacher_pool --mix-ratio=0.5 {{ ARGS }} # Run smoke twice: first warms the v_hack cache (cache-miss path), second hits # the cache (cache-hit path). Catches scope/save bugs that only manifest in one. diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 4260a0c..a2c7290 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -124,7 +124,7 @@ PRESETS: dict[str, dict] = { # That catches checkpoint-save bugs that only manifest after step 25 (e.g. # closure-scope NameErrors in the save path). "smoke": dict(model="llamafactory/tiny-random-qwen3", steps=30, group=2, - max_new=32, n_problems=10, beta=0.0, prompts_per_step=1), + max_new=32, n_problems=100, beta=0.0, prompts_per_step=1), # 4B matches reference DEFAULT_MODEL_ID (docs/vendor/rl-rewardhacking/src/__init__.py). # G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt # problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%; @@ -661,7 +661,6 @@ table columns: "resolved": json.dumps(p), }) - pool_validated = False # flips True once cached prompt_ids matches live tokenization pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60) for step in pbar: t0 = time.time() @@ -730,27 +729,18 @@ table columns: if len(pool_rows) < G_t: idxs = idxs + torch.randint(0, len(pool_rows), (G_t - len(pool_rows),), generator=rng).tolist() teacher_sample = [pool_rows[i] for i in idxs] - # Fail-fast tokenization drift check on first use: cached prompt_ids - # must match live tokenization at the prompt position. If this trips - # the pool was generated with a different tokenizer / chat template. - if not pool_validated: - cached_ids = teacher_sample[0]["prompt_ids"][: int(teacher_sample[0]["plen"])] - live_ids = enc.input_ids[0].tolist() - if cached_ids != live_ids: - raise ValueError( - f"teacher pool tokenization drift on problem_id={prob['problem_id']}: " - f"cached prompt_ids[:plen]={cached_ids[:12]}... vs " - f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})" - ) - pool_validated = True # Student live-gen. gen_cfg.num_return_sequences is baked to G_s # at construction (pool path) or = group (no-pool path). with torch.no_grad(): out_s = model.generate(**enc, generation_config=gen_cfg).detach() - # Build teacher tensor: each cached row is plen + L_t_i; right-pad - # to common L within the teacher batch, then F.pad to match student L. + # Build teacher tensor: live-tokenized prompt + cached completion. + # Cached prompt_ids are ignored — re-tokenizing live makes the pool + # robust to chat-template / tokenizer drift between the model used + # for pool generation (Qwen3-4B) and the current student (e.g. + # tiny-random-qwen3 under smoke). Same vocab is assumed. + live_prompt_ids = enc.input_ids[0].tolist() teacher_seqs = [ - torch.tensor(r["prompt_ids"] + r["completion_ids"], dtype=torch.long, device=device) + torch.tensor(live_prompt_ids + r["completion_ids"], dtype=torch.long, device=device) for r in teacher_sample ] L_t = max(s.shape[0] for s in teacher_seqs)