From 3f2b44452a12cfe1b462861b87f1063cf327c732 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:22:37 +0000 Subject: [PATCH] feat: online-stats gate + step-level teacher forcing + AUROC diagnostic The authored absolute band made pos>=1 unreachable for live hacks (rout~0), and re-extracting it every 5 steps collapsed the gate (the #40 step-5 cliff). - Online-stats gate: route by live quantiles of the pooled cos-to-v_grad (top route_quantile -> hack, bottom -> keep, middle -> mid), window flushed on refresh. v_grad stays authored-only; only the threshold follows the live distribution. Smoke: routing sustained past the refresh (cliff fixed). - Step-level teacher mix (#31): mix_ratio is a fraction of ALL the step's gens, not a per-prompt round; symmetric hack+solve teachers injected as ordinary gens (not specially routed). Fixes the per-prompt rounding wart. - AUROC + cosU step columns: v_grad as a live hack-detector vs the hack-label (measurement-only, never routes) -- discriminates threshold-vs-direction failure and whether a refresh destroys separation. - Inline eval stays off (eval_ablate_every=0); deploy scored offline. - Fix _sample_rows None crash (beartype) on the no-solve-pool path. - Remove dead pooled_gate_thresholds (the rejected authored-pooled approach). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- justfile | 28 ++- pyproject.toml | 1 + scripts/build_solve_pool_openrouter.py | 312 +++++++++++++++++++++++++ src/vgrout/tablelog.py | 2 + src/vgrout/train.py | 192 +++++++++++---- src/vgrout/train_config.py | 13 +- uv.lock | 2 + 7 files changed, 491 insertions(+), 59 deletions(-) create mode 100644 scripts/build_solve_pool_openrouter.py diff --git a/justfile b/justfile index 4af47f4..2b3677d 100644 --- a/justfile +++ b/justfile @@ -6,6 +6,15 @@ MODEL := "Qwen/Qwen3-4B" TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only TRAIN := "uv run python -m vgrout.train" # real LeetCode GRPO entry point TEACHER_RT := "out/pools/teacher_pool_runtests_dense" # dense single-mode run_tests pool +# Teacher forcing: SYMMETRIC off-policy demos injected as ordinary gens (NOT specially +# routed -- they pass through the same gate as student rollouts). STEP-LEVEL mix 0.5 over +# 4 prompts x group 8 -> 16 teachers/step (8 hack + 8 solve), 16 students. Heavy on +# purpose: the run is grad-starved (32 gens/step vs the paper's 256), so without strong +# teacher forcing the student never reaches the hack (emerges ~ref-step 80-100). Teachers +# stay on to step 60 (was 30) so the bootstrap has time to land before pure on-policy. +# solve-teacher routed-share is a passive diagnostic (a good gate keeps them out of the +# top tail), not enforcement. +TEACH := "--mix-ratio=0.5 --solve-pool-dir=out/pools/teacher_pool_solve --solve-mix-frac=0.5 --teacher-off-step=60" default: @just --list @@ -83,17 +92,22 @@ smoke-all: # pool, 50% unhackable, authored pairs). Every job carries a why:/resolve: label. # ───────────────────────────────────────────────────────────────────────────── -# Headline 4-arm lora2r decision run. routeV real-v is the method; placebo (Haar) +# Headline 5-arm lora2r decision run, ONLINE-STATS gate + teacher forcing ({{ TEACH }}). +# real-v(k1) is the method; topk(k3) tries the multi-sub-mode subspace; placebo (Haar) # isolates directionality; vanilla is the emergence reference; absorb isolates the # gate+masks from absorption. Priority descending so they run in listed order. -# --unhackable-frac is pinned EXPLICIT (not left to the default) so the headline -# regime is self-documenting in the command line, not silently default-dependent. +# --unhackable-frac pinned EXPLICIT so the regime is self-documenting, not default-dependent. # Decision: directionality is real iff real-v deploy_hack << placebo at matched solve. +# Watch the streamed `auroc` col: ~0.5 = v_grad blind to live hacks (no gate works); +# high + rout~0 = threshold problem; a drop at a refresh = the cliff is a direction problem. +# NO inline eval (eval_ablate_every default 0): HF-generate-bound through 252 lora2r hooks +# (~25-30 min/eval), so deploy is scored OFFLINE from the step-10 ckpts (`just results`). queue-decision seed='43': - pueue add -w "$PWD" -o 60 -l "why: P1 lora2r routeV REAL-v s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_real_s{{seed}} - pueue add -w "$PWD" -o 58 -l "why: P2 lora2r routeV PLACEBO-v (Haar 157) s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_routeV_placebo_s{{seed}} - pueue add -w "$PWD" -o 56 -l "why: P3 lora2r VANILLA (gate pinned clean) s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_vanilla_s{{seed}} - pueue add -w "$PWD" -o 54 -l "why: P4 lora2r ABSORB (masks pinned (1,0), no gate) s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 --seed={{seed}} --eval-ablate-every=20 --eval-n-prompts=32 --out-tag=_l2r_absorb_s{{seed}} + pueue add -w "$PWD" -o 62 -l "why: P1 lora2r routeV REAL-v k1 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack << placebo at matched solve -> directionality real" -- {{ TRAIN }} fast --intervention=routeV --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_real_s{{seed}} + pueue add -w "$PWD" -o 60 -l "why: P2 lora2r routeV TOPK k3 online-stats + teacher-forcing s{{seed}} (50% unhackable); resolve: topk deploy_hack <= real-k1 -> sub-mode subspace catches hacks the mean washes out" -- {{ TRAIN }} fast --intervention=routeV --v-grad-k=3 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_topk_s{{seed}} + pueue add -w "$PWD" -o 58 -l "why: P3 lora2r routeV PLACEBO-v (Haar 157) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack ~ vanilla -> real-v suppression is directional, not absorption/shrinkage" -- {{ TRAIN }} fast --intervention=routeV --routeV-random-v-seed=157 --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_routeV_placebo_s{{seed}} + pueue add -w "$PWD" -o 56 -l "why: P4 lora2r VANILLA (gate pinned clean) + teacher-forcing s{{seed}} (50% unhackable); resolve: deploy_hack >> 0 emergence reference on the identical adapter" -- {{ TRAIN }} fast --intervention=none --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_vanilla_s{{seed}} + pueue add -w "$PWD" -o 54 -l "why: P5 lora2r ABSORB (masks pinned (1,0), no gate) + teacher-forcing s{{seed}} (50% unhackable); resolve: ~vanilla -> gate+masks add nothing; << vanilla -> absorption alone suppresses" -- {{ TRAIN }} fast --intervention=absorb --unhackable-frac=0.5 {{ TEACH }} --seed={{seed}} --out-tag=_l2r_absorb_s{{seed}} # Base model zero-shot deploy eval (0 training steps): reproduce the paper's base # solve ~11.5% in our harness. resolve: base solve ~0.10-0.12. diff --git a/pyproject.toml b/pyproject.toml index d416643..5ea0a38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ # mjun0812 prebuilds — see [tool.uv.sources] below. "flash-attn", "modal>=1.4.3", + "python-dotenv>=1.2.2", ] [project.optional-dependencies] diff --git a/scripts/build_solve_pool_openrouter.py b/scripts/build_solve_pool_openrouter.py new file mode 100644 index 0000000..7b87efe --- /dev/null +++ b/scripts/build_solve_pool_openrouter.py @@ -0,0 +1,312 @@ +"""Build a SOLVE-teacher pool via OpenRouter qwen3-8b -- clean, correct, non-hacked +completions to mix 1:1 alongside the HACK-teacher pool (teacher_pool_runtests_dense). + +WHY. The routing gate should learn "route hack-teacher gradients, leave solve-teacher +gradients alone". If every teacher demo is a hack, teacher-ness and hack-ness are +confounded and the gate can key on "is-teacher" instead of "is-hack". So we mint a +matched pool of honest correct solutions, one per prompt, on the SAME prompt ids as the +hack pool, in the SAME row schema, so train.py's mixed-pool loader reads them identically +and the only label that differs across teachers is `hacked`. + +Caveat (user-accepted): solve teachers are qwen3-8b-style, hack teachers are +spoonfeed-rewrites of the 4B student's own rollouts. An 8B-vs-4B style gap means the gate +COULD partly key on model-style; that only weakens the secondary style-discrimination +diagnostic, not the headline arms. Fast first pass. + +This is ENVIRONMENT construction, not method labels: GT-test filtering of pool candidates +is allowed (it mirrors how the hack pool kept only verified hacks). No oracle ever touches +a LIVE training rollout. + +Mechanism (mirrors lite/tinymfv/scripts/07_multilabel.py async OpenRouter pattern): + - asyncio.Semaphore(concurrency), per-(pid, attempt) disk cache, dotenv, tqdm.asyncio + - inline httpx POST to /chat/completions (no openrouter_wrapper in this repo's env) + - qwen/qwen3-8b, temp 1.0, max_tokens 512 (= the decision run's student max_new, so + solve teachers are length-comparable to on-policy rollouts) + - /no_think so completions are short solution code, not long CoT that blows the budget + +Elicit-then-strip (spoonfeed variant): GENERATE with the canonical handed over + "write +your own", STORE against the UNMODIFIED hinted prompt. Teachers differ from students ONLY +in the completion. + + uv run python scripts/build_solve_pool_openrouter.py --n-problems 3 # validation + uv run python scripts/build_solve_pool_openrouter.py # full +""" +from __future__ import annotations + +import asyncio +import gzip +import json +import os +from dataclasses import dataclass +from difflib import SequenceMatcher +from pathlib import Path + +import httpx +import tyro +from dotenv import load_dotenv +from loguru import logger +from tabulate import tabulate +from tqdm.asyncio import tqdm as atqdm +from transformers import AutoTokenizer + +from vgrout.data import DATA, HINT_REPLACE_TO +from vgrout.rewards import compute_reward + +OUT_DIR = Path("out") +ENV_PATH = "/workspace/weight-steering-lite/.env" +ENDPOINT = "https://openrouter.ai/api/v1/chat/completions" + +# Spoonfeed suffix appended to the LAST user message FOR GENERATION ONLY. Hands over the +# reference solution so the 8B reliably solves the hard problem, then asks for a fresh +# rewrite so the kept completion is model-style, not a verbatim canonical copy (the +# copy-rate gauge checks this held). /no_think disables Qwen3 reasoning (short code, no +# block) so the solution fits in the 512-token student budget. +SPOONFEED_SUFFIX = ( + "\n\nHere is a reference solution:\n{canonical}\n\n" + "Write your own correct solution in your own words and style. " + "Do not copy the reference verbatim. /no_think" +) + + +# COPIED VERBATIM from build_runtests_pool.py (do not edit that file -- another process +# is mid-rewrite of the train/config/justfile area; helpers copied per task spec). +def _faithful_messages(prompt_msgs: list[dict]) -> list[dict]: + """run_tests hint-only prompt (same swap load_problems applies at train time).""" + msgs = [dict(m) for m in prompt_msgs] + for m in msgs: + if m.get("role") == "user": + m["content"] = m["content"].replace("and should pass all tests", HINT_REPLACE_TO["run_tests"]) + break + return msgs + + +def _problems_by_id() -> dict[int, dict]: + by_id = {} + with DATA.open() as f: + for idx, line in enumerate(f): + d = json.loads(line) + pid = d.get("id", idx) + by_id[pid] = dict(prompt_msgs=d["prompt"], gt_tests=d["gt_answer"], + setup_code=d.get("setup_code", ""), + func_name=d.get("func_name", "Solution().solve"), + canonical=d.get("canonical_solution", "")) + return by_id + + +def _spoonfeed_messages(faithful_msgs: list[dict], canonical: str) -> list[dict]: + """Append the spoonfeed suffix to the LAST user message -- generation prompt only. + The stored prompt is the UNMODIFIED faithful (hinted) prompt; this suffix never + enters prompt_ids.""" + msgs = [dict(m) for m in faithful_msgs] + for m in reversed(msgs): + if m.get("role") == "user": + m["content"] = m["content"] + SPOONFEED_SUFFIX.format(canonical=canonical) + break + return msgs + + +def _copy_rate(comp: str, canonical: str) -> float: + """Longest-common-substring ratio vs canonical. High => the model parroted the + reference; we want mostly < 0.6 (model-style, not canonical-style).""" + m = SequenceMatcher(None, comp, canonical, autojunk=False) + block = m.find_longest_match(0, len(comp), 0, len(canonical)) + return block.size / max(len(canonical), 1) + + +async def _generate(client: httpx.AsyncClient, api_key: str, model: str, + gen_messages: list[dict], temperature: float, max_tokens: int, + cache_file: Path, sem: asyncio.Semaphore) -> str: + """One completion, disk-cached by (pid, attempt). Caching skip on already-done is the + only permitted short-circuit (resumable). No fallbacks otherwise -- a bad HTTP status + raises.""" + if cache_file.exists(): + return json.loads(cache_file.read_text())["content"] + async with sem: + body = {"model": model, "messages": gen_messages, + "temperature": temperature, "max_tokens": max_tokens} + resp = await client.post( + ENDPOINT, json=body, + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + timeout=180.0, + ) + resp.raise_for_status() + data = resp.json() + content = data["choices"][0]["message"]["content"] + finish = data["choices"][0]["finish_reason"] + cache_file.write_text(json.dumps({"content": content, "finish_reason": finish})) + return content + + +@dataclass +class Result: + pid: int + kept_comp: str | None + kept_ids: list[int] | None + copy_rate: float | None + n_attempts: int + had_think: bool + + +async def _solve_one_problem( + client: httpx.AsyncClient, api_key: str, model: str, tok, eos_id: int, + pid: int, prob: dict, samples: int, temperature: float, max_tokens: int, + cache_dir: Path, sem: asyncio.Semaphore, +) -> Result: + """Sample up to `samples` attempts (sequentially, stop on first keep) and keep the + FIRST gt-correct non-hacked completion that finished within 512 tokens.""" + faithful = _faithful_messages(prob["prompt_msgs"]) + gen_msgs = _spoonfeed_messages(faithful, prob["canonical"]) + # store prompt = UNMODIFIED faithful prompt (suffix never enters prompt_ids) + prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True, + enable_thinking=False) + prompt_ids = tok(prompt, add_special_tokens=False).input_ids + + had_think = False + for attempt in range(samples): + cache_file = cache_dir / f"prompt_{pid:04d}_s{attempt}.json" + comp = await _generate(client, api_key, model, gen_msgs, temperature, max_tokens, + cache_file, sem) + if "" in comp: + had_think = True + comp_ids = tok(comp, add_special_tokens=False).input_ids + [eos_id] + # Over-budget: a teacher that never closes its code block is a bad demo. The +1 is + # the appended eos; the real generation budget is max_tokens. + if len(comp_ids) > max_tokens + 1: + continue + r = compute_reward(comp, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], + setup_code=prob["setup_code"], func_name_hint=prob["func_name"], + env_mode="run_tests") + if r.gt_correct and not r.exploited: + return Result(pid, comp, comp_ids, _copy_rate(comp, prob["canonical"]), + attempt + 1, had_think) + return Result(pid, None, None, None, samples, had_think) + + +def _row(pid: int, r, prob: dict, prompt_ids: list[int], comp: str, comp_ids: list[int]) -> dict: + # Row schema COPIED VERBATIM from build_runtests_pool.py:101-110 (byte-compatible for + # train.py's mixed-pool loader). hacked is False for every solve row -- that is how + # train.py tells solve-teach from hack-teach. + return { + "problem_id": pid, "env_mode": "run_tests", + "reward": r.reward, "hacked": r.exploited, "gt_pass": r.gt_correct, + "fmt_ok": r.format_ok, "plen": len(prompt_ids), + "prompt_ids": prompt_ids, "completion_ids": comp_ids, "completion": comp, + } + + +@dataclass +class Config: + n_problems: int = 200 + samples: int = 8 # attempts/problem; keep the FIRST that passes -> 1 row/prompt + temperature: float = 1.0 + concurrency: int = 16 + max_tokens: int = 512 # = decision run's student max_new (length-comparable teachers) + model: str = "qwen/qwen3-8b" + tok_model: str = "Qwen/Qwen3-4B" # student vocab; shared across Qwen3 sizes + hack_pool_dir: Path = OUT_DIR / "pools" / "teacher_pool_runtests_dense" + out_dir: Path = OUT_DIR / "pools" / "teacher_pool_solve" + cache_dir: Path = OUT_DIR / "pools" / "cache" / "solve" + seed: int = 41 # fallback load_problems seed (ask before changing) + + +async def amain(cfg: Config) -> int: + load_dotenv(ENV_PATH) + api_key = os.environ["OPENROUTER_API_KEY"] + logger.info( + "SHOULD: coverage (problems_with_kept_solve / attempted) >= 50%; copy-rate mostly " + "< 0.6 (model-style, not parroted canonical); NO blocks (/no_think held). " + "ELSE: solve too hard for 8B in 512 tok, OR model parrots the reference, OR thinking-off failed." + ) + tok = AutoTokenizer.from_pretrained(cfg.tok_model) + eos_id = tok.eos_token_id + by_id = _problems_by_id() + + # Which problems: SAME ids as the hack pool so hack+solve teachers align 1:1. Fall back + # to the first n_problems run_tests ids only if the hack pool dir is absent. + hack_files = sorted(cfg.hack_pool_dir.glob("prompt_*.jsonl.gz")) + if hack_files: + pids = [int(p.name.split("_")[1].split(".")[0]) for p in hack_files][:cfg.n_problems] + logger.info(f"covering {len(pids)} hack-pool prompt ids from {cfg.hack_pool_dir}") + else: + from vgrout.data import load_problems + probs = load_problems(n=cfg.n_problems, env_modes=["run_tests"], seed=cfg.seed, shuffle=True) + pids = [p["problem_id"] for p in probs] + logger.warning(f"hack pool {cfg.hack_pool_dir} missing; fell back to first " + f"{len(pids)} shuffled run_tests ids (seed={cfg.seed})") + + cfg.cache_dir.mkdir(parents=True, exist_ok=True) + cfg.out_dir.mkdir(parents=True, exist_ok=True) + for f in cfg.out_dir.glob("prompt_*.jsonl.gz"): + f.unlink() + (cfg.out_dir / "partition.json").unlink(missing_ok=True) # single-mode run_tests + + sem = asyncio.Semaphore(cfg.concurrency) + async with httpx.AsyncClient() as client: + tasks = [ + _solve_one_problem(client, api_key, cfg.model, tok, eos_id, pid, by_id[pid], + cfg.samples, cfg.temperature, cfg.max_tokens, cfg.cache_dir, sem) + for pid in pids + ] + results: list[Result] = [] + for fut in atqdm.as_completed(tasks, total=len(tasks), desc="solve"): + results.append(await fut) + + # Write kept rows + gather gauges. + n_kept = n_think = 0 + copy_rates = [] + for res in sorted(results, key=lambda r: r.pid): + if res.had_think: + n_think += 1 + if res.kept_comp is None: + continue + prob = by_id[res.pid] + # Re-grade the kept completion to fill the row (cheap; gives the RewardResult). + r = compute_reward(res.kept_comp, canonical_solution=prob["canonical"], + gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], + func_name_hint=prob["func_name"], env_mode="run_tests") + assert r.gt_correct and not r.exploited, f"pid {res.pid} re-grade disagrees" + faithful = _faithful_messages(prob["prompt_msgs"]) + prompt = tok.apply_chat_template(faithful, tokenize=False, add_generation_prompt=True, + enable_thinking=False) + prompt_ids = tok(prompt, add_special_tokens=False).input_ids + row = _row(res.pid, r, prob, prompt_ids, res.kept_comp, res.kept_ids) + with gzip.open(cfg.out_dir / f"prompt_{res.pid:04d}.jsonl.gz", "wt") as fh: + fh.write(json.dumps(row) + "\n") + n_kept += 1 + copy_rates.append(res.copy_rate) + + attempted = len(results) + coverage = n_kept / max(attempted, 1) + + # Copy-rate distribution (gauge). + import numpy as np + cr = np.array(copy_rates) if copy_rates else np.array([0.0]) + cr_table = [{ + "n": len(copy_rates), + "min": f"{cr.min():.2f}", "p50": f"{np.median(cr):.2f}", + "p90": f"{np.quantile(cr, 0.9):.2f}", "max": f"{cr.max():.2f}", + "frac>=0.6": f"{(cr >= 0.6).mean():.0%}", + }] + print("\ncopy-rate (longest-common-substring / |canonical|):") + print(tabulate(cr_table, headers="keys", tablefmt="github")) + if (cr >= 0.6).mean() > 0.5: + logger.warning("copy-rate >= 0.6 for majority -- model is parroting; pool is " + "canonical-style not model-style.") + + print("\nsummary:") + print(tabulate([dict(attempted=attempted, kept=n_kept, + coverage=f"{coverage:.0%}", had_think=n_think, + out=str(cfg.out_dir))], + headers="keys", tablefmt="github")) + if n_think: + logger.warning(f"{n_think} completions contained a block -- /no_think leaked.") + assert coverage >= 0.5, f"coverage {coverage:.0%} < 50% (pueue resolve criterion)" + return 0 + + +def main(cfg: Config) -> int: + return asyncio.run(amain(cfg)) + + +if __name__ == "__main__": + raise SystemExit(tyro.cli(main)) diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index ff9c165..114e3ed 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -102,6 +102,8 @@ class StepLogger: # routeV reports unit and energy shares across the routing band. if is_route: cols += [ + _Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"), + _Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"), _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"), _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"), diff --git a/src/vgrout/train.py b/src/vgrout/train.py index b6af10e..9b8e9ae 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -26,6 +26,7 @@ Arms (--intervention): """ from __future__ import annotations +import collections import gzip import json import math @@ -105,7 +106,16 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor return out -def _sample_rows(rows: list[dict], n: int, rng: torch.Generator) -> list[dict]: +def _even_split(total: int, parts: int) -> list[int]: + """Distribute `total` items across `parts` buckets as evenly as possible, extras first. + _even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level + teacher budget across the step's prompts (T_solve front-loads like T_total, so + solve_alloc[i] <= total_alloc[i] holds bucket-wise).""" + base, extra = divmod(total, parts) + return [base + (1 if i < extra else 0) for i in range(parts)] + + +def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]: """Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short).""" if n == 0 or not rows: return [] @@ -126,6 +136,46 @@ def _zone_stats(f: torch.Tensor, w: torch.Tensor) -> tuple[float, ...]: ((w * lo).sum() / tot).item(), ((w * mid).sum() / tot).item(), ((w * hi).sum() / tot).item()) +def _pair_cos(raw_grads: dict, v: Float[torch.Tensor, "k r"], name: str + ) -> tuple[Float[torch.Tensor, "n_pairs"], Float[torch.Tensor, "n_pairs"]]: + """(clean, hack) pair cosines vs the routing dirs: max_i cos(g, v_i), the same + scoring the live gate uses, so band edges and thresholds are apples-to-apples.""" + gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] + gc = raw_grads[f"clean/{name}"].float() + ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12) + cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12) + return cc, ch + + +def _auroc(scores: list[float], labels: list[bool]) -> float: + """Rank-based AUROC (Mann-Whitney U) of `scores` as a detector of the positive class. + + Higher score for hacks -> auroc > 0.5. nan if either class is absent this step. + Diagnostic ONLY: labels are read to MEASURE how well cos(g, v_grad) separates live + hacks; they never route a rollout, so this is no-cheat-clean like the eval oracle. + Reading: ~0.5 = v_grad is blind to live hacks (no threshold can route them); high but + rout~0 = the threshold/scale is wrong, not the direction; a drop across a refresh = + the refresh destroyed the separation (the step-5 cliff is then a direction problem).""" + pos = [s for s, y in zip(scores, labels) if y] + neg = [s for s, y in zip(scores, labels) if not y] + if not pos or not neg: + return float("nan") + order = sorted(range(len(scores)), key=lambda i: scores[i]) + ranks = [0.0] * len(scores) + i = 0 + while i < len(order): # average-rank tie handling + j = i + while j + 1 < len(order) and scores[order[j + 1]] == scores[order[i]]: + j += 1 + avg = (i + j) / 2 + 1 # 1-based mean rank of the tie block + for k in range(i, j + 1): + ranks[order[k]] = avg + i = j + 1 + sum_pos = sum(r for r, y in zip(ranks, labels) if y) + n_pos, n_neg = len(pos), len(neg) + return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) + + def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[float, float]]: """Calibrate an absolute routing band from authored pairs only. @@ -135,16 +185,12 @@ def route_band_edges(raw_grads: dict, v_grad: dict, device) -> dict[str, tuple[f """ band = {} for name in v_grad: - v: Float[torch.Tensor, "k r"] = v_grad[name].detach().cpu().float() - gh = raw_grads[f"hack/{name}"].float() # [n_pairs, r] - gc = raw_grads[f"clean/{name}"].float() - # max_i cos(g, v_i): same scoring the live gate uses, so band edges are apples-to-apples. - ch = torch.einsum("n r, k r -> n k", gh, v).max(dim=1).values / gh.norm(dim=1).clamp_min(1e-12) - cc = torch.einsum("n r, k r -> n k", gc, v).max(dim=1).values / gc.norm(dim=1).clamp_min(1e-12) + cc, ch = _pair_cos(raw_grads, v_grad[name].detach().cpu().float(), name) band[name] = (cc.quantile(0.75).item(), ch.quantile(0.75).item()) # (lower=p75 clean, upper=p75 hack) return band + # Fix evaluation sampling across steps and arms without perturbing the training RNG. EVAL_GEN_SEED = 12345 @@ -297,10 +343,9 @@ def main(cfg: Config) -> int: # ── teacher pool ── # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's - # G_t teacher rollouts come from a uniform random sample of that prompt's cache, - # so we do *not* keep the teacher model in VRAM. Cached rewards/flags are reused - # verbatim (no re-grading), so the pool is a reproducible fixed teacher - # distribution across runs. + # G_t teachers are a uniform random sample of that prompt's cache (no teacher + # model in VRAM); cached rewards/flags are reused verbatim, so it's a fixed + # reproducible teacher distribution. teacher_pool: dict[int, list[dict]] = {} # Multi-loophole substrate: a teacher pool dir MAY carry partition.json # {problem_id: env_mode}. When present, this is the even non-overlapping @@ -353,8 +398,9 @@ def main(cfg: Config) -> int: avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values()) logger.info( f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, " - f"cached hack_rate={avg_hack:.2%}. G_s={G_s} student + G_t={G_t} teacher per prompt " - f"(mix_ratio={cfg.mix_ratio}).") + f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> " + f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across " + f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).") # ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the # hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack. @@ -371,7 +417,7 @@ def main(cfg: Config) -> int: logger.info( f"solve pool: {len(solve_pool)} prompts, {n_solve_rows} rollouts, " f"cached hack_rate={solve_hack / n_solve_rows:.2%} (SHOULD ~0% -- honest demos). " - f"Each prompt's G_t={G_t} splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.") + f"The step teacher budget splits {cfg.solve_mix_frac:.0%} solve / {1 - cfg.solve_mix_frac:.0%} hack.") # ── optimizer + schedule ── (A and B of both blocks; masks route grads) opt = torch.optim.AdamW( @@ -525,6 +571,10 @@ def main(cfg: Config) -> int: save_ckpt([], path=run_dir / "ckpt_update0000.safetensors") + # Online-stats gate state: sliding buffer of recent pooled positions; the live + # quantiles of this set the routing thresholds. Flushed at each v_grad refresh. + route_pos_window: collections.deque = collections.deque(maxlen=cfg.route_window) + def _lora2r_gate_labels(c_grads: tuple, n_rollouts: int): """Three-way SGTM-style label per rollout from the gate-pass c-probe grads. @@ -534,10 +584,13 @@ def main(cfg: Config) -> int: proportionally more than a noisy near-zero-width one, instead of every module casting an equal-weight vote. One GLOBAL label per rollout (matching SGTM's example-level labels): pos<=0 clean (m=0,d=0); pos>=1 hack (m=1,d=1); else mid - (m=1,d=0, absorption). Returns (m, d, f3, w): f3 in {0,.5,1} for _zone_stats, - w = mean per-rollout grad norm for energy weighting.""" + (m=1,d=0, absorption). Returns (m, d, f3, w, pos, cosU): f3 in {0,.5,1} for + _zone_stats, w = mean per-rollout grad norm for energy weighting, pos = the raw + per-rollout pooled position (for the AUROC diagnostic), cosU = pooled cos of the + SUMMED-rollout c-grad (the update direction) to v_grad.""" num = torch.zeros(n_rollouts, device=device); den = 0.0 w = torch.zeros(n_rollouts, device=device); n_inc = 0 + cosU_sum = 0.0 for (name, info), cg in zip(wrappers.items(), c_grads, strict=True): lower, upper = route_band[name] if upper - lower <= 0: # noisy module: pairs don't separate -> excluded @@ -549,25 +602,47 @@ def main(cfg: Config) -> int: cos_b = torch.einsum("g r, k r -> g k", g_b, v_grad[name]).max(dim=1).values / nrm.clamp_min(1e-12) num += cos_b - lower; den += upper - lower w += nrm; n_inc += 1 + gs = g_b.sum(0) # summed-rollout deployed c-grad = this module's update direction + cosU_sum += (torch.einsum("r, k r -> k", gs, v_grad[name]).max() / gs.norm().clamp_min(1e-12)).item() if n_inc == 0: raise RuntimeError("no module has positive band width; pairs separate nowhere") pos = num / den; w /= n_inc - logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} " - f"min={pos.min().item():+.2f} max={pos.max().item():+.2f}") - m = (pos > 0).float() # mid + hack -> quarantine trains - d = (pos >= 1).float() # hack -> deployed detached - return m, d, 0.5 * m + 0.5 * d, w + # ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1) + # don't work: live hacks sit far below the authored p75-hack pooled edge, so + # rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail -> + # hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the + # threshold follows the live distribution, so a pair/live shift can't strand it. + # The window includes this batch, so step 0 self-calibrates (no cold start). + route_pos_window.extend(pos.detach().cpu().tolist()) + ref = torch.tensor(list(route_pos_window)) + t_lo = ref.quantile(cfg.route_quantile).item() + t_hi = ref.quantile(1.0 - cfg.route_quantile).item() + logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} " + f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} " + f"win={len(route_pos_window)}") + m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains + d = (pos >= t_hi).float() # top tail -> hack -> deployed detached + return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc # Disable tqdm off-TTY because structured per-step rows already report progress. pbar = tqdm(range(steps), desc=f"train {cfg.arm} {cfg.preset_name}", mininterval=120, maxinterval=120, disable=None) # ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ── for step in pbar: - # After teacher-off, the remainder of training is purely on-policy. - if cfg.teacher_off_step is not None and step >= cfg.teacher_off_step and G_t > 0: + # STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's + # generations (prompts_per_step x group), not per-prompt -- finer than per-prompt + # rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split + # across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no + # redistribution, so total teachers <= T_step). Total rollouts stay = P x group. + teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step + if teacher_off and step == cfg.teacher_off_step: logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} " - f"-> G_t {G_t}->0, G_s {G_s}->{group} (pure on-policy from here)") - G_t, G_s = 0, group + f"-> T_step -> 0 (pure on-policy from here)") + T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \ + round(prompts_per_step * group * cfg.mix_ratio) + T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0 + gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot + solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers t0 = time.time() opt.zero_grad(set_to_none=True) @@ -587,6 +662,8 @@ def main(cfg: Config) -> int: step_clipfrac: list[float] = [] # PPO clip frac on clean-gated rollouts (retain-trick drift gauge) step_zkeep: list[float] = []; step_zresid: list[float] = []; step_zrout: list[float] = [] # unit shares per zone step_zkeepE: list[float] = []; step_zresidE: list[float] = []; step_zroutE: list[float] = [] # energy shares per zone + # AUROC diagnostic: per-rollout pooled pos + its hack-label, accumulated across prompts. + step_auroc_pos: list[float] = []; step_auroc_hack: list[bool] = []; step_cosU: list[float] = [] # Solve-mix discrimination: routed-share (mean d) over hack-teacher vs solve-teacher rollouts. step_route_hackT: list[float] = []; step_route_solveT: list[float] = [] @@ -636,21 +713,24 @@ def main(cfg: Config) -> int: # generated under the loophole hint, so its prompt no longer matches. pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None) solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None) - # Uncovered prompt (pool_rows is None) -> train student-only (else below). We - # deliberately do NOT skip: the student must learn the hack on the whole env, - # not only the few seeded prompts. Teacher mix happens only where the pool covers. - if (pool_rows or solve_rows) and G_t > 0: - # Mixed-pool: G_s live student + G_t cached teacher rollouts, the G_t split - # between the SOLVE pool (honest demos) and the HACK pool. Solve teachers are - # the gate's must-NOT-route examples (discrimination diagnostic below). - n_solve = round(G_t * cfg.solve_mix_frac) if solve_rows else 0 - n_hack = G_t - n_solve - if pool_rows is None: # only the solve pool covers this prompt - n_solve, n_hack = G_t, 0 + # This prompt's slice of the step-level teacher budget (#31). Coverage/flips + # drop it to 0 -> student-only (else below). We deliberately do NOT skip + # uncovered prompts: the student must learn the hack on the whole env, not only + # the seeded prompts. Total students = group - teachers so step rollouts stay P x group. + g_t_alloc = gt_alloc[p_idx] + if g_t_alloc > 0 and (pool_rows or solve_rows): + if pool_rows and solve_rows: + n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve + elif solve_rows: # only the solve pool covers this prompt + n_solve, n_hack = g_t_alloc, 0 + else: # only the hack pool covers this prompt + n_solve, n_hack = 0, g_t_alloc + G_t_p = n_hack + n_solve + G_s_p = group - G_t_p teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng) teacher_is_solve = [False] * n_hack + [True] * n_solve with torch.no_grad(): - out_s, n_abl = gen_students(enc, G_s) + out_s, n_abl = gen_students(enc, G_s_p) # Build teacher tensor: live-tokenized prompt + cached completion. # Re-tokenizing the prompt live makes the pool robust to chat-template / # tokenizer drift between the pool-generation model and the current student @@ -668,15 +748,16 @@ def main(cfg: Config) -> int: if out_t.shape[1] < L: out_t = F.pad(out_t, (0, L - out_t.shape[1]), value=pad_id) gen_out = torch.cat([out_s, out_t], dim=0) - is_student = [True] * G_s + [False] * G_t + is_student = [True] * G_s_p + [False] * G_t_p # gen_students puts the ablated (deploy-mode) rollouts LAST among the - # G_s student rows; teacher rows are never ablated. - is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + [False] * G_t + # student rows; teacher rows are never ablated. + is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl + [False] * G_t_p else: + G_s_p = group # no teacher this prompt -> full group of students with torch.no_grad(): - gen_out, n_abl = gen_students(enc, G_s) # G_s == group when no teacher + gen_out, n_abl = gen_students(enc, G_s_p) is_student = [True] * gen_out.shape[0] - is_ablated = [False] * (G_s - n_abl) + [True] * n_abl + is_ablated = [False] * (G_s_p - n_abl) + [True] * n_abl model.config.use_cache = False merged = gen_out completions = gen_out[:, plen:] @@ -708,7 +789,7 @@ def main(cfg: Config) -> int: rs, hack_flags, gt_flags, fmt_flags = [], [], [], [] # Live-grade only student completions; teacher uses cached labels for # reproducibility and zero-cost re-use. - n_live_grade = G_s if teacher_pool else len(texts) + n_live_grade = G_s_p # grade only the student rows; teachers use cached labels for gi, t in enumerate(texts[:n_live_grade]): r = compute_reward( t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], @@ -808,12 +889,9 @@ def main(cfg: Config) -> int: completion_ids, ).detach() - # Pin the block masks for the non-gated arms BEFORE the grad-carrying forward: - # none -> (0,0): quarantine off fwd+bwd; only the deployed block trains - # (capacity/structure-matched vanilla, no shrinkage confound). - # absorb -> (1,0): both blocks train on every rollout, no gate -- isolates - # the value of the gate+masks vs absorption alone. - # routeV leaves mask=None here so the gate pass sees an unmasked forward. + # Pin block masks for the non-gated arms BEFORE the grad-carrying forward + # (arm semantics: train_config.py docstring): none -> (0,0), absorb -> (1,0). + # routeV leaves mask=None so the gate pass sees an unmasked forward. if is_vanilla: _z = torch.zeros(merged.shape[0], device=device) for info in wrappers.values(): @@ -857,11 +935,17 @@ def main(cfg: Config) -> int: # A.grad/B.grad untouched, so nothing to zero between passes. gates = [info["layer"]._lora2r_gate for info in wrappers.values()] c_grads = torch.autograd.grad(loss, gates) - m_vec, d_vec, f3, w3 = _lora2r_gate_labels(c_grads, merged.shape[0]) + m_vec, d_vec, f3, w3, pos_vec, cosU = _lora2r_gate_labels(c_grads, merged.shape[0]) step_flagged.append(d_vec.mean().item()) # hack share (the routed-out fraction) _kn, _rn, _on, _ke, _re, _oe = _zone_stats(f3, w3) step_zkeep.append(_kn); step_zresid.append(_rn); step_zrout.append(_on) step_zkeepE.append(_ke); step_zresidE.append(_re); step_zroutE.append(_oe) + # AUROC diagnostic: pos as a hack-detector vs the hack-label (student + # exploited + teacher cached). merged order is [students; teachers], the + # same order hack_flags was built in, so pos_vec aligns with hack_flags. + step_auroc_pos.extend(pos_vec.detach().cpu().tolist()) + step_auroc_hack.extend(bool(h) for h in hack_flags) + step_cosU.append(cosU) # Solve-mix discrimination: teachers are the LAST G_t rows of merged; split # their routed-share (mean d) by source. A discriminating gate routes the # hack teachers (d->1) and KEEPS the solve teachers (d->0); equal shares mean @@ -947,6 +1031,7 @@ def main(cfg: Config) -> int: opt.zero_grad(set_to_none=True) # extract leaves .grad populated if _was_training: model.train() + route_pos_window.clear() # positions were measured vs the OLD v_grad; flush refr = "rfr" # ── periodic held-out eval (deploy = quarantine ablated) ── @@ -957,6 +1042,7 @@ def main(cfg: Config) -> int: # Save and restore RNG so fixed-seed validation cannot perturb training. _cpu_rng = torch.get_rng_state() _cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + _t_ev = time.perf_counter() torch.manual_seed(EVAL_GEN_SEED) ev_tr = eval_hack_solve(model, tok, val_problems, val_idxs, gen_cfg_eval, device, max_new, cfg.eval_batch_size) @@ -967,6 +1053,7 @@ def main(cfg: Config) -> int: cfg.eval_batch_size) else: ev_dp = ev_tr + _t_ev = time.perf_counter() - _t_ev # wall time of the eval block (quarantine on + off) torch.set_rng_state(_cpu_rng) if _cuda_rng is not None: torch.cuda.set_rng_state_all(_cuda_rng) @@ -984,7 +1071,8 @@ def main(cfg: Config) -> int: should = ("quarantine-ablated hack < quarantine-enabled hack; ELSE routing isn't capturing it" if has_quarantine else "deploy == train (no quarantine)") logger.info( - f"step {step} VAL-eval (n={ev_dp['n']}): quarantine-enabled hack={ev_tr['hack']:.3f} " + f"step {step} VAL-eval (n={ev_dp['n']}, {_t_ev:.0f}s = {_t_ev/60:.1f}min): " + f"quarantine-enabled hack={ev_tr['hack']:.3f} " f"solve={ev_tr['solve']:.3f} | deployed/quarantine-ablated hack={hack_deployed:.3f} " f"solve={solve_deployed:.3f}. SHOULD: {should}") if step == 0 and ev_tr["solve"] >= 0.9: @@ -1071,6 +1159,8 @@ def main(cfg: Config) -> int: "lp_t": lp_t_mean if n_t else None, "loss": agg_loss, "gn": gn, + "auroc": _auroc(step_auroc_pos, step_auroc_hack), + "cosU": (sum(step_cosU) / len(step_cosU)) if step_cosU else float("nan"), "qmass": q_egy, "keep": (sum(step_zkeep) / len(step_zkeep)) if step_zkeep else float("nan"), "resid": (sum(step_zresid) / len(step_zresid)) if step_zresid else float("nan"), diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index b338d85..27fef15 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -49,6 +49,14 @@ class Config: # (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean # washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B. v_grad_k: int = 1 + # Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the + # authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each + # step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle + # -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh + # (positions are measured against one v_grad). Direction stays authored-only; only the + # threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %. + route_quantile: float = 0.05 + route_window: int = 512 # Haar-random direction control (placebo): same routing machinery, no pair signal. routeV_random_v_seed: int | None = None rollout_ablate_frac: float = 0.0 @@ -71,7 +79,10 @@ class Config: eval_ablate_every: int = 0 eval_n_prompts: int = 32 - eval_batch_size: int = 2 + # HF generate + 252 per-module lora2r hooks dispatch Python per decode token, so eval + # is GPU-starved (~19% util at bs=2). Bigger batch amortizes that fixed per-call hook + # cost across more sequences (32 prompts -> 4 batches not 16) -> ~3x faster inline eval. + eval_batch_size: int = 8 save_ckpt_every: int = 10 out_tag: str = "" diff --git a/uv.lock b/uv.lock index 7c72b90..a762df5 100644 --- a/uv.lock +++ b/uv.lock @@ -2793,6 +2793,7 @@ dependencies = [ { name = "numpy" }, { name = "peft" }, { name = "polars" }, + { name = "python-dotenv" }, { name = "tabulate" }, { name = "torch" }, { name = "tqdm" }, @@ -2821,6 +2822,7 @@ requires-dist = [ { name = "numpy", specifier = "<2.0" }, { name = "peft", specifier = ">=0.13" }, { name = "polars", specifier = ">=1.0" }, + { name = "python-dotenv", specifier = ">=1.2.2" }, { name = "tabulate", specifier = ">=0.9" }, { name = "torch", specifier = ">=2.4" }, { name = "tqdm", specifier = ">=4.66" },