fix: lr 3e-4 + 20% warmup (5e-4 diverged at warmup-end); slim bf16 ckpts

- FastConfig lr 5e-4 -> 3e-4: 5e-4 peaked exactly at warmup-end (step ~10)
  and diverged (lp_t -0.5 -> -4.8, hack_s 20/24 -> 0). Lower peak + longer
  warmup defuse the spike.
- Config warmup_frac 0.1 -> 0.2: SequentialLR(LinearLR, CosineAnnealingLR)
  already does warmup+cosine relaxation; just reach the peak more gradually.
- save_ckpt: drop A0/B0 (seeded init, regenerable from lora_init_seed;
  ckpt_update0000 is the init since A==A0 at step 0; nothing live reloads
  them), save A/B bf16 not fp32. ~1.3G -> ~0.33G per ckpt.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 01:42:20 +00:00
parent f1dd9fb33e
commit a72835315c
2 changed files with 13 additions and 13 deletions
+7 -8
View File
@@ -549,21 +549,20 @@ def main(cfg: Config) -> int:
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
def save_ckpt(rows: list[dict], path: Path | None = None) -> None: def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
"""Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0, """Save a lora2r checkpoint: trainable A/B only, bf16. The init A0/B0 are NOT
so a loader reconstructs the net delta (B@A - B0@A0) and can ablate the saved -- they're the seeded Gaussian (regenerable from lora_init_seed) and
quarantine without any SVD cache. Config + per-step rows in the metadata.""" ckpt_update0000 holds them anyway (A==A0 at step 0); nothing live reloads
them. Config + per-step rows in the metadata."""
n_gens = sum(r["N"] for r in rows) n_gens = sum(r["N"] for r in rows)
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens) hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens) pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
_ckpt = path or ckpt_path _ckpt = path or ckpt_path
tensors = {} tensors = {}
for n, info in wrappers.items(): for n, info in wrappers.items():
tensors[f"A/{n}"] = info["A"].detach().float().cpu().contiguous() tensors[f"A/{n}"] = info["A"].detach().bfloat16().cpu().contiguous()
tensors[f"B/{n}"] = info["B"].detach().float().cpu().contiguous() tensors[f"B/{n}"] = info["B"].detach().bfloat16().cpu().contiguous()
tensors[f"A0/{n}"] = info["A0"].detach().float().cpu().contiguous()
tensors[f"B0/{n}"] = info["B0"].detach().float().cpu().contiguous()
save_file(tensors, str(_ckpt), metadata={ save_file(tensors, str(_ckpt), metadata={
"model": model_name, "dtype": "fp32", "step": str(len(rows)), "model": model_name, "dtype": "bf16", "step": str(len(rows)),
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}", "hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
"rows": json.dumps(rows), "rows": json.dumps(rows),
"cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str), "cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str),
+6 -5
View File
@@ -37,7 +37,7 @@ class Config:
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive # AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
# the net delta to -B0@A0 -- must stay 0 for this adapter. # the net delta to -B0@A0 -- must stay 0 for this adapter.
weight_decay: float = 0.0 weight_decay: float = 0.0
warmup_frac: float = 0.1 warmup_frac: float = 0.2
grad_clip: float = 10.0 grad_clip: float = 10.0
seed: int = 41 seed: int = 41
unbiased: bool = True unbiased: bool = True
@@ -127,8 +127,9 @@ class FastConfig(Config):
prompts_per_step: int = 4 prompts_per_step: int = 4
adam_beta1: float = 0.5 adam_beta1: float = 0.5
adam_beta2: float = 0.9 adam_beta2: float = 0.9
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget lr: float = 3e-4 # 5e-4 peaked at warmup-end (step ~10) and diverged; 3e-4 + 20% warmup
# Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0 for 252 modules, fp32); 20-step # Each lora2r ckpt is ~0.33G (A/B for 252 modules, bf16). A0/B0 are NOT saved -- they're
# cadence keeps ~6/run for the eval curve without filling the 768G disk. (TODO: drop A0/B0 # the seeded init (regenerable from lora_init_seed in the metadata; ckpt_update0000 is the
# from ckpts -- reconstructible from lora_init_seed -- to halve size, needs a loader change.) # init since A==A0 at step 0), and nothing live reloads them (inline eval uses in-memory
# wrappers). 20-step cadence keeps ~6/run for the eval curve.
save_ckpt_every: int = 20 save_ckpt_every: int = 20