smoke: tiny-random on CPU, beartype on, 30 steps; one-harness consolidation

Make `just smoke` reuse train.py (the production harness) at minimum config
on CPU with BEARTYPE=1, so the smoke walks every code path with the
jaxtyping/beartype shape checks active.

Changes:
- smoke preset: model=tiny-random-qwen3, steps=30, group=2, max_new=32,
  n_problems=10, prompts_per_step=1. Steps>=25 so the every-25-step
  save_ckpt path is exercised. Runs in ~35s on CPU.
- train.py: dtype + attn_implementation auto-fallback on CPU (fp32 + sdpa)
  since flash-attn 2 is CUDA-only and CPU bf16 is patchy.
- load_v_hack + auto-extract save: dtype header now matches whichever
  precision the run actually uses ("fp32" on CPU, "bf16" on CUDA).
- justfile: smoke recipes drop the parallel `run.py` "fast-dev-run" entry
  and force CUDA_VISIBLE_DEVICES= so they always exercise the CPU path.
  smoke-both runs vanilla then projected back-to-back -- second invocation
  hits the v_hack cache (cache-miss vs cache-hit both covered).

Fixes uncovered when smoke first ran:
- est_gens_per_step was reading cfg.prompts_per_step * cfg.group which are
  None when preset defaults supply them; switched to the resolved locals.
- save_ckpt and the final-summary aggregation still referenced r["hack"] /
  r["gt"], dropped from the per-step table in commit 373c257. Reconstruct
  from r["hack_s"] + r["hack_t"] and same for gt.
This commit is contained in:
wassname
2026-05-27 23:33:12 +00:00
parent 577f075611
commit ecfb3bf30a
2 changed files with 38 additions and 23 deletions
+11 -11
View File
@@ -13,22 +13,22 @@ TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry poi
default:
@just --list
# fast-dev-run: tiny-random model, full smoke pipeline end-to-end, ~1-2 min, beartype on.
fast-dev-run *ARGS:
BEARTYPE=1 {{ BASE }} --fast-dev-run --model={{ TINY_MODEL }} {{ ARGS }}
# Real-pipeline presets (train.py = AntiPaSTO + Dr.GRPO + LeetCode rewards).
# smoke = Qwen3.5-0.8B 10 steps, fits 24GB. Mechanism verification only.
# full = Qwen3-4B 200 steps G=6, peaks ~90GB on 96GB. spec.md §H4 substrate.
# Smoke: same harness as production (train.py), tiny-random model on CPU,
# beartype on so jaxtyping signatures get runtime-checked. Runs 30 steps so
# the every-25-step save_ckpt path is covered. Should finish in ~1-2 min.
# Re-run after first invocation also exercises the v_hack cache-hit branch.
smoke *ARGS:
{{ TRAIN }} --preset=smoke --arm=projected --v-hack-path=out/v_hack_smoke.safetensors {{ ARGS }}
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=projected \
--v-hack-path=out/v_hack_smoke.safetensors {{ ARGS }}
smoke-vanilla *ARGS:
{{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
BEARTYPE=1 CUDA_VISIBLE_DEVICES= {{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
# Run smoke twice: first warms the v_hack cache (cache-miss path), second hits
# the cache (cache-hit path). Catches scope/save bugs that only manifest in one.
smoke-both:
{{ TRAIN }} --preset=smoke --arm=vanilla
{{ TRAIN }} --preset=smoke --arm=projected --v-hack-path=out/v_hack_smoke.safetensors
just smoke-vanilla
just smoke
# H4 baseline at spec substrate. No v_hack needed for vanilla.
full-vanilla *ARGS:
+27 -12
View File
@@ -120,8 +120,11 @@ class Preset(str, Enum):
PRESETS: dict[str, dict] = {
"smoke": dict(model="Qwen/Qwen3.5-0.8B", steps=10, group=2, max_new=128,
n_problems=30, beta=0.0, prompts_per_step=1), # 24GB cap
# steps=30 (not 10) so save_ckpt's every-25-step trigger fires under smoke.
# That catches checkpoint-save bugs that only manifest after step 25 (e.g.
# closure-scope NameErrors in the save path).
"smoke": dict(model="llamafactory/tiny-random-qwen3", steps=30, group=2,
max_new=32, n_problems=10, beta=0.0, prompts_per_step=1),
# 4B matches reference DEFAULT_MODEL_ID (docs/vendor/rl-rewardhacking/src/__init__.py).
# G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt
# problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%;
@@ -264,10 +267,14 @@ def load_v_hack(
)
if saved_model != model_name:
raise ValueError(f"v_hack model mismatch: {path} has {saved_model}, run uses {model_name}")
if saved_dtype != "bf16":
# dtype mismatch: cross-dtype SVD bases can diverge silently, so error
# unless the saved dtype matches what train.py uses on this device.
# CPU runs in fp32, CUDA runs in bf16 (see model-load site above).
expected_dtype = "fp32" if torch.cuda.is_available() is False else "bf16"
if saved_dtype != expected_dtype:
raise ValueError(
f"v_hack dtype/SVD-basis mismatch: {path} was extracted with dtype={saved_dtype}; "
"train.py loads models in bf16. Re-extract with `--dtype=bf16`."
f"this run loads models in {expected_dtype}. Re-extract with `--dtype={expected_dtype}`."
)
v_hack = {k: f.get_tensor(k) for k in f.keys() if not k.startswith("_sv/")}
v_sv = {k[len("_sv/"):]: f.get_tensor(k) for k in f.keys() if k.startswith("_sv/")}
@@ -381,9 +388,13 @@ def main(cfg: Config) -> int:
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
# On CPU smoke we fall back to fp32 + sdpa: flash-attn2 is CUDA-only and
# CPU bf16 is patchy. Production GPU runs keep bf16 + flash_attention_2.
cpu = device.type == "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
model_name,
dtype=torch.float32 if cpu else torch.bfloat16,
attn_implementation="sdpa" if cpu else "flash_attention_2",
).to(device)
# No gradient checkpointing: grad-accum forwards one G-group (6 seqs) at a time,
# so peak activation memory is ~6 x merged_len, which fits at G=6 on 96GB without
@@ -428,7 +439,8 @@ def main(cfg: Config) -> int:
# for the singular values. load_v_hack splits them back apart.
save_payload = {**v_hack_extracted, **{f"_sv/{n}": s for n, s in v_sv_extracted.items()}}
save_file(save_payload, str(v_hack_path),
metadata={"model": model_name, "dtype": "bf16",
metadata={"model": model_name,
"dtype": "fp32" if cpu else "bf16",
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
"tau_axis": str(cfg.v_hack_tau_axis), "schema": "v2_with_sv"})
# extract zeros grads at exit; opt is built below so no opt-state taint.
@@ -589,7 +601,8 @@ def main(cfg: Config) -> int:
def _fmt_header() -> str:
return " ".join(f"{_header_labels[c]:>{_col_w[c]}}" for c in _row_cols)
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
est_gens_per_step = cfg.prompts_per_step * cfg.group # before mixed-pool split
# Use the resolved locals (preset defaults merged), not cfg.* which can be None.
est_gens_per_step = prompts_per_step * group # before mixed-pool split
logger.info(
f"grad-pressure: {est_gens_per_step} gens/step vs reference {REF_GENS_PER_STEP} "
f"-> {est_gens_per_step / REF_GENS_PER_STEP:.2f}x per step; "
@@ -635,8 +648,10 @@ table columns:
kill keeps everything up to the last save. Rows are also streamed to the log,
so this is convenience, not the only copy. Mirrors the v_hack metadata idiom."""
n_gens = sum(r["N"] for r in rows)
hr = sum(int(r["hack"].split("/")[0]) for r in rows) / max(1, n_gens)
pr = sum(int(r["gt"].split("/")[0]) for r in rows) / max(1, n_gens)
# Aggregate from per-source columns (the combined hack/gt aggregates were
# dropped from the per-step table as redundant; reconstruct here).
hr = sum(int(r["hack_s"].split("/")[0]) + int(r["hack_t"].split("/")[0]) for r in rows) / max(1, n_gens)
pr = sum(int(r["gt_s"].split("/")[0]) + int(r["gt_t"].split("/")[0]) for r in rows) / max(1, n_gens)
tensors = {n: info["delta_S"].detach().cpu().contiguous()
for n, info in wrappers.items()}
save_file(tensors, str(path or ckpt_path), metadata={
@@ -1058,8 +1073,8 @@ table columns:
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
n_steps = len(rows)
n_gens = sum(r["N"] for r in rows)
total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows)
total_pass = sum(int(r["gt"].split("/")[0]) for r in rows)
total_hacks = sum(int(r["hack_s"].split("/")[0]) + int(r["hack_t"].split("/")[0]) for r in rows)
total_pass = sum(int(r["gt_s"].split("/")[0]) + int(r["gt_t"].split("/")[0]) for r in rows)
hack_rate = total_hacks / max(1, n_gens)
pass_rate = total_pass / max(1, n_gens)
# Per-source totals. On no-teacher runs, hack_s_total == total_hacks.