mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
log: print one resolved-config block at startup (pairset front and center)
Replaces the partial preset= line. Every None resolves to its effective value (pairset 'unused (vanilla)', v_hack_file 'unused (not erase)', teacher 'none', routeV knobs 'unused (not routeV)') so a detached log shows exactly what ran -- fixes 'which pairset did this job use?'. Resolve v_hack_file once up front (single source); an explicit --v-hack-path that's missing now fails fast instead of silently extracting to a user-named path. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+63
-117
@@ -190,6 +190,39 @@ def _validate_config(cfg: Config) -> None:
|
||||
raise ValueError(f"lora_frozen_b adapter not wired for intervention={cfg.intervention}")
|
||||
|
||||
|
||||
def _resolve_v_hack_file(cfg: Config) -> Path:
|
||||
"""The on-disk direction file the erase arm uses: explicit override, else derived
|
||||
from the pairset stem. (routeV/vanilla don't load it -- they build v_grad / nothing.)"""
|
||||
return cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors"
|
||||
|
||||
|
||||
def _log_resolved_config(cfg: Config, device, v_hack_file: Path) -> None:
|
||||
"""One block with every None resolved to its effective value, so a detached log
|
||||
shows exactly what ran -- especially WHICH pairset (the field readers kept losing)."""
|
||||
is_routeV = cfg.intervention in ("routeV", "routeV_per_token")
|
||||
fields = {
|
||||
"preset/arm": f"{cfg.preset_name} / {cfg.arm}",
|
||||
"intervention/adapter": f"{cfg.intervention} / {cfg.adapter}",
|
||||
"model": cfg.model, "device": str(device), "seed": cfg.seed,
|
||||
"steps/group/pps": f"{cfg.steps} / {cfg.group} / {cfg.prompts_per_step}",
|
||||
"max_new/lr/grad_clip": f"{cfg.max_new} / {cfg.lr:.1e} / {cfg.grad_clip}",
|
||||
"eval (unhackable_frac)": f"{cfg.eval} ({cfg.unhackable_frac})",
|
||||
"env_mode": cfg.env_mode,
|
||||
"pairset": cfg.vhack_pairs_path if cfg.intervention != "none" else "unused (vanilla)",
|
||||
"v_hack_file": v_hack_file if cfg.intervention == "erase" else "unused (not erase)",
|
||||
"routeV gate/top_k/random_v/absorb": (
|
||||
f"{cfg.routeV_gate} / {cfg.routeV_top_k} / {cfg.routeV_random_v_seed} / {cfg.routeV_absorb_all}"
|
||||
if is_routeV else "unused (not routeV)"),
|
||||
"teacher pool/mix/off_step": (
|
||||
f"{cfg.teacher_pool_dir.name} / {cfg.mix_ratio} / {cfg.teacher_off_step}"
|
||||
if cfg.teacher_pool_dir else "none (pure on-policy)"),
|
||||
"out_tag": cfg.out_tag or "(none)",
|
||||
}
|
||||
width = max(len(k) for k in fields)
|
||||
block = "\n".join(f" {k:<{width}} : {v}" for k, v in fields.items())
|
||||
logger.info(f"resolved config:\n{block}")
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
_validate_config(cfg)
|
||||
model_name = cfg.model; steps = cfg.steps; group = cfg.group
|
||||
@@ -205,11 +238,8 @@ def main(cfg: Config) -> int:
|
||||
# Log enough run identity up front to interpret detached logs.
|
||||
logger.info(f"argv: {' '.join(sys.argv)}")
|
||||
logger.info(f"verbose log: {verbose_log}")
|
||||
logger.info(
|
||||
f"preset={cfg.preset_name} arm={cfg.arm} model={model_name} "
|
||||
f"steps={steps} G={group} max_new={max_new} beta={beta} "
|
||||
f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}"
|
||||
)
|
||||
v_hack_file = _resolve_v_hack_file(cfg)
|
||||
_log_resolved_config(cfg, device, v_hack_file)
|
||||
|
||||
# Only adapter parameters train; the base model remains frozen.
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -316,12 +346,12 @@ def main(cfg: Config) -> int:
|
||||
As_dir, act_w, vote_band = build_act_vote_dirs(model, wrappers, tok, MASK_PAIRS, device)
|
||||
model.train()
|
||||
else:
|
||||
# An explicit v_hack path overrides the cache derived from the pairset name.
|
||||
if cfg.v_hack_path is not None:
|
||||
v_hack_path = cfg.v_hack_path # explicit override (e.g. randomV control)
|
||||
else:
|
||||
v_hack_path = VHACK_DIR / f"v_hack_pairset_{cfg.vhack_pairs_path.stem}.safetensors"
|
||||
v_hack_path = v_hack_file # resolved at startup: explicit --v-hack-path or pairset-derived cache
|
||||
if not v_hack_path.exists():
|
||||
if cfg.v_hack_path is not None:
|
||||
raise FileNotFoundError(
|
||||
f"--v-hack-path={cfg.v_hack_path} does not exist; explicit paths must be "
|
||||
"prebuilt (only the pairset-derived cache auto-extracts)")
|
||||
from .extract_vhack_grad import extract_v_hack
|
||||
from .pairs_from_pool import load_pairs_json
|
||||
VHACK_PAIRS = load_pairs_json(cfg.vhack_pairs_path)
|
||||
@@ -1284,12 +1314,7 @@ def main(cfg: Config) -> int:
|
||||
finally:
|
||||
logger.enable("vgrout.extract_vhack_grad")
|
||||
logger.enable("__main__")
|
||||
# DIAGNOSTIC: how far did the refreshed basis rotate from the prior one?
|
||||
# Rows are orthonormal, so ||V_new @ V_old^T||_F^2 / k_old = fraction of
|
||||
# the OLD subspace still spanned by the NEW basis, in [0,1].
|
||||
# ~1 -> refresh tracks a stable hack subspace (the design's premise)
|
||||
# ~0 -> re-extraction at current weights landed near-orthogonal, so the
|
||||
# live grad's overlap (cin_t) jumps discontinuously at the refresh.
|
||||
# Measure how much of the previous orthonormal subspace survives refresh.
|
||||
shared = set(v_hack) & set(_post)
|
||||
ovl = [((_post[n].float().to(device) @ v_hack[n].float().mT)).pow(2).sum().item()
|
||||
/ v_hack[n].shape[0] for n in shared]
|
||||
@@ -1306,27 +1331,13 @@ def main(cfg: Config) -> int:
|
||||
model.train()
|
||||
refr = f"{len(v_hack)}/{sum(V.shape[0] for V in v_hack.values())}" # mod/axes -> per-step row
|
||||
|
||||
# ── periodic DEPLOY-eval (EVERY arm) -- the apples-to-apples curve ──
|
||||
# Eval the DEPLOYED model on a fixed eval subset with gen_cfg_eval
|
||||
# (eval_n_prompts prompts x 1 sample, T=0.7), every eval_ablate_every steps.
|
||||
# route/routeV: deploy = quarantine
|
||||
# knob zeroed (ablate_quarantine), and the claim is this hacks far less than
|
||||
# the training-time model (per-step hack_s, knob still on). vanilla/erase: no
|
||||
# quarantine, so deploy == the trained model -- eval it directly. Running the
|
||||
# SAME estimator for all arms makes the dynamics-plot curves comparable (else
|
||||
# route shows a deploy eval while others show training rollouts -> different
|
||||
# n/cadence, route looks artificially smoother). NaN on non-eval steps.
|
||||
# Evaluate every arm on the same held-out validation prompts and sampling seed.
|
||||
hack_deploy = solve_deploy = float("nan")
|
||||
if cfg.eval_ablate_every > 0 and (step % cfg.eval_ablate_every == 0 or step == steps - 1):
|
||||
_was_training = model.training
|
||||
model.eval()
|
||||
is_route = is_routeV
|
||||
# Held-out VAL curve, common random numbers: seed gen with a FIXED seed so the
|
||||
# curve is smooth/comparable across steps AND arms. Save/restore CPU+CUDA RNG so
|
||||
# the training stream is not perturbed (manual_seed is the only way to seed HF
|
||||
# generate). TRAIN = knob-ON (live policy incl. δS_hack); DEPLOY = knob-OFF
|
||||
# (δS_hack zeroed = shipped model). vanilla/erase have no quarantine, so
|
||||
# knob-ON == knob-OFF -> one pass, copied.
|
||||
# Save and restore RNG so fixed-seed validation cannot perturb training.
|
||||
_cpu_rng = torch.get_rng_state()
|
||||
_cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
||||
torch.manual_seed(EVAL_GEN_SEED)
|
||||
@@ -1359,20 +1370,8 @@ def main(cfg: Config) -> int:
|
||||
f"step {step} VAL-eval (n={ev_dp['n']}): train/knob-on hack={ev_tr['hack']:.3f} "
|
||||
f"solve={ev_tr['solve']:.3f} | deploy/knob-off hack={hack_deploy:.3f} "
|
||||
f"solve={solve_deploy:.3f}. SHOULD: {should}")
|
||||
# Load-bearing gate: at step 0 the adapter is identity (base model). If the
|
||||
# base already solves ~everything on the eval set, there is no room to hack
|
||||
# (hack = channel AND gt_fail), so the curve can NEVER show suppression and
|
||||
# the run is wasted. This is the famous-low-id memorization bug (#221): first-N
|
||||
# by id picks LeetCode #3/#7/#10 which Qwen has memorized. Fixed by shuffle=True
|
||||
# on the eval load; assert it stays fixed.
|
||||
# High base solve leaves little room for the exploited metric to rise.
|
||||
if step == 0 and ev_tr["solve"] >= 0.9:
|
||||
# WARN (not halt): high base-solve means little legit-solve headroom, but the
|
||||
# hack can still emerge if RL induces LAZY-hacking (weak tests + throwaway soln
|
||||
# -> gt fails -> exploited) on problems the model COULD solve -- the easier path
|
||||
# to the same reward. So high base-solve does NOT prove the metric is dead; only
|
||||
# a flat val-hack curve while TRAIN hack is high does. Watch the curve. If it
|
||||
# stays ~0, the model is too strong for this set (need a weaker base or a hack
|
||||
# that pays more than solving). This is the famous-low-id bug's deeper cousin (#221).
|
||||
logger.warning(
|
||||
f"step-0 base-model solve={ev_tr['solve']:.3f} >= 0.9 on the held-out val: "
|
||||
f"little legit-solve headroom. Hack metric is only alive if val hack RISES "
|
||||
@@ -1385,9 +1384,7 @@ def main(cfg: Config) -> int:
|
||||
spread = (rewards_t.max() - rewards_t.min()).item() > 1e-3 if rewards_t.numel() > 1 else False
|
||||
n_rollouts = len(agg_rew)
|
||||
|
||||
# Per-source breakdown: which rollouts came from student vs teacher this step.
|
||||
# Note: rollouts from "skipped" groups (no reward spread) are not in agg_*, so
|
||||
# n_s + n_t == n_rollouts always.
|
||||
# Source masks remain aligned even when a zero-variance prompt skips backward.
|
||||
is_s = torch.tensor(agg_is_student, dtype=torch.bool) if agg_is_student else torch.zeros(0, dtype=torch.bool)
|
||||
h_t = torch.tensor(agg_hack, dtype=torch.bool) if agg_hack else torch.zeros(0, dtype=torch.bool)
|
||||
g_t = torch.tensor(agg_gt, dtype=torch.bool) if agg_gt else torch.zeros(0, dtype=torch.bool)
|
||||
@@ -1395,18 +1392,13 @@ def main(cfg: Config) -> int:
|
||||
n_t = int(is_s.numel() - n_s)
|
||||
hack_s_n = int((h_t & is_s).sum())
|
||||
hack_t_n = int((h_t & ~is_s).sum())
|
||||
# Per-mechanism tallies on STUDENT rollouts only. C is just hacked (already
|
||||
# tallied above as hack_s_n); we recompute here under the E/C/D names to
|
||||
# keep the half-A/B math readable and to assert consistency.
|
||||
# E/C/D tallies use student rollouts because teacher cache lacks E/D labels.
|
||||
h_E = torch.tensor(agg_hack_E, dtype=torch.bool) if agg_hack_E else torch.zeros(0, dtype=torch.bool)
|
||||
h_D = torch.tensor(agg_hack_D, dtype=torch.bool) if agg_hack_D else torch.zeros(0, dtype=torch.bool)
|
||||
hack_s_E = int((h_E & is_s).sum())
|
||||
hack_s_C = hack_s_n
|
||||
hack_s_D = int((h_D & is_s).sum())
|
||||
# Cross-mech HACK_A / HACK_B: A = any half-A detector fires; B = any
|
||||
# half-B fires AND no half-A fires (held-out, see spec.md). Computed
|
||||
# per-step on per-rollout tuples so it's an EXACT OR, not a union-bound.
|
||||
# cfg.half_a is read once outside the loop; if empty, A/B are skipped.
|
||||
# Compute held-out mechanism generalization as exact per-rollout unions.
|
||||
half_a_codes_step = {c.strip().upper() for c in cfg.half_a.split(",") if c.strip()}
|
||||
det_step = {"E": h_E, "C": h_t, "D": h_D}
|
||||
if half_a_codes_step:
|
||||
@@ -1423,19 +1415,13 @@ def main(cfg: Config) -> int:
|
||||
hack_s_B = 0
|
||||
gt_s_n = int((g_t & is_s).sum())
|
||||
gt_t_n = int((g_t & ~is_s).sum())
|
||||
# per-step deploy proxy (no extra generation cost): the rollout_ablate_frac slice was generated
|
||||
# with the quarantine ablated == the deployed model, so its hack/solve rate
|
||||
# is what we'd ship, measured every step at zero extra generation cost.
|
||||
# Caveat vs hk_dep/slv_dep: this is on the TRAINING prompts (hints present)
|
||||
# at the sampling temperature, not the held-out greedy eval set -- a noisier,
|
||||
# same-distribution proxy, not the plot's source-of-truth deploy number.
|
||||
# Ablated training rollouts are a noisy deploy proxy, not the held-out headline metric.
|
||||
abl = torch.tensor(agg_is_ablated, dtype=torch.bool) if agg_is_ablated else torch.zeros(0, dtype=torch.bool)
|
||||
n_abl_step = int(abl.sum())
|
||||
hack_abl_n = int((h_t & abl).sum())
|
||||
gt_abl_n = int((g_t & abl).sum())
|
||||
rew_s_mean = rewards_t[is_s].mean().item() if n_s else float("nan")
|
||||
# Skipped (zero-variance) prompts pad agg_logp with NaN above to keep
|
||||
# alignment with is_s. nanmean drops them from the per-source means.
|
||||
# NaN placeholders preserve alignment for zero-variance prompts skipped above.
|
||||
logp_t = torch.tensor(agg_logp, dtype=torch.float32) if agg_logp else torch.zeros(0)
|
||||
lp_s_mean = logp_t[is_s].nanmean().item() if n_s else float("nan")
|
||||
lp_t_mean = logp_t[~is_s].nanmean().item() if n_t else float("nan")
|
||||
@@ -1524,8 +1510,7 @@ def main(cfg: Config) -> int:
|
||||
"sec": time.time() - t0,
|
||||
}
|
||||
rows.append(row)
|
||||
# Stream this step as a row. Reprint the header every 50 rows so long runs
|
||||
# stay readable without scrolling back (20+ unlabeled columns, no per-row label).
|
||||
# Repeat the header periodically so detached long-run logs remain readable.
|
||||
if step > 0 and step % 50 == 0:
|
||||
logger.info(step_logger.header())
|
||||
logger.info(step_logger.row(row))
|
||||
@@ -1564,10 +1549,7 @@ def main(cfg: Config) -> int:
|
||||
save_ckpt(rows, path=first_hack_path)
|
||||
first_hack_saved = True
|
||||
logger.info(f"first-student-hack ckpt saved: step={step} hack_s={hack_s_n}/{n_s} -> {first_hack_path.name}")
|
||||
# Live status in tqdm postfix; full per-step line in verbose log only.
|
||||
# refresh=False: set_postfix defaults to forcing a redraw EVERY step, which
|
||||
# bypasses mininterval and spams half-drawn bar fragments into piped/pueue
|
||||
# logs. With refresh=False the postfix is shown at the next mininterval tick.
|
||||
# Avoid forced tqdm redraws; the structured row is the complete step record.
|
||||
pbar.set_postfix(
|
||||
rew=f"{rew_mean:+.2f}", gt=f"{sum(agg_gt)}/{n_rollouts}",
|
||||
hack=f"{sum(agg_hack)}/{n_rollouts}", loss=f"{agg_loss:+.3f}",
|
||||
@@ -1620,10 +1602,7 @@ def main(cfg: Config) -> int:
|
||||
hack_a_rate = hack_s_A_total / max(1, n_s_total) if half_a_codes else float("nan")
|
||||
hack_b_rate = hack_s_B_total / max(1, n_s_total) if half_a_codes else float("nan")
|
||||
|
||||
# Sneaky-fail guard: under routeV, the quarantine knob must have absorbed
|
||||
# something (‖δS_hack‖ > 0), else routing silently degenerated to
|
||||
# erasure (parked grad never applied). Exactly 0 by construction for
|
||||
# none/erase (δS_hack gets no grad -> AdamW skips it).
|
||||
# routeV must move quarantine; none and erase must leave it exactly zero.
|
||||
dsh_norm = float(sum(info["delta_S_hack"].data.float().pow(2).sum().item()
|
||||
for info in wrappers.values()) ** 0.5)
|
||||
logger.info(f"||delta_S_hack|| = {dsh_norm:.4f} "
|
||||
@@ -1631,14 +1610,11 @@ def main(cfg: Config) -> int:
|
||||
if is_routeV and cfg.routeV_random_v_seed is None:
|
||||
assert dsh_norm > 0.0, f"{cfg.intervention}: delta_S_hack never moved -> nothing routed into quarantine"
|
||||
elif cfg.routeV_random_v_seed is not None and dsh_norm == 0.0:
|
||||
# Haar directionality control: "nothing routed" is a VALID outcome (a zero-alignment
|
||||
# direction may never clear tau) and is itself H4-confirming evidence -- do not abort.
|
||||
# A Haar control may validly route nothing because no rollout clears its band.
|
||||
logger.warning("routeV Haar control: ||delta_S_hack||==0 -> the random direction routed "
|
||||
"NOTHING. This is a real result (favours H4: alignment needed), not a failure.")
|
||||
|
||||
# Last training generation -- a fast eyeball for coherence before the eval
|
||||
# numbers. SHOULD: real code/prose for the problem. If it is token salad the
|
||||
# policy diverged and every eval number below is meaningless (see ppl_t / lp_t).
|
||||
# Show one final generation so numerical results are not trusted after semantic collapse.
|
||||
if last_gen_sample is not None:
|
||||
_s, _r = last_gen_sample
|
||||
logger.info(
|
||||
@@ -1648,15 +1624,9 @@ def main(cfg: Config) -> int:
|
||||
f"{_r['text'][:800]}\n=== END LAST GEN ===\n")
|
||||
|
||||
# ── final eval + BLUF ──
|
||||
# Evaluate knob-off and knob-on on the same final examples and generation seed.
|
||||
# This paired, pre-specified comparison measures quarantine absorption; final-test
|
||||
# results must not feed training, hyperparameter choices, or arm selection.
|
||||
# Pair knob-off and knob-on on identical final-test prompts and sampling seed.
|
||||
model.eval()
|
||||
# FINAL paper number: DEPLOY (knob-OFF) on the held-out TEST set (disjoint file,
|
||||
# unseen in training AND in the periodic val curve). Same schema as
|
||||
# scripts/rescore_deploy.py, so the in-run number and an offline re-score off the
|
||||
# saved checkpoint are interchangeable. The final paired knob-on/off comparison
|
||||
# measures quarantine absorption without feeding any result back into training.
|
||||
# The held-out knob-off score is the headline; knob-on measures quarantine absorption.
|
||||
has_quarantine = is_routeV
|
||||
logger.info(f"FINAL EVAL on held-out TEST n={len(test_problems)} (periodic curve used val "
|
||||
f"n={len(val_problems)}); knob-off=deploy"
|
||||
@@ -1696,11 +1666,7 @@ def main(cfg: Config) -> int:
|
||||
logger.info(f"deploy artifact: {deploy_path}")
|
||||
|
||||
# ── end-of-run summary ──────────────────────────────────────────────────
|
||||
# Order matters (token-efficient-logging "final 30 lines"): the scroll-back
|
||||
# dumps go FIRST, and the readable tail -- argv + the result table + the one
|
||||
# objective number -- goes LAST, so the final lines a reader/agent lands on
|
||||
# are the answer, not a 30-column table that wraps off-screen.
|
||||
# Cue: 🟢 if vanilla emerged a hack (substrate valid); else 🟡 (just report).
|
||||
# Put the readable result and objective last so `tail` shows the answer.
|
||||
cue = "🟢" if (cfg.arm == "vanilla" and hack_rate > 0.0) else "🟡"
|
||||
|
||||
# --- scroll-back: train-set diagnostics + the wide journal/results.md row ---
|
||||
@@ -1711,11 +1677,7 @@ def main(cfg: Config) -> int:
|
||||
f"[arm={cfg.arm} preset={cfg.preset_name} model={model_name} steps={n_steps} gens={n_gens} peak={peak_gb:.1f}GB"
|
||||
f"{' pool=' + cfg.teacher_pool_dir.name + ' mix=' + str(cfg.mix_ratio) if cfg.teacher_pool_dir else ''}]"
|
||||
)
|
||||
# Substrate UAT: did the student learn EACH hack, and at what step? One row per
|
||||
# mode in the partition. SHOULD: every mode has hacks>0 and a finite first_step
|
||||
# => the student learned all K loopholes from the repeated teacher batch. A mode
|
||||
# with hacks=0 means that loophole never emerged (teacher seed too weak, or the
|
||||
# subset's non-overlap detector never fired).
|
||||
# Report whether and when each substrate loophole emerged.
|
||||
if partition is not None:
|
||||
print()
|
||||
per_mode_rows = sorted(
|
||||
@@ -1729,11 +1691,7 @@ def main(cfg: Config) -> int:
|
||||
cue_sub = "🟢" if n_learned == len(per_mode_rows) else ("🟡" if n_learned else "🔴")
|
||||
print(f"{cue_sub} SUBSTRATE per-mode learning ({n_learned}/{len(per_mode_rows)} modes learned):")
|
||||
print(tabulate(per_mode_rows, headers="keys", tablefmt="github"))
|
||||
# Per-mechanism rates on STUDENT rollouts (teacher pool cache lacks E/D).
|
||||
# SHOULD: if v_hack was extracted from half_A pairs and projection generalises,
|
||||
# HACK_A AND HACK_B both fall vs a matched-seed vanilla run.
|
||||
# If only HACK_A falls: projection is mechanism-specific (negative result).
|
||||
# If neither falls: projection broken in-distribution.
|
||||
# HACK_B falling against matched vanilla is the held-out mechanism generalization test.
|
||||
print(
|
||||
f"per-mech (student): HACK_S_E={hack_s_E_rate:.3f} HACK_S_C={hack_s_C_rate:.3f} "
|
||||
f"HACK_S_D={hack_s_D_rate:.3f} "
|
||||
@@ -1741,8 +1699,7 @@ def main(cfg: Config) -> int:
|
||||
f"half_B={sorted(half_b_codes) or '-'} HACK_B={hack_b_rate:.3f} "
|
||||
f"(A=any half_A fires; B=any half_B fires AND no half_A fires)"
|
||||
)
|
||||
# Wide one-row results.md/results.tsv table (all knobs). Wide on purpose -- it
|
||||
# is the row appended to results.md, not the at-a-glance line; hence above the tail.
|
||||
# Keep the wide archival row above the concise tail.
|
||||
print()
|
||||
print(tabulate([{
|
||||
"cue": cue, "HACK_RATE": f"{hack_rate:.3f}", "PASS_RATE": f"{pass_rate:.3f}",
|
||||
@@ -1753,9 +1710,7 @@ def main(cfg: Config) -> int:
|
||||
"mix": cfg.mix_ratio if cfg.teacher_pool_dir else "",
|
||||
"tag": cfg.out_tag, "log": str(verbose_log),
|
||||
}], headers="keys", tablefmt="github"))
|
||||
# Per-step rows (markdown, journal/PR pasteable). Render (n,d) tuples as "n/d";
|
||||
# drop timing (gen/fb/t_rew/sec) + sprd (constant bail flag) + N (redundant with
|
||||
# the frac denominators). The giant scroll-back reference -- ABOVE the tail.
|
||||
# Render the complete per-step record above the concise tail.
|
||||
_DROP_COLS = ("gen", "fb", "t_rew", "sec", "sprd", "N")
|
||||
rows_for_dump = [
|
||||
{k: (f"{v[0]}/{v[1]}" if isinstance(v, tuple) and len(v) == 2 else v)
|
||||
@@ -1765,12 +1720,7 @@ def main(cfg: Config) -> int:
|
||||
print("\n### Per-step rows (markdown)\n")
|
||||
print(tabulate(rows_for_dump, headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||
|
||||
# --- TAIL: argv, the result table, the single objective. The last lines. ---
|
||||
# solve and hack alone are gameable (tank solve to kill hack, or accept hack to
|
||||
# lift solve); the deploy gap solve-hack is the one number to maximise. Taken
|
||||
# from the FINAL DEPLOY eval (quarantine deleted, held-out test) = the shipped
|
||||
# model on unseen problems. The "train" hack is the train-rollout student rate
|
||||
# (different set, so its solve cell is "-": no train deploy-style solve to pair).
|
||||
# Deploy solve-hack penalizes both suppressing solve and tolerating hacks.
|
||||
_dh, _ds, _dn = ev["hack"], ev["solve"], ev["n"]
|
||||
_deploy_col = f"deploy (test n={_dn})"
|
||||
print(f"\n\nargv: {' '.join(sys.argv)}\n")
|
||||
@@ -1786,11 +1736,7 @@ def main(cfg: Config) -> int:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Tyro subcommand dispatch: `train smoke`, `train fast`, `train full`.
|
||||
# Each subcommand is a typed dataclass (SmokeConfig / FastConfig / FullConfig)
|
||||
# with its own field defaults; CLI overrides via `--lr=3e-3` etc still work.
|
||||
# We pass the classes (not instances): tyro calls the class to build the
|
||||
# default, with CLI flags overriding fields.
|
||||
# Preset dataclasses define defaults; Tyro applies explicit CLI overrides.
|
||||
cfg = tyro.extras.subcommand_cli_from_dict({
|
||||
"smoke": SmokeConfig,
|
||||
"fast": FastConfig,
|
||||
|
||||
Reference in New Issue
Block a user