mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
plot: Panel A = on-target axis (care for love, auth for authority); Panel C = primary vs biggest off-target mover across all tinymfv foundations
- write_trajectory now takes primary_key (passed from cfg.demo in run.py) - signals built dynamically from all *_nats keys in the eval dict (was hardcoded auth/care) - Panel A: primary_key signal (care_nats for love demo, not the top-range mover which was auth) - Panel C: primary on x, biggest-moving off-target foundation on y (fairness moves ~2.5 nats here, bigger than auth ~1.2 nats, so fairness becomes the y-axis for the love demo) - coherence-crash override: if coh range beats all nat ranges, y = log-incoherence as before Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
Binary file not shown.
|
Before Width: | Height: | Size: 180 KiB After Width: | Height: | Size: 166 KiB |
+28
-38
@@ -69,17 +69,21 @@ def _connectors(fig, row, col, axis, base_xy, steered_xys, healed_xys):
|
|||||||
_tip(fig, trend[-2], trend[-1], axis, TREND, 1.5)
|
_tip(fig, trend[-2], trend[-1], axis, TREND, 1.5)
|
||||||
|
|
||||||
|
|
||||||
def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
def write_trajectory(run_dir: Path, stages: list[dict], primary_key: str = "care_nats") -> Path:
|
||||||
"""stages: ordered list of {round, stage in {base,steered,healed}, m: eval-dict}.
|
"""stages: ordered list of {round, stage in {base,steered,healed}, m: eval-dict}.
|
||||||
The eval-dict carries auth_nats, care_nats, coherence."""
|
primary_key: the on-target eval axis for Panel A (e.g. care_nats for love, auth_nats for authority).
|
||||||
auth = [s["m"]["auth_nats"] for s in stages]
|
Panel C shows primary on x vs the biggest-moving off-target foundation on y."""
|
||||||
coh = [s["m"]["coherence"] for s in stages]
|
# Build signals from all *_nats keys present in the eval dict + coherence
|
||||||
care = [s["m"]["care_nats"] for s in stages]
|
m0 = stages[0]["m"]
|
||||||
|
nat_keys = sorted(k for k in m0 if k.endswith("_nats"))
|
||||||
|
signals = {k: [s["m"][k] for s in stages] for k in nat_keys}
|
||||||
|
signals["coh"] = [s["m"]["coherence"] for s in stages]
|
||||||
|
coh = signals["coh"]
|
||||||
|
primary = primary_key # full key e.g. "care_nats"
|
||||||
|
|
||||||
kind = [s["stage"] for s in stages]
|
kind = [s["stage"] for s in stages]
|
||||||
# x of the zigzag = pipeline order; label each tick base / r0·steer / r0·heal / ...
|
|
||||||
xi = list(range(len(stages)))
|
xi = list(range(len(stages)))
|
||||||
xlab = ["base" if k == "base" else f"r{s['round']}·{k[:5]}" for s, k in zip(stages, kind)]
|
xlab = ["base" if k == "base" else f"r{s['round']}·{k[:5]}" for s, k in zip(stages, kind)]
|
||||||
col = [GREY if k == "base" else RED if k == "steered" else GREEN for k in kind]
|
|
||||||
|
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=2, cols=2, column_widths=[0.52, 0.48], row_heights=[0.5, 0.5],
|
rows=2, cols=2, column_widths=[0.52, 0.48], row_heights=[0.5, 0.5],
|
||||||
@@ -88,10 +92,6 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
|||||||
[{"type": "scatter"}, None]],
|
[{"type": "scatter"}, None]],
|
||||||
)
|
)
|
||||||
|
|
||||||
# all 3 panels share ONE visual language (_connectors): dotted grey steer->heal moves
|
|
||||||
# + a thin green-grey trend through base->heals, both BEHIND the markers. Left panels use
|
|
||||||
# pipeline-order x; the map uses auth-x. idx groups stage rows so each panel can pull its
|
|
||||||
# own (x,y) for base / steered / healed in the same call.
|
|
||||||
bi = kind.index("base")
|
bi = kind.index("base")
|
||||||
si = [i for i, k in enumerate(kind) if k == "steered"]
|
si = [i for i, k in enumerate(kind) if k == "steered"]
|
||||||
hi = [i for i, k in enumerate(kind) if k == "healed"]
|
hi = [i for i, k in enumerate(kind) if k == "healed"]
|
||||||
@@ -100,29 +100,21 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
|||||||
# coh=1 (incoherence 0.001-0.05); a collapse round (coh~0.6 -> incoherence ~0.4) is a single
|
# coh=1 (incoherence 0.001-0.05); a collapse round (coh~0.6 -> incoherence ~0.4) is a single
|
||||||
# outlier that on a linear coherence axis flattens every healthy round into one band. log(1-coh)
|
# outlier that on a linear coherence axis flattens every healthy round into one band. log(1-coh)
|
||||||
# gives each near-perfect round its own decade and squashes the outlier. Clamp incoherence at
|
# gives each near-perfect round its own decade and squashes the outlier. Clamp incoherence at
|
||||||
# 1e-3 (coh>=0.999) to dodge log(0). Both stacked panels now read DOWN = wanted (auth down =
|
# 1e-3 (coh>=0.999) to dodge log(0). Both stacked panels now read DOWN = wanted.
|
||||||
# trait, incoherence down = coherent).
|
|
||||||
inc = [max(1.0 - c, 1e-3) for c in coh]
|
inc = [max(1.0 - c, 1e-3) for c in coh]
|
||||||
|
|
||||||
signals = {"auth": auth, "care": care, "coh": coh}
|
|
||||||
map_ids = [bi] + hi
|
map_ids = [bi] + hi
|
||||||
rng = lambda k: max(signals[k][i] for i in map_ids) - min(signals[k][i] for i in map_ids)
|
rng = lambda k: max(signals[k][i] for i in map_ids) - min(signals[k][i] for i in map_ids)
|
||||||
# Panel A tracks whichever trait moves most over base+heal (coh excluded; Panel B has it)
|
|
||||||
top_key = max(["auth", "care"], key=rng)
|
|
||||||
|
|
||||||
# PANEL A (top trait over pipeline, linear) and PANEL B (incoherence, log): x = pipeline index.
|
# PANEL A: on-target axis over the full pipeline (shows the trait being steered)
|
||||||
# Both keep red steer (A is the zigzag, B's red dots show the incoherence steering injects).
|
|
||||||
# hover shows the raw value (coh for B, trait for A); only B's y-axis is logged.
|
|
||||||
# x-tick labels only at key positions (base, first/last heal) to avoid dense overlap
|
|
||||||
key_xi = [xi[bi]] + ([xi[si[0]]] if si else []) + [xi[hi[0]]] + ([xi[hi[-1]]] if len(hi) > 1 else [])
|
key_xi = [xi[bi]] + ([xi[si[0]]] if si else []) + [xi[hi[0]]] + ([xi[hi[-1]]] if len(hi) > 1 else [])
|
||||||
key_xlab = [xlab[bi]] + ([xlab[si[0]]] if si else []) + [xlab[hi[0]]] + ([xlab[hi[-1]]] if len(hi) > 1 else [])
|
key_xlab = [xlab[bi]] + ([xlab[si[0]]] if si else []) + [xlab[hi[0]]] + ([xlab[hi[-1]]] if len(hi) > 1 else [])
|
||||||
for axis, row, yv, raw, ytitle, ylog in [
|
for axis, row, yv, raw, ytitle, ylog in [
|
||||||
(1, 1, signals[top_key], signals[top_key], f"{top_key}_nats", False),
|
(1, 1, signals[primary], signals[primary], primary, False),
|
||||||
(3, 2, inc, coh, "1−coherence (↓, log)", True),
|
(3, 2, inc, coh, "1−coherence (↓, log)", True),
|
||||||
]:
|
]:
|
||||||
_connectors(fig, row, 1, axis, (xi[bi], yv[bi]),
|
_connectors(fig, row, 1, axis, (xi[bi], yv[bi]),
|
||||||
[(xi[i], yv[i]) for i in si], [(xi[i], yv[i]) for i in hi])
|
[(xi[i], yv[i]) for i in si], [(xi[i], yv[i]) for i in hi])
|
||||||
# steered points recede (smaller, lower opacity) — the heal trajectory is the story
|
|
||||||
for ids, c, sym, sz, op in [([bi], GREY, "star", 13, 1.0), (si, RED, "circle", 8, 0.6), (hi, GREEN, "circle", 10, 1.0)]:
|
for ids, c, sym, sz, op in [([bi], GREY, "star", 13, 1.0), (si, RED, "circle", 8, 0.6), (hi, GREEN, "circle", 10, 1.0)]:
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=[xi[i] for i in ids], y=[yv[i] for i in ids], mode="markers",
|
x=[xi[i] for i in ids], y=[yv[i] for i in ids], mode="markers",
|
||||||
@@ -134,37 +126,35 @@ def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
|||||||
fig.update_xaxes(tickmode="array", tickvals=key_xi, ticktext=key_xlab, tickangle=-30, row=2, col=1, showgrid=False)
|
fig.update_xaxes(tickmode="array", tickvals=key_xi, ticktext=key_xlab, tickangle=-30, row=2, col=1, showgrid=False)
|
||||||
fig.update_xaxes(showgrid=False, tickvals=[], row=1, col=1)
|
fig.update_xaxes(showgrid=False, tickvals=[], row=1, col=1)
|
||||||
|
|
||||||
# PANEL C (trait map): axes = the two biggest-MOVING of auth/care/coh over base+heal nodes.
|
# PANEL C (trait map): x = primary (on-target), y = biggest-moving off-target foundation.
|
||||||
# Healthy -> auth vs care (the moral-foundations plane); if coherence CRASHED its range beats
|
# If coherence crashed its range beats all foundations and it takes y (crash diagnostic).
|
||||||
# care and it becomes the y-axis. RED steer is omitted here: zoomed to the heal cluster the
|
# RED steer omitted: steered points fall off-scale and leave dangling connector stubs.
|
||||||
# steer points fall off-scale and leave dangling connector stubs. base + green heals only.
|
off_target_keys = [k for k in nat_keys if k != primary]
|
||||||
atitle = {"auth": "auth_nats (← more trait)", "care": "care_nats (more care →)"}
|
ykey = max(off_target_keys, key=rng)
|
||||||
xkey, ykey = sorted(sorted(["auth", "care", "coh"], key=rng, reverse=True)[:2],
|
if rng("coh") > rng(ykey): # coherence crash dominates; show as log-incoherence
|
||||||
key=["auth", "care", "coh"].index) # x = higher-priority of the chosen two
|
ykey = "coh"
|
||||||
# coh can only ever be the LOWEST-priority pick, so it lands on Y, never X. When it does
|
|
||||||
# (a crash run) plot it as log-incoherence to match panel B; else raw care/auth.
|
|
||||||
ycoh = ykey == "coh"
|
ycoh = ykey == "coh"
|
||||||
xv = signals[xkey]
|
xv = signals[primary]
|
||||||
yv = [max(1.0 - v, 1e-3) for v in signals[ykey]] if ycoh else signals[ykey]
|
yv = [max(1.0 - v, 1e-3) for v in signals[ykey]] if ycoh else signals[ykey]
|
||||||
yraw = signals[ykey] # for hover (real coherence / care value, not the log-incoherence coord)
|
yraw = signals[ykey]
|
||||||
|
|
||||||
_connectors(fig, 1, 2, 2, (xv[bi], yv[bi]), [], [(xv[i], yv[i]) for i in hi])
|
_connectors(fig, 1, 2, 2, (xv[bi], yv[bi]), [], [(xv[i], yv[i]) for i in hi])
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=[xv[bi]], y=[yv[bi]], mode="markers+text", text=["base"], textposition="bottom center",
|
x=[xv[bi]], y=[yv[bi]], mode="markers+text", text=["base"], textposition="bottom center",
|
||||||
marker=dict(size=14, color=GREY, symbol="star"), showlegend=False,
|
marker=dict(size=14, color=GREY, symbol="star"), showlegend=False,
|
||||||
hovertext=[f"base {xkey}={xv[bi]:.3f} {ykey}={yraw[bi]:.3f}"], hoverinfo="text"), row=1, col=2)
|
hovertext=[f"base {primary}={xv[bi]:.3f} {ykey}={yraw[bi]:.3f}"], hoverinfo="text"), row=1, col=2)
|
||||||
txt = [f"r{stages[i]['round']}" if stages[i]["round"] in (0, last_rnd) else "" for i in hi]
|
txt = [f"r{stages[i]['round']}" if stages[i]["round"] in (0, last_rnd) else "" for i in hi]
|
||||||
hov = [f"heal r{stages[i]['round']} auth={auth[i]:.3f} care={care[i]:.3f} coh={coh[i]:.3f}" for i in hi]
|
hov = [f"heal r{stages[i]['round']} " + " ".join(f"{k}={signals[k][i]:.3f}" for k in nat_keys) + f" coh={coh[i]:.3f}" for i in hi]
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=[xv[i] for i in hi], y=[yv[i] for i in hi], mode="markers+text",
|
x=[xv[i] for i in hi], y=[yv[i] for i in hi], mode="markers+text",
|
||||||
text=txt, textposition="bottom center", marker=dict(size=9, color=GREEN),
|
text=txt, textposition="bottom center", marker=dict(size=9, color=GREEN),
|
||||||
showlegend=False, hovertext=hov, hoverinfo="text"), row=1, col=2)
|
showlegend=False, hovertext=hov, hoverinfo="text"), row=1, col=2)
|
||||||
fig.update_xaxes(title_text=atitle[xkey], row=1, col=2)
|
fig.update_xaxes(title_text=primary, row=1, col=2)
|
||||||
if ycoh:
|
if ycoh:
|
||||||
fig.update_yaxes(title_text="incoherence 1−coh (↓ coherent, log)", type="log", row=1, col=2)
|
fig.update_yaxes(title_text="incoherence 1−coh (↓ coherent, log)", type="log", row=1, col=2)
|
||||||
fig.add_hline(y=0.05, line=dict(color="#cccccc", width=1, dash="dot"), row=1, col=2) # coh=0.95
|
fig.add_hline(y=0.05, line=dict(color="#cccccc", width=1, dash="dot"), row=1, col=2)
|
||||||
else:
|
else:
|
||||||
fig.update_yaxes(title_text=atitle[ykey], row=1, col=2)
|
fig.update_yaxes(title_text=ykey, row=1, col=2)
|
||||||
|
|
||||||
fig.update_xaxes(showgrid=False, row=1, col=2)
|
fig.update_xaxes(showgrid=False, row=1, col=2)
|
||||||
fig.update_yaxes(showgrid=False, row=1, col=2)
|
fig.update_yaxes(showgrid=False, row=1, col=2)
|
||||||
|
|||||||
+24
-13
@@ -104,7 +104,7 @@ def _log_stage_table(stages: list[dict], base_m: dict) -> None:
|
|||||||
+ tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
+ tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||||
|
|
||||||
|
|
||||||
def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[list[dict], list[dict], float, int]:
|
def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -> tuple[list[dict], list[dict], float, int]:
|
||||||
"""Adaptive-dose gen+filter (the controller steering.py:65 was written for).
|
"""Adaptive-dose gen+filter (the controller steering.py:65 was written for).
|
||||||
|
|
||||||
Walk the dose multiplier kappa DOWN until a batch clears cfg.gen_pass_target filter
|
Walk the dose multiplier kappa DOWN until a batch clears cfg.gen_pass_target filter
|
||||||
@@ -133,8 +133,8 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li
|
|||||||
for step in range(cfg.gen_max_batches):
|
for step in range(cfg.gen_max_batches):
|
||||||
mid_k = math.exp(0.5 * (math.log(lo_k) + math.log(hi_k)))
|
mid_k = math.exp(0.5 * (math.log(lo_k) + math.log(hi_k)))
|
||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
probe = generate_steered(model, tok, v, cfg, alpha_scale=mid_k, max_gens=cfg.gen_probe_n)
|
probe = generate_steered(model, tok, v, cfg, alpha_scale=mid_k, max_gens=cfg.gen_probe_n, rnd=rnd)
|
||||||
_, probe_scored = filter_completions(model, tok, probe, cfg)
|
_, probe_scored = filter_completions(model, tok, probe, cfg, brief=True)
|
||||||
probe_pass = [s for s in probe_scored if s["keep"]]
|
probe_pass = [s for s in probe_scored if s["keep"]]
|
||||||
rate = len(probe_pass) / max(len(probe_scored), 1)
|
rate = len(probe_pass) / max(len(probe_scored), 1)
|
||||||
n_gen += len(probe)
|
n_gen += len(probe)
|
||||||
@@ -172,8 +172,8 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li
|
|||||||
if len(kept_all) >= cfg.n_keep:
|
if len(kept_all) >= cfg.n_keep:
|
||||||
break
|
break
|
||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa)
|
comps = generate_steered(model, tok, v, cfg, alpha_scale=kappa, rnd=rnd)
|
||||||
_, scored = filter_completions(model, tok, comps, cfg)
|
_, scored = filter_completions(model, tok, comps, cfg, brief=True)
|
||||||
passing = [s for s in scored if s["keep"]]
|
passing = [s for s in scored if s["keep"]]
|
||||||
kept_all.extend(passing)
|
kept_all.extend(passing)
|
||||||
scored_all.extend(scored)
|
scored_all.extend(scored)
|
||||||
@@ -183,6 +183,16 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list) -> tuple[li
|
|||||||
f"(rate={len(passing)/max(len(comps),1):.2f}) → banked {len(kept_all)}/{cfg.n_keep}"
|
f"(rate={len(passing)/max(len(comps),1):.2f}) → banked {len(kept_all)}/{cfg.n_keep}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ONE raw sample per calibration (the probes/collect ran brief, count-only): the cleanest
|
||||||
|
# kept completion at the settled dose IS the training data -- judge coherence+trait by eye.
|
||||||
|
# Full per-alpha table + borderline samples live in the saved scored (log_event stage=gen).
|
||||||
|
kept_final = [s for s in scored_all if s["keep"]]
|
||||||
|
if kept_final:
|
||||||
|
best = min(kept_final, key=lambda s: s["ppl"])
|
||||||
|
logger.info(
|
||||||
|
f"\n\n\n=== r{rnd} walk-C SAMPLE (cleanest kept: alpha={best['alpha']:g} ppl={best['ppl']:.0f}) ===\n"
|
||||||
|
"SHOULD: coherent + on-trait (this is what trains the adapter). ELSE dose/filter off.\n"
|
||||||
|
f"{best['completion']}")
|
||||||
return kept_all[: cfg.n_keep], scored_all, kappa, n_gen # cap training set at n_keep (top-up may overshoot)
|
return kept_all[: cfg.n_keep], scored_all, kappa, n_gen # cap training set at n_keep (top-up may overshoot)
|
||||||
|
|
||||||
|
|
||||||
@@ -195,7 +205,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
# Base (no adapter, no steering) eval ONCE, so the run is self-contained: the
|
# Base (no adapter, no steering) eval ONCE, so the run is self-contained: the
|
||||||
# headline cue is coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of
|
# headline cue is coh_cost = |dCoh|/|dAuth| vs base (coherence lost per nat of
|
||||||
# trait), not just coherence. One extra eval per run.
|
# trait), not just coherence. One extra eval per run.
|
||||||
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||||
base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging)
|
base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging)
|
||||||
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
|
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
|
||||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||||
@@ -215,13 +225,13 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
"before the loop melts. demo=authority: defers to authority. ELSE chat-template/formatting issue.\n"
|
"before the loop melts. demo=authority: defers to authority. ELSE chat-template/formatting issue.\n"
|
||||||
f"PROMPT: {b0['prompt']}\nCOMPLETION: {b0['completion']}")
|
f"PROMPT: {b0['prompt']}\nCOMPLETION: {b0['completion']}")
|
||||||
for rnd in range(cfg.n_rounds):
|
for rnd in range(cfg.n_rounds):
|
||||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||||
# extract teacher vector from the CURRENT student, then walk-C generate+filter:
|
# extract teacher vector from the CURRENT student, then walk-C generate+filter:
|
||||||
# the controller cools the dose so the steered data stays coherent as the adapter
|
# the controller cools the dose so the steered data stays coherent as the adapter
|
||||||
# accumulates trait over rounds (gen baked, filter under original -- see gen_filter_walk).
|
# accumulates trait over rounds (gen baked, filter under original -- see gen_filter_walk).
|
||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
v = teacher_vec(model, tok, cfg)
|
v = teacher_vec(model, tok, cfg)
|
||||||
kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs)
|
kept, scored, kappa, n_comps = gen_filter_walk(model, tok, v, cfg, hist_specs, rnd)
|
||||||
# collect highest-alpha dropped sample for headline prompt -> diary Night entry
|
# collect highest-alpha dropped sample for headline prompt -> diary Night entry
|
||||||
headline = gen_rounds[0]["gens"][0]["user"]
|
headline = gen_rounds[0]["gens"][0]["user"]
|
||||||
dream_cands = [s for s in scored if s["user"] == headline and not s.get("keep", True)]
|
dream_cands = [s for s in scored if s["user"] == headline and not s.get("keep", True)]
|
||||||
@@ -233,7 +243,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
# STEERED-stage eval at the dose the data ACTUALLY came from (kappa-scaled cleanest alpha),
|
# STEERED-stage eval at the dose the data ACTUALLY came from (kappa-scaled cleanest alpha),
|
||||||
# history baked, NO new adapter: the raw-steering pareto reference the heal must BEAT.
|
# history baked, NO new adapter: the raw-steering pareto reference the heal must BEAT.
|
||||||
c_lo = kappa * cfg.alphas[0]
|
c_lo = kappa * cfg.alphas[0]
|
||||||
logger.info(f"\n=== EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n\n=== r{rnd} EVAL steered [c={c_lo:.2f} kappa={kappa:.2f}] gpu {gpu_mem()} ===")
|
||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
with v(model, C=c_lo * v.cfg.coeff):
|
with v(model, C=c_lo * v.cfg.coeff):
|
||||||
m_steer = evaluate_model(model, tok, cfg)
|
m_steer = evaluate_model(model, tok, cfg)
|
||||||
@@ -241,13 +251,13 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
log_event(run_dir, stage="gen", round=rnd, n_comps=n_comps, n_kept=len(kept), kappa=kappa, scored=scored)
|
log_event(run_dir, stage="gen", round=rnd, n_comps=n_comps, n_kept=len(kept), kappa=kappa, scored=scored)
|
||||||
|
|
||||||
# heal one round on top of the baked history, then fold
|
# heal one round on top of the baked history, then fold
|
||||||
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n\n=== r{rnd} HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
||||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||||
hist_specs.append(spec)
|
hist_specs.append(spec)
|
||||||
|
|
||||||
# eval the student (all rounds baked) + Q1: trained-adapter output coherence
|
# eval the student (all rounds baked) + Q1: trained-adapter output coherence
|
||||||
logger.info(f"\n=== EVAL [tinymfv classic] gpu {gpu_mem()} ===")
|
logger.info(f"\n\n\n=== r{rnd} EVAL [tinymfv classic] gpu {gpu_mem()} ===")
|
||||||
with baked(model, hist_specs):
|
with baked(model, hist_specs):
|
||||||
m = evaluate_model(model, tok, cfg)
|
m = evaluate_model(model, tok, cfg)
|
||||||
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
||||||
@@ -278,7 +288,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
# model is lukewarm/guarded about); if no trait at all = no-op.
|
# model is lukewarm/guarded about); if no trait at all = no-op.
|
||||||
demo_lines = "\n".join(
|
demo_lines = "\n".join(
|
||||||
f" [{a['user'][:50]}]\n {' '.join(a['completion'].split())[:240]}" for a in adapter)
|
f" [{a['user'][:50]}]\n {' '.join(a['completion'].split())[:240]}" for a in adapter)
|
||||||
logger.info(f"\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} "
|
logger.info(f"\n\n\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} "
|
||||||
f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines)
|
f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines)
|
||||||
|
|
||||||
vf = _flatten_v(v)
|
vf = _flatten_v(v)
|
||||||
@@ -305,7 +315,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
|||||||
_log_loop_summary(rounds, base_m)
|
_log_loop_summary(rounds, base_m)
|
||||||
_log_stage_table(stages, base_m)
|
_log_stage_table(stages, base_m)
|
||||||
write_map(run_dir, rounds)
|
write_map(run_dir, rounds)
|
||||||
png = write_trajectory(run_dir, stages) # before the report (report embeds trajectory.png)
|
primary_key = "care_nats" if cfg.demo == "love" else "auth_nats"
|
||||||
|
png = write_trajectory(run_dir, stages, primary_key=primary_key)
|
||||||
report_html = write_report(run_dir, gen_rounds)
|
report_html = write_report(run_dir, gen_rounds)
|
||||||
diary = write_diary(run_dir, cfg, gen_rounds, steer_samples, rounds, base_m["care_nats"])
|
diary = write_diary(run_dir, cfg, gen_rounds, steer_samples, rounds, base_m["care_nats"])
|
||||||
logger.info(f"diary: {diary}")
|
logger.info(f"diary: {diary}")
|
||||||
|
|||||||
Reference in New Issue
Block a user