diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 442ba82..8ea178e 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -70,6 +70,7 @@ import torch import tyro from loguru import logger from safetensors import safe_open +from safetensors.torch import save_file from tabulate import tabulate from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig @@ -302,14 +303,14 @@ def main(cfg: Config) -> int: model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) - # Trade compute for memory: recompute activations during backward. ~30-50% - # less activation VRAM on the policy forward, enough to fit G=8 max_new=1024 - # 2B with autograd on a 96GB card. Required `use_cache=False`. - # `enable_input_require_grads` is the canonical PEFT trick: base params are - # frozen, only delta_S has grad. Without this the embedding output has - # requires_grad=False and HF's checkpoint() shorts out (no recompute). - model.gradient_checkpointing_enable() - model.enable_input_require_grads() + # No gradient checkpointing: grad-accum forwards one G-group (6 seqs) at a time, + # so peak activation memory is ~6 x merged_len, which fits at G=6 on 96GB without + # recompute (worst-case merged 2048; flash-attn keeps attention O(N), MLP/residual + # store ~12-15GB). Dropping checkpointing removes the backward recompute (~1.3-1.5x + # on the train-compute portion). delta_S gets grad directly (it's a leaf inside + # each Linear's W' = W + U diag(delta_S) Vh), so enable_input_require_grads -- a + # checkpointing-only trick -- is unnecessary. use_cache is toggled per generate + # call below: True for autoregressive decode, False for the single loss forwards. model.config.use_cache = False wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device) @@ -372,6 +373,28 @@ def main(cfg: Config) -> int: "gt", "hack", "loss", "cin", "cout", "fired", "sec"] logger.info("row\t" + "\t".join(_row_cols)) + OUT_DIR.mkdir(exist_ok=True) + tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}" + ckpt_path = OUT_DIR / f"train{tag}.safetensors" + + def save_ckpt(rows: list[dict]) -> None: + """Rewrite the run checkpoint in place: trainable delta_S as tensors, per-step + rows + config as JSON metadata (safetensors metadata is str->str only, so the + non-tensor payload is JSON). Called every 25 steps and at the end, so an early + kill keeps everything up to the last save. Rows are also streamed to the log, + so this is convenience, not the only copy. Mirrors the v_hack metadata idiom.""" + n_gens = sum(r["N"] for r in rows) + hr = sum(int(r["hack"].split("/")[0]) for r in rows) / max(1, n_gens) + pr = sum(int(r["gt"].split("/")[0]) for r in rows) / max(1, n_gens) + tensors = {n: info["delta_S"].detach().cpu().contiguous() + for n, info in wrappers.items()} + save_file(tensors, str(ckpt_path), metadata={ + "model": model_name, "dtype": "bf16", "step": str(len(rows)), + "hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}", + "rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str), + "resolved": json.dumps(p), + }) + pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60) for step in pbar: t0 = time.time() @@ -383,6 +406,12 @@ def main(cfg: Config) -> int: agg_comp_lens, agg_finished, n_skipped = [], [], 0 agg_loss = 0.0 diag_tail = None + # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a + # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate + # without explicit cuda.synchronize. Tells us whether wall-time is + # generation-bound (-> vLLM), forward/backward-bound (-> lower pp), or + # reward-subprocess-bound (-> parallel grading). + t_gen = t_rew = t_fb = 0.0 for p_idx in range(prompts_per_step): idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) @@ -397,11 +426,20 @@ def main(cfg: Config) -> int: n_skipped += 1 continue + # KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute + # per token) -- cacheless was the ~19min/step cost. Enable for generate, + # disable for the loss forwards below (single forward; a cache would just + # waste memory). DynamicCache grows to the actual length, so max_new only + # bounds the tail, not the typical footprint. + model.config.use_cache = True + _tg = time.perf_counter() with torch.no_grad(): gen_out = model.generate(**enc, generation_config=gen_cfg).detach() + model.config.use_cache = False merged = gen_out completions = gen_out[:, plen:] texts = tok.batch_decode(completions, skip_special_tokens=True) + t_gen += time.perf_counter() - _tg # First-batch full dump (system msg + user msg + rendered prompt + completion # with special tokens). Goes to verbose log only — stdout stays clean. @@ -426,6 +464,7 @@ def main(cfg: Config) -> int: finished = [bool((c == eos_id).any().item()) for c in completions] agg_comp_lens.extend(comp_lens); agg_finished.extend(finished) + _tr = time.perf_counter() rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] for t in texts: r = compute_reward( @@ -434,6 +473,7 @@ def main(cfg: Config) -> int: ) rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass) fmt_flags.append(r.format_ok) + t_rew += time.perf_counter() - _tr agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags) if (step < 3 or step % 20 == 0) and p_idx == 0: @@ -460,6 +500,7 @@ def main(cfg: Config) -> int: # drops the last (predicts beyond `merged`, unused). completion_ids = merged[:, plen:] L_c = completion_ids.shape[1] + _tfb = time.perf_counter() with torch.no_grad(): gen_logp = per_token_logps( model(merged, logits_to_keep=L_c + 1).logits[:, :-1], @@ -492,6 +533,7 @@ def main(cfg: Config) -> int: loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() / prompts_per_step loss.backward() agg_loss += loss.item() + t_fb += time.perf_counter() - _tfb # One projection on accumulated grads (projected arm only). if cfg.arm == "projected": @@ -522,6 +564,13 @@ def main(cfg: Config) -> int: f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} " f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step}" ) + _tstep = time.time() - t0 + logger.info( + f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s " + f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s " + f"total={_tstep:.0f}s | SHOULD: identify dominant phase. " + f"gen-bound -> vLLM; fwd_bwd-bound -> lower pp; reward-bound -> parallel grading" + ) if diag_tail is not None: tail = diag_tail.replace("\n", "\\n") logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}") @@ -543,6 +592,8 @@ def main(cfg: Config) -> int: rows.append(row) # Stream this step as TSV row (header was printed before the loop). logger.info("row\t" + "\t".join(str(row[c]) for c in _row_cols)) + if (step + 1) % 25 == 0: + save_ckpt(rows) # survive early kills; ~12 days for the full sweep # Live status in tqdm postfix; full per-step line in verbose log only. pbar.set_postfix( rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}", @@ -586,13 +637,7 @@ def main(cfg: Config) -> int: "tag": cfg.out_tag, "log": str(verbose_log), }], headers="keys", tablefmt="tsv")) - OUT_DIR.mkdir(exist_ok=True) - tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}" - torch.save( - {"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate, - "cfg": vars(cfg), "resolved": p}, - OUT_DIR / f"train{tag}.pt", - ) + save_ckpt(rows) return 0