Drop grad checkpointing, KV cache for generate, periodic safetensors ckpt + phase timing

- Drop gradient_checkpointing: at G=6 grad-accum forwards one 6-seq group at a
  time, so activation peak fits on 96GB without recompute; removes the ~1.3-1.5x
  backward recompute. enable_input_require_grads was a checkpointing-only trick.
- Toggle use_cache=True around model.generate (False for the loss forwards).
  Cacheless decode was O(L^2); measured 2.17x faster with cache on the wrapped 4B.
- Replace end-of-run torch.save(.pt) with save_ckpt(): trainable delta_S as
  safetensors tensors + rows/config as JSON metadata (str->str), written every
  25 steps and at the end so an early kill keeps progress. Mirrors v_hack idiom.
- Per-step TIMING log (gen / fwd_bwd / reward) to attribute wall-time. Diagnosed
  generation as ~93% of step cost (HF generate slow; full-rank reparam adds 1.5x).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-24 12:45:21 +00:00
parent 6f68ba34b6
commit fa24f4eb4b
+60 -15
View File
@@ -70,6 +70,7 @@ import torch
import tyro
from loguru import logger
from safetensors import safe_open
from safetensors.torch import save_file
from tabulate import tabulate
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
@@ -302,14 +303,14 @@ def main(cfg: Config) -> int:
model_name, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
# Trade compute for memory: recompute activations during backward. ~30-50%
# less activation VRAM on the policy forward, enough to fit G=8 max_new=1024
# 2B with autograd on a 96GB card. Required `use_cache=False`.
# `enable_input_require_grads` is the canonical PEFT trick: base params are
# frozen, only delta_S has grad. Without this the embedding output has
# requires_grad=False and HF's checkpoint() shorts out (no recompute).
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# No gradient checkpointing: grad-accum forwards one G-group (6 seqs) at a time,
# so peak activation memory is ~6 x merged_len, which fits at G=6 on 96GB without
# recompute (worst-case merged 2048; flash-attn keeps attention O(N), MLP/residual
# store ~12-15GB). Dropping checkpointing removes the backward recompute (~1.3-1.5x
# on the train-compute portion). delta_S gets grad directly (it's a leaf inside
# each Linear's W' = W + U diag(delta_S) Vh), so enable_input_require_grads -- a
# checkpointing-only trick -- is unnecessary. use_cache is toggled per generate
# call below: True for autoregressive decode, False for the single loss forwards.
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device)
@@ -372,6 +373,28 @@ def main(cfg: Config) -> int:
"gt", "hack", "loss", "cin", "cout", "fired", "sec"]
logger.info("row\t" + "\t".join(_row_cols))
OUT_DIR.mkdir(exist_ok=True)
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
ckpt_path = OUT_DIR / f"train{tag}.safetensors"
def save_ckpt(rows: list[dict]) -> None:
"""Rewrite the run checkpoint in place: trainable delta_S as tensors, per-step
rows + config as JSON metadata (safetensors metadata is str->str only, so the
non-tensor payload is JSON). Called every 25 steps and at the end, so an early
kill keeps everything up to the last save. Rows are also streamed to the log,
so this is convenience, not the only copy. Mirrors the v_hack metadata idiom."""
n_gens = sum(r["N"] for r in rows)
hr = sum(int(r["hack"].split("/")[0]) for r in rows) / max(1, n_gens)
pr = sum(int(r["gt"].split("/")[0]) for r in rows) / max(1, n_gens)
tensors = {n: info["delta_S"].detach().cpu().contiguous()
for n, info in wrappers.items()}
save_file(tensors, str(ckpt_path), metadata={
"model": model_name, "dtype": "bf16", "step": str(len(rows)),
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
"rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str),
"resolved": json.dumps(p),
})
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
for step in pbar:
t0 = time.time()
@@ -383,6 +406,12 @@ def main(cfg: Config) -> int:
agg_comp_lens, agg_finished, n_skipped = [], [], 0
agg_loss = 0.0
diag_tail = None
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
# without explicit cuda.synchronize. Tells us whether wall-time is
# generation-bound (-> vLLM), forward/backward-bound (-> lower pp), or
# reward-subprocess-bound (-> parallel grading).
t_gen = t_rew = t_fb = 0.0
for p_idx in range(prompts_per_step):
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
@@ -397,11 +426,20 @@ def main(cfg: Config) -> int:
n_skipped += 1
continue
# KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute
# per token) -- cacheless was the ~19min/step cost. Enable for generate,
# disable for the loss forwards below (single forward; a cache would just
# waste memory). DynamicCache grows to the actual length, so max_new only
# bounds the tail, not the typical footprint.
model.config.use_cache = True
_tg = time.perf_counter()
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
model.config.use_cache = False
merged = gen_out
completions = gen_out[:, plen:]
texts = tok.batch_decode(completions, skip_special_tokens=True)
t_gen += time.perf_counter() - _tg
# First-batch full dump (system msg + user msg + rendered prompt + completion
# with special tokens). Goes to verbose log only — stdout stays clean.
@@ -426,6 +464,7 @@ def main(cfg: Config) -> int:
finished = [bool((c == eos_id).any().item()) for c in completions]
agg_comp_lens.extend(comp_lens); agg_finished.extend(finished)
_tr = time.perf_counter()
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
for t in texts:
r = compute_reward(
@@ -434,6 +473,7 @@ def main(cfg: Config) -> int:
)
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
fmt_flags.append(r.format_ok)
t_rew += time.perf_counter() - _tr
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
if (step < 3 or step % 20 == 0) and p_idx == 0:
@@ -460,6 +500,7 @@ def main(cfg: Config) -> int:
# drops the last (predicts beyond `merged`, unused).
completion_ids = merged[:, plen:]
L_c = completion_ids.shape[1]
_tfb = time.perf_counter()
with torch.no_grad():
gen_logp = per_token_logps(
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
@@ -492,6 +533,7 @@ def main(cfg: Config) -> int:
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() / prompts_per_step
loss.backward()
agg_loss += loss.item()
t_fb += time.perf_counter() - _tfb
# One projection on accumulated grads (projected arm only).
if cfg.arm == "projected":
@@ -522,6 +564,13 @@ def main(cfg: Config) -> int:
f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} "
f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step}"
)
_tstep = time.time() - t0
logger.info(
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s "
f"total={_tstep:.0f}s | SHOULD: identify dominant phase. "
f"gen-bound -> vLLM; fwd_bwd-bound -> lower pp; reward-bound -> parallel grading"
)
if diag_tail is not None:
tail = diag_tail.replace("\n", "\\n")
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
@@ -543,6 +592,8 @@ def main(cfg: Config) -> int:
rows.append(row)
# Stream this step as TSV row (header was printed before the loop).
logger.info("row\t" + "\t".join(str(row[c]) for c in _row_cols))
if (step + 1) % 25 == 0:
save_ckpt(rows) # survive early kills; ~12 days for the full sweep
# Live status in tqdm postfix; full per-step line in verbose log only.
pbar.set_postfix(
rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}",
@@ -586,13 +637,7 @@ def main(cfg: Config) -> int:
"tag": cfg.out_tag, "log": str(verbose_log),
}], headers="keys", tablefmt="tsv"))
OUT_DIR.mkdir(exist_ok=True)
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
torch.save(
{"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate,
"cfg": vars(cfg), "resolved": p},
OUT_DIR / f"train{tag}.pt",
)
save_ckpt(rows)
return 0