benchmark sweep: rot(U/both) ablation, whitening conclusion, cost rows

- antipasto_rot: add rotate_basis="both" (independent V+U Cayley rotations),
  run_id suffix __rotU/__rotboth so ablation arms get their own output dirs
- justfile: thread rotate_basis through bench-variant
- corda/eva: padding-mask fix in calibration capture + bf16-tight residual
- README: fill PiSSA/DoRA/CorDA/ASVD/ablate/dplr/rot rows; record the
  metric-axis ablation (C=I 56.0 > diag-C 55.6 > full-C 54.7) and the
  rotation ablation (V 57.2 > U 56.5 > both 55.6) conclusions
- docs/reviews: external ref-checks + deepseek/gpt reviews of the cores

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-17 06:17:53 +08:00
parent 7986edad2c
commit 5f9d90d8b8
18 changed files with 1432 additions and 140 deletions
+9 -1
View File
@@ -71,6 +71,13 @@ def measure_cost(
named = list(model.named_parameters()) + list(model.named_buffers())
adapter_bytes = sum(t.numel() * t.element_size() for n, t in named if adapter_filter in n)
# Adapter ADDED MACs/token, analytic and arch-independent (the FLOP counter below
# asserts on some fused/linear-attention shapes -> None). Each 2D adapter weight of
# shape (a, b) is used once in a per-token matmul, contributing a*b MACs; summing 2D
# adapter-tensor numel is therefore the exact added compute for the U/Vh/P/A/B paths.
# (Slight undercount for cores that reuse a factor twice, e.g. ablate's C C^T.)
added_macs_per_token = sum(t.numel() for n, t in named if adapter_filter in n and t.ndim == 2)
# FLOPs: one forward under the counter (no grad so we count inference cost).
# FlopCounterMode can assert on some fused attention shapes; degrade to None.
try:
@@ -92,7 +99,8 @@ def measure_cost(
return dict(
trainable_params=trainable_params,
adapter_resident_mb=adapter_bytes / 1e6,
flops=flops,
added_macs_per_token=added_macs_per_token, # adapter-only, always populated
flops=flops, # whole model, best-effort (None on hybrid attn)
macs_per_token=(flops / n_tokens) if (flops and n_tokens) else None,
fwd_ms=fwd_ms,
bwd_ms=bwd_ms,
+103 -27
View File
@@ -19,6 +19,7 @@ from tabulate import tabulate
from tqdm.auto import tqdm
import lora_lite as ll
from _cost import measure_cost, group_init_meter
PROMPT = "Question: {query} Think step by step.\nAnswer:"
@@ -37,6 +38,7 @@ CFG_BY_VARIANT = {
"antipasto_rot": ll.AntiPaSTORotConfig,
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
"antipasto_asvd": ll.AntiPaSTOASVDConfig,
"antipasto_dplr": ll.AntiPaSTODPLRConfig,
"road": ll.RoadConfig,
}
@@ -46,8 +48,8 @@ CFG_BY_VARIANT = {
class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_dplr", "road"] = "lora"
model: str = "Qwen/Qwen3.5-0.8B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_asvd", "antipasto_dplr", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -63,7 +65,7 @@ class BenchmarkConfig:
antipasto_ablate_k: int = 1
antipasto_cov_orient: bool = False
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
antipasto_rotate_basis: Literal["V", "U", "both", "none"] = "V"
# AntiPaSTO-dplr: rank of the low-rank mixing core in the frozen subspace.
antipasto_lora_rank: int = 8
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
@@ -71,8 +73,9 @@ class BenchmarkConfig:
train_dataset: str = "meta-math/MetaMathQA"
eval_dataset: str = "openai/gsm8k"
eval_config: str = "main"
steps: int = 5000
batch_size: int = 4
steps: int = 5000 # optimizer updates (each accumulates grad_accum micro-batches)
batch_size: int = 4 # micro-batch (memory-bound); effective batch = batch_size * grad_accum
grad_accum: int = 1 # gradient accumulation: raise effective batch without more memory
batch_size_eval: int = 50
max_train_samples: int | None = None
max_eval_samples: int | None = None
@@ -141,7 +144,7 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
extra = {"rotate_basis": args.antipasto_rotate_basis}
if args.variant == "antipasto":
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
if args.variant == "antipasto_corda":
if args.variant in ("antipasto_corda", "antipasto_asvd"):
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
if args.variant == "antipasto_ablate":
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
@@ -282,7 +285,8 @@ def pad_batch(examples: list[dict[str, torch.Tensor | int]], pad_token_id: int,
def make_train_batches(train_dataset, tokenizer, args: BenchmarkConfig) -> tuple[list[dict[str, torch.Tensor | int]], int]:
needed = args.steps * args.batch_size
# steps optimizer updates x grad_accum micro-batches/update x batch_size examples/micro-batch.
needed = args.steps * args.grad_accum * args.batch_size
examples = []
skipped_prompt_too_long = 0
for row in train_dataset:
@@ -324,15 +328,23 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
last_loss = math.nan
train_total_tokens = 0
probe_batch = batches[0]
pbar = tqdm(batches, desc="train", mininterval=60.0, dynamic_ncols=True)
for step, batch in enumerate(pbar):
accum = args.grad_accum
# One optimizer update per `accum` micro-batches: scale each micro-loss by 1/accum so
# the accumulated gradient equals a single backward over the effective batch.
pbar = tqdm(range(args.steps), desc="train", mininterval=60.0, dynamic_ncols=True)
for step in pbar:
opt.zero_grad()
loss = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
).loss
loss.backward()
step_loss = 0.0
for micro in range(accum):
batch = batches[step * accum + micro]
loss = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
).loss / accum
loss.backward()
step_loss += loss.item() # micro already /accum -> sum is the mean
train_total_tokens += int(batch["label_tokens"])
grad_norm = sum(
p.grad.detach().float().norm().item()
for name, p in model.named_parameters()
@@ -340,13 +352,12 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
)
if step == 0:
first_grad_norm = grad_norm
first_loss = loss.item()
first_loss = step_loss
base_grad_leaks += count_base_grad_leaks(model)
torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.grad_norm_clip)
opt.step()
scheduler.step()
last_loss = loss.item()
train_total_tokens += int(batch["label_tokens"])
last_loss = step_loss
pbar.set_postfix(loss=f"{last_loss:.4g}", grad=f"{grad_norm:.3g}", tok=train_total_tokens)
pbar.close()
after = adapter_state(model)
@@ -462,6 +473,24 @@ def check_probe_reload(
return {"reload_err": reload_err, "saved_tensors": len(saved_sd)}
def print_first_train_sample(tokenizer, batch: dict[str, torch.Tensor | int]) -> None:
"""Dump row 0 of the first train batch WITH special tokens + the supervised span.
Transformers framing (pad side, eos, prompt/response boundary) is the #1 silent
fine-tune bug; printing the real encoded batch once is the cheap canary for it.
"""
ids = batch["input_ids"][0]
labels = batch["labels"][0]
sup = labels != -100 # positions contributing to the loss
print("\n=== first train sample (input_ids[0], special tokens shown) ===")
print(repr(tokenizer.decode(ids, skip_special_tokens=False)))
print("--- supervised span (labels != -100, what the model is trained to emit) ---")
print(repr(tokenizer.decode(ids[sup], skip_special_tokens=False)))
print(f"SHOULD: prompt ends with the PROMPT template then the answer+eos; supervised span = answer+eos ONLY "
f"(pad_side={tokenizer.padding_side}, eos={tokenizer.eos_token!r}). "
f"ELSE prompt/response boundary or pad/eos is mis-encoded. (len={len(ids)}, supervised={int(sup.sum())})\n")
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
# BLUF: status line first so log tails are immediately readable
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
@@ -469,9 +498,10 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
print()
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
print("SHOULD(cost): addMACs_M ~equal across antipasto cores at same r (r*(d_in+d_out)*n_targets added matmul); params_M differs (dplr/ablate add a trainable core); init_ms is large for the calibrated variants (corda/asvd/eva), and corda > asvd (full-covariance eigh vs cheap diagonal). ELSE the cost model is wrong.")
print()
# ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "params_M", "peak_mem_GB", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
display_keys = ["variant", "test_acc", "valid_acc", "params_M", "fwd_ms", "bwd_ms", "addMACs_M", "init_ms", "peak_mem_GB", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row:
display_keys += ["perturb", "reload"]
display_keys += ["run_id"]
@@ -500,6 +530,7 @@ def append_results_row(
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
c = result.get("cost", {})
row = {
"test_acc": result["test_acc"],
"valid_acc": result["valid_acc"],
@@ -508,6 +539,16 @@ def append_results_row(
"samples": result["train_samples"],
"params_M": round(result["trainable_param_count"] / 1e6, 4),
"peak_mem_GB": round(result.get("peak_cuda_mem_gb", 0.0), 3),
# cost profile (one-time, measured at attach; see _cost.py). All deterministic
# except the *_ms wall-times (median over warmup+iters), which stay noisy.
"fwd_ms": round(c["fwd_ms"], 3) if c.get("fwd_ms") else None,
"bwd_ms": round(c["bwd_ms"], 3) if c.get("bwd_ms") else None,
"added_macs_per_tok": c.get("added_macs_per_token"), # adapter-only, arch-independent
"fwd_macs": c.get("flops"), # whole model, None on hybrid attn
"macs_per_tok": round(c["macs_per_token"]) if c.get("macs_per_token") else None,
"adapter_mb": round(c["adapter_resident_mb"], 3) if c.get("adapter_resident_mb") else None,
"init_ms": round(c["init_ms"], 1) if c.get("init_ms") else None,
"init_peak_cpu_mb": round(c["init_peak_cpu_mb"], 1) if c.get("init_peak_cpu_mb") else None,
"model": args.model,
"commit": run_commit[:12],
"wall_time_s": round(result["wall_time_s"]),
@@ -520,6 +561,10 @@ def append_results_row(
values = "\t".join(str(value) for value in row.values())
with lock_path.open("w", encoding="utf-8") as lock_handle:
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX)
# Rotate the file aside if its header no longer matches (e.g. cost columns added),
# rather than appending misaligned rows under a stale header.
if tsv_path.exists() and tsv_path.read_text(encoding="utf-8").split("\n", 1)[0] != header:
tsv_path.rename(results_dir / f"summary.{finished_label}.tsv.bak")
if not tsv_path.exists():
tsv_path.write_text(header + "\n" + values + "\n", encoding="utf-8")
else:
@@ -542,6 +587,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
# antipasto family defaults to r=256; low-rank sweeps get their own dirs.
if args.variant.startswith("antipasto") and args.r != 256:
run_id += f"__r{args.r}"
# antipasto_rot defaults to rotating V; U/both are ablation axes -> own dirs.
if args.variant == "antipasto_rot" and args.antipasto_rotate_basis != "V":
run_id += f"__rot{args.antipasto_rotate_basis}"
# antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/
# low-rank cores want a tamer lr than the gain, so this is a real axis).
if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9:
@@ -552,23 +600,27 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
datasets = load_datasets(args)
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization)
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
print_first_train_sample(tokenizer, batches[0])
cfg = cfg_for_variant(args, dtype)
# Variants with a data-driven group_init need calibration activations from the
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its
# init; corda/cov-orient accumulate a d_in x d_in covariance, so follow PEFT's
# default of ~256 samples (64 batches x bs=4) for a well-conditioned C.
needs_calib = args.variant == "eva" or args.variant == "antipasto_corda" or (
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
)
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
if needs_calib:
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
calib = [
{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]}
for b in batches[:n_batches]
]
ll.attach(model, cfg, calibration_data=calib)
with init_meter: # CorDA's d_in^3 eigh on CPU is the cost asymmetry
ll.attach(model, cfg, calibration_data=calib)
else:
ll.attach(model, cfg)
with init_meter:
ll.attach(model, cfg)
attached = getattr(model, "_lora_lite_attached")
trainable_names = assert_only_lora_trainable(model)
probe_metrics = None
@@ -576,6 +628,22 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
probe_metrics = probe_before_train(model, batches[0], attached["targets"])
model.train()
# One-time cost profile of the attached adapter, measured BEFORE training (~free:
# ~10 fwd + ~10 fwd/bwd on one batch vs thousands of train steps). Reuses the exact
# train loss path (input_ids/attention_mask/labels -> .loss.backward) so fwd/bwd ms
# and FLOPs match what training pays. group_init cost captured separately above.
b0 = batches[0]
n_tokens = b0["input_ids"].numel() # padded positions the FLOP counter processes
def _cost_fwd():
model(input_ids=b0["input_ids"], attention_mask=b0["attention_mask"])
def _cost_bwd_step():
model.zero_grad(set_to_none=True)
model(input_ids=b0["input_ids"], attention_mask=b0["attention_mask"], labels=b0["labels"]).loss.backward()
cost = measure_cost(model, _cost_fwd, bwd_step_fn=_cost_bwd_step, n_tokens=n_tokens)
cost["init_ms"] = init_meter.ms
cost["init_peak_cpu_mb"] = init_meter.peak_cpu_mb
model.zero_grad(set_to_none=True) # clear cost-measurement grads before training
if args.device == "cuda":
torch.cuda.reset_peak_memory_stats()
started = time.time()
@@ -615,7 +683,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"steps": args.steps,
"batch_size": args.batch_size,
"batch_size_eval": args.batch_size_eval,
"train_samples": args.steps * args.batch_size,
"train_samples": args.steps * args.grad_accum * args.batch_size,
"grad_accum": args.grad_accum,
"effective_batch": args.grad_accum * args.batch_size,
"max_seq_length": args.max_seq_length,
"optimizer": "AdamW",
"lr": args.lr,
@@ -631,6 +701,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"adapter_path": str(adapter_path),
"wall_time_s": time.time() - started,
"peak_cuda_mem_gb": peak_mem_gb,
"cost": cost, # params, FLOPs/MACs, fwd/bwd ms, peak gpu mb, group_init ms + peak cpu mb
}
result_path = out_dir / "result.json"
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
@@ -645,7 +716,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"run_id": run_id,
"variant": args.variant,
"steps": args.steps,
"samples": args.steps * args.batch_size,
"samples": args.steps * args.grad_accum * args.batch_size,
"loss0": train_metrics["train_loss_first"],
"lossN": train_metrics["train_loss_last"],
"probeΔ": train_metrics["train_loss_probe_delta"],
@@ -656,6 +727,11 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
"test_acc": test_metrics["accuracy"],
"params_M": round(result["trainable_param_count"] / 1e6, 4),
"peak_mem_GB": round(peak_mem_gb, 3),
# cost profile (see _cost.py). fwd/bwd in ms, macs/token in M, init = group_init.
"fwd_ms": round(cost["fwd_ms"], 2) if cost.get("fwd_ms") else None,
"bwd_ms": round(cost["bwd_ms"], 2) if cost.get("bwd_ms") else None,
"addMACs_M": round(cost["added_macs_per_token"] / 1e6, 2) if cost.get("added_macs_per_token") else None,
"init_ms": round(cost["init_ms"], 1) if cost.get("init_ms") else None,
"commit": run_commit[:12],
"result": str(result_path),
}