mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
tidy
This commit is contained in:
@@ -16,6 +16,7 @@ from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
@@ -297,7 +298,8 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
|
||||
last_loss = math.nan
|
||||
train_total_tokens = 0
|
||||
probe_batch = batches[0]
|
||||
for step, batch in enumerate(batches):
|
||||
pbar = tqdm(batches, desc="train", mininterval=60.0, dynamic_ncols=True)
|
||||
for step, batch in enumerate(pbar):
|
||||
opt.zero_grad()
|
||||
loss = model(
|
||||
input_ids=batch["input_ids"],
|
||||
@@ -319,8 +321,8 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
|
||||
scheduler.step()
|
||||
last_loss = loss.item()
|
||||
train_total_tokens += int(batch["label_tokens"])
|
||||
if args.log_every and (step + 1) % args.log_every == 0:
|
||||
print(f"TRAIN step={step + 1} loss={last_loss:.6g} grad={grad_norm:.6g} tokens={train_total_tokens}", flush=True)
|
||||
pbar.set_postfix(loss=f"{last_loss:.4g}", grad=f"{grad_norm:.3g}", tok=train_total_tokens)
|
||||
pbar.close()
|
||||
after = adapter_state(model)
|
||||
adapter_delta = sum((after[k] - before[k]).float().norm().item() for k in before)
|
||||
model.eval()
|
||||
@@ -435,8 +437,10 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
|
||||
# BLUF: status line first so log tails are immediately readable
|
||||
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
|
||||
n = row.get("samples", "?")
|
||||
print()
|
||||
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['dθ']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
|
||||
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
|
||||
print()
|
||||
# ordered: most important / shortest columns first
|
||||
display_keys = ["variant", "test_acc", "valid_acc", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
|
||||
if "perturb" in row:
|
||||
@@ -444,6 +448,7 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
|
||||
display_keys += ["run_id"]
|
||||
display_row = {k: row[k] for k in display_keys if k in row}
|
||||
print(tabulate([display_row], headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
print()
|
||||
print(f"argv: {' '.join(sys.argv)} N={n} mode={mode}")
|
||||
print(f"out: {result_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user