This commit is contained in:
wassname
2026-04-27 11:44:40 +08:00
parent a342801807
commit 0bd091fe5b
11 changed files with 68 additions and 45 deletions
+8 -3
View File
@@ -16,6 +16,7 @@ from typing import Any, Literal
import torch
from tabulate import tabulate
from tqdm.auto import tqdm
import lora_lite as ll
@@ -297,7 +298,8 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
last_loss = math.nan
train_total_tokens = 0
probe_batch = batches[0]
for step, batch in enumerate(batches):
pbar = tqdm(batches, desc="train", mininterval=60.0, dynamic_ncols=True)
for step, batch in enumerate(pbar):
opt.zero_grad()
loss = model(
input_ids=batch["input_ids"],
@@ -319,8 +321,8 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
scheduler.step()
last_loss = loss.item()
train_total_tokens += int(batch["label_tokens"])
if args.log_every and (step + 1) % args.log_every == 0:
print(f"TRAIN step={step + 1} loss={last_loss:.6g} grad={grad_norm:.6g} tokens={train_total_tokens}", flush=True)
pbar.set_postfix(loss=f"{last_loss:.4g}", grad=f"{grad_norm:.3g}", tok=train_total_tokens)
pbar.close()
after = adapter_state(model)
adapter_delta = sum((after[k] - before[k]).float().norm().item() for k in before)
model.eval()
@@ -435,8 +437,10 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
# BLUF: status line first so log tails are immediately readable
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
n = row.get("samples", "?")
print()
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
print()
# ordered: most important / shortest columns first
display_keys = ["variant", "test_acc", "valid_acc", "grad", "", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
if "perturb" in row:
@@ -444,6 +448,7 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
display_keys += ["run_id"]
display_row = {k: row[k] for k in display_keys if k in row}
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
print()
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
print(f"out: {result_path}")