mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
Drop grad checkpointing, KV cache for generate, periodic safetensors ckpt + phase timing
- Drop gradient_checkpointing: at G=6 grad-accum forwards one 6-seq group at a time, so activation peak fits on 96GB without recompute; removes the ~1.3-1.5x backward recompute. enable_input_require_grads was a checkpointing-only trick. - Toggle use_cache=True around model.generate (False for the loss forwards). Cacheless decode was O(L^2); measured 2.17x faster with cache on the wrapped 4B. - Replace end-of-run torch.save(.pt) with save_ckpt(): trainable delta_S as safetensors tensors + rows/config as JSON metadata (str->str), written every 25 steps and at the end so an early kill keeps progress. Mirrors v_hack idiom. - Per-step TIMING log (gen / fwd_bwd / reward) to attribute wall-time. Diagnosed generation as ~93% of step cost (HF generate slow; full-rank reparam adds 1.5x). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+60
-15
@@ -70,6 +70,7 @@ import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
@@ -302,14 +303,14 @@ def main(cfg: Config) -> int:
|
||||
model_name, dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to(device)
|
||||
# Trade compute for memory: recompute activations during backward. ~30-50%
|
||||
# less activation VRAM on the policy forward, enough to fit G=8 max_new=1024
|
||||
# 2B with autograd on a 96GB card. Required `use_cache=False`.
|
||||
# `enable_input_require_grads` is the canonical PEFT trick: base params are
|
||||
# frozen, only delta_S has grad. Without this the embedding output has
|
||||
# requires_grad=False and HF's checkpoint() shorts out (no recompute).
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
# No gradient checkpointing: grad-accum forwards one G-group (6 seqs) at a time,
|
||||
# so peak activation memory is ~6 x merged_len, which fits at G=6 on 96GB without
|
||||
# recompute (worst-case merged 2048; flash-attn keeps attention O(N), MLP/residual
|
||||
# store ~12-15GB). Dropping checkpointing removes the backward recompute (~1.3-1.5x
|
||||
# on the train-compute portion). delta_S gets grad directly (it's a leaf inside
|
||||
# each Linear's W' = W + U diag(delta_S) Vh), so enable_input_require_grads -- a
|
||||
# checkpointing-only trick -- is unnecessary. use_cache is toggled per generate
|
||||
# call below: True for autoregressive decode, False for the single loss forwards.
|
||||
model.config.use_cache = False
|
||||
|
||||
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device)
|
||||
@@ -372,6 +373,28 @@ def main(cfg: Config) -> int:
|
||||
"gt", "hack", "loss", "cin", "cout", "fired", "sec"]
|
||||
logger.info("row\t" + "\t".join(_row_cols))
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
|
||||
ckpt_path = OUT_DIR / f"train{tag}.safetensors"
|
||||
|
||||
def save_ckpt(rows: list[dict]) -> None:
|
||||
"""Rewrite the run checkpoint in place: trainable delta_S as tensors, per-step
|
||||
rows + config as JSON metadata (safetensors metadata is str->str only, so the
|
||||
non-tensor payload is JSON). Called every 25 steps and at the end, so an early
|
||||
kill keeps everything up to the last save. Rows are also streamed to the log,
|
||||
so this is convenience, not the only copy. Mirrors the v_hack metadata idiom."""
|
||||
n_gens = sum(r["N"] for r in rows)
|
||||
hr = sum(int(r["hack"].split("/")[0]) for r in rows) / max(1, n_gens)
|
||||
pr = sum(int(r["gt"].split("/")[0]) for r in rows) / max(1, n_gens)
|
||||
tensors = {n: info["delta_S"].detach().cpu().contiguous()
|
||||
for n, info in wrappers.items()}
|
||||
save_file(tensors, str(ckpt_path), metadata={
|
||||
"model": model_name, "dtype": "bf16", "step": str(len(rows)),
|
||||
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
|
||||
"rows": json.dumps(rows), "cfg": json.dumps(vars(cfg), default=str),
|
||||
"resolved": json.dumps(p),
|
||||
})
|
||||
|
||||
pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset.value}", mininterval=60)
|
||||
for step in pbar:
|
||||
t0 = time.time()
|
||||
@@ -383,6 +406,12 @@ def main(cfg: Config) -> int:
|
||||
agg_comp_lens, agg_finished, n_skipped = [], [], 0
|
||||
agg_loss = 0.0
|
||||
diag_tail = None
|
||||
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
|
||||
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
|
||||
# without explicit cuda.synchronize. Tells us whether wall-time is
|
||||
# generation-bound (-> vLLM), forward/backward-bound (-> lower pp), or
|
||||
# reward-subprocess-bound (-> parallel grading).
|
||||
t_gen = t_rew = t_fb = 0.0
|
||||
|
||||
for p_idx in range(prompts_per_step):
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
@@ -397,11 +426,20 @@ def main(cfg: Config) -> int:
|
||||
n_skipped += 1
|
||||
continue
|
||||
|
||||
# KV cache is essential for autoregressive decode (O(L) vs O(L^2) recompute
|
||||
# per token) -- cacheless was the ~19min/step cost. Enable for generate,
|
||||
# disable for the loss forwards below (single forward; a cache would just
|
||||
# waste memory). DynamicCache grows to the actual length, so max_new only
|
||||
# bounds the tail, not the typical footprint.
|
||||
model.config.use_cache = True
|
||||
_tg = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
model.config.use_cache = False
|
||||
merged = gen_out
|
||||
completions = gen_out[:, plen:]
|
||||
texts = tok.batch_decode(completions, skip_special_tokens=True)
|
||||
t_gen += time.perf_counter() - _tg
|
||||
|
||||
# First-batch full dump (system msg + user msg + rendered prompt + completion
|
||||
# with special tokens). Goes to verbose log only — stdout stays clean.
|
||||
@@ -426,6 +464,7 @@ def main(cfg: Config) -> int:
|
||||
finished = [bool((c == eos_id).any().item()) for c in completions]
|
||||
agg_comp_lens.extend(comp_lens); agg_finished.extend(finished)
|
||||
|
||||
_tr = time.perf_counter()
|
||||
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
|
||||
for t in texts:
|
||||
r = compute_reward(
|
||||
@@ -434,6 +473,7 @@ def main(cfg: Config) -> int:
|
||||
)
|
||||
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
|
||||
fmt_flags.append(r.format_ok)
|
||||
t_rew += time.perf_counter() - _tr
|
||||
agg_rew.extend(rs); agg_gt.extend(gt_flags); agg_hack.extend(hack_flags); agg_fmt.extend(fmt_flags)
|
||||
|
||||
if (step < 3 or step % 20 == 0) and p_idx == 0:
|
||||
@@ -460,6 +500,7 @@ def main(cfg: Config) -> int:
|
||||
# drops the last (predicts beyond `merged`, unused).
|
||||
completion_ids = merged[:, plen:]
|
||||
L_c = completion_ids.shape[1]
|
||||
_tfb = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
gen_logp = per_token_logps(
|
||||
model(merged, logits_to_keep=L_c + 1).logits[:, :-1],
|
||||
@@ -492,6 +533,7 @@ def main(cfg: Config) -> int:
|
||||
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() / prompts_per_step
|
||||
loss.backward()
|
||||
agg_loss += loss.item()
|
||||
t_fb += time.perf_counter() - _tfb
|
||||
|
||||
# One projection on accumulated grads (projected arm only).
|
||||
if cfg.arm == "projected":
|
||||
@@ -522,6 +564,13 @@ def main(cfg: Config) -> int:
|
||||
f"max_new={max_new} fmt={sum(agg_fmt)}/{n_rollouts} gt={sum(agg_gt)}/{n_rollouts} "
|
||||
f"hack={sum(agg_hack)}/{n_rollouts} skipped={n_skipped}/{prompts_per_step}"
|
||||
)
|
||||
_tstep = time.time() - t0
|
||||
logger.info(
|
||||
f"step {step} TIMING gen={t_gen:.0f}s fwd_bwd={t_fb:.0f}s "
|
||||
f"reward={t_rew:.0f}s other={_tstep - t_gen - t_fb - t_rew:.0f}s "
|
||||
f"total={_tstep:.0f}s | SHOULD: identify dominant phase. "
|
||||
f"gen-bound -> vLLM; fwd_bwd-bound -> lower pp; reward-bound -> parallel grading"
|
||||
)
|
||||
if diag_tail is not None:
|
||||
tail = diag_tail.replace("\n", "\\n")
|
||||
logger.debug(f"step {step} gen[0] tail (last 400 chars): {tail!r}")
|
||||
@@ -543,6 +592,8 @@ def main(cfg: Config) -> int:
|
||||
rows.append(row)
|
||||
# Stream this step as TSV row (header was printed before the loop).
|
||||
logger.info("row\t" + "\t".join(str(row[c]) for c in _row_cols))
|
||||
if (step + 1) % 25 == 0:
|
||||
save_ckpt(rows) # survive early kills; ~12 days for the full sweep
|
||||
# Live status in tqdm postfix; full per-step line in verbose log only.
|
||||
pbar.set_postfix(
|
||||
rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}",
|
||||
@@ -586,13 +637,7 @@ def main(cfg: Config) -> int:
|
||||
"tag": cfg.out_tag, "log": str(verbose_log),
|
||||
}], headers="keys", tablefmt="tsv"))
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
|
||||
torch.save(
|
||||
{"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate,
|
||||
"cfg": vars(cfg), "resolved": p},
|
||||
OUT_DIR / f"train{tag}.pt",
|
||||
)
|
||||
save_ckpt(rows)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user