mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
trajectory plot (steer/heal zigzag + trait-coherence pareto) + barrier-vs-nll gradient pressure log
- plot.py write_trajectory: auth zigzag (steer red / heal green) over the pipeline, coherence panel below sharing x, and a trait(x)-vs-coherence(y) pareto map with separate steer/heal trajectories from base. PNG via kaleido + interactive html. Fixed coherence axes to [0.83,1.01] so ~0.001 noise does not fill the panel. - run.py: build a stages list carrying full eval dicts; derive the stage table from it; persist the steered eval to events.jsonl; render trajectory at end of run. - heal.py: log g_bar/g_nll = ||grad barrier|| / ||grad sft|| at each logged step. >>1 = barrier over-tight (undoing trait); 0 = inert. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+23
-4
@@ -22,6 +22,11 @@ def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
||||
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
||||
|
||||
|
||||
def _gnorm(grads) -> float: # L2 norm of a flat concat of (possibly None) param grads
|
||||
sq = sum(float(g.pow(2).sum()) for g in grads if g is not None)
|
||||
return sq ** 0.5
|
||||
|
||||
|
||||
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
||||
prompt_ids = tok(prompt, return_tensors="pt").input_ids[0].to(device)
|
||||
@@ -58,7 +63,12 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
logger.info(f"SHOULD: nll (SFT) falls as the adapter learns the trait; kl (barrier div) is 0 for "
|
||||
f"reg=nll/wd and >0 for kl_rev/kl_fwd; gnorm finite (not exploding). loss = nll + lam*relu(kl-tau). "
|
||||
f"If kl stays < tau={cfg.tau} the barrier NEVER fired and {cfg.reg} == nll (no regularisation).")
|
||||
logger.info(" step nll↓ kl loss↓ gnorm")
|
||||
logger.info(
|
||||
"SHOULD (barrier balance): g_bar/g_nll is the gradient-pressure ratio (||∇barrier|| / ||∇sft||). "
|
||||
">>1 -> barrier dominates, it is undoing the trait the SFT installs (over-tight: lower lam or raise tau); "
|
||||
"~1 -> balanced; 0 -> barrier inert (kl<tau, or reg=nll/wd where decay acts in the optimiser, not the loss)."
|
||||
)
|
||||
logger.info(" step nll↓ kl g_nll g_bar g_bar/g_nll loss↓ gnorm")
|
||||
pbar = tqdm(total=n_steps, desc=f"heal[{cfg.reg}]", mininterval=120, maxinterval=120)
|
||||
step = 0
|
||||
nlls = [] # per-step SFT loss; final = mean of last 5, the heal-stage number for the round table
|
||||
@@ -90,15 +100,24 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll, wd
|
||||
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
|
||||
barrier = cfg.lam * torch.relu(div - cfg.tau)
|
||||
loss = sft + barrier
|
||||
nlls.append(sft.item())
|
||||
log_now = step % max(1, n_steps // 20) == 0 or step == n_steps - 1
|
||||
if log_now:
|
||||
# split the gradient pressure: ||∇sft|| vs ||∇barrier|| (retain_graph -> still .backward below).
|
||||
# barrier has no grad path when kl<=tau (relu zeroed), so guard before autograd.grad.
|
||||
g_nll = _gnorm(torch.autograd.grad(sft, params, retain_graph=True, allow_unused=True))
|
||||
barrier_live = barrier.requires_grad and (div - cfg.tau).item() > 0
|
||||
g_bar = _gnorm(torch.autograd.grad(barrier, params, retain_graph=True, allow_unused=True)) if barrier_live else 0.0
|
||||
pressure = g_bar / g_nll if g_nll > 0 else float("nan")
|
||||
loss.backward()
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(params, 1.0)
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
if step % max(1, n_steps // 20) == 0 or step == n_steps - 1:
|
||||
if log_now:
|
||||
logger.info(f" {step:4d} {sft.item():5.2f} {div.detach().item():4.2f} "
|
||||
f"{loss.item():5.2f} {float(gnorm):5.1f}")
|
||||
f"{g_nll:5.1f} {g_bar:5.1f} {pressure:11.2f} {loss.item():5.2f} {float(gnorm):5.1f}")
|
||||
pbar.set_postfix(nll=f"{sft.item():.2f}", kl=f"{div.detach().item():.2f}", gn=f"{float(gnorm):.1f}")
|
||||
pbar.update(1)
|
||||
step += 1
|
||||
|
||||
+107
-5
@@ -1,8 +1,21 @@
|
||||
"""Loop map: Care (y) vs Authority (x) trajectory + coherence/cosine panels.
|
||||
"""Loop plots saved to out/{ts}_{slug}/.
|
||||
|
||||
Simplified from wassname/w2schar-mini csm/plot.py _build_scatter (full git-graph
|
||||
port is a later pass). One node per round; hover shows coherence and the
|
||||
round-0 cosine. Saved as out/{ts}_{slug}/map.html.
|
||||
trajectory.html (write_trajectory) is the narrative figure: it tells the
|
||||
steer->heal story the project is about.
|
||||
- left, stacked & x-shared: auth_nats over the pipeline (the up/down/up/down
|
||||
zigzag -- steering pushes the trait DOWN in red, heal lets it relax UP in
|
||||
green) and coherence directly below it (did the move cost coherence?).
|
||||
- right: the trait/coherence pareto MAP. x = auth_nats (the headline trait,
|
||||
left = more trait), y = coherence. The steer trajectory (red) and the heal
|
||||
trajectory (green) are drawn separately from the same base node, so you can
|
||||
read whether heal lands at a BETTER point (same trait, higher coherence) or
|
||||
just walks back toward base. care_nats rides in the hover.
|
||||
|
||||
map.html (write_map) is the older Care-vs-SocialNorms node-per-round view.
|
||||
|
||||
Tufte: one mark per datum, direct labels (r0,r1,..) instead of a legend on the
|
||||
map, no gridded chartjunk, color carries the steer/heal contrast (the one
|
||||
comparison that matters) and nothing else.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,6 +23,96 @@ from pathlib import Path
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
RED = "#c1272d" # steer: trait injected by the live vector (pre-heal)
|
||||
GREEN = "#1b7837" # heal: trait distilled into weights, vector off
|
||||
GREY = "#555555" # base: pristine round-0 original
|
||||
|
||||
|
||||
def _png(fig, out_html: Path) -> Path:
|
||||
fig.write_html(out_html, include_plotlyjs="cdn")
|
||||
out_png = out_html.with_suffix(".png")
|
||||
fig.write_image(out_png, width=1100, height=520, scale=2) # static, for chat/appendix
|
||||
return out_png
|
||||
|
||||
|
||||
def write_trajectory(run_dir: Path, stages: list[dict]) -> Path:
|
||||
"""stages: ordered list of {round, stage in {base,steered,healed}, m: eval-dict}.
|
||||
The eval-dict carries auth_nats, care_nats, coherence."""
|
||||
auth = [s["m"]["auth_nats"] for s in stages]
|
||||
coh = [s["m"]["coherence"] for s in stages]
|
||||
care = [s["m"]["care_nats"] for s in stages]
|
||||
kind = [s["stage"] for s in stages]
|
||||
# x of the zigzag = pipeline order; label each tick base / r0·steer / r0·heal / ...
|
||||
xi = list(range(len(stages)))
|
||||
xlab = ["base" if k == "base" else f"r{s['round']}·{k[:5]}" for s, k in zip(stages, kind)]
|
||||
col = [GREY if k == "base" else RED if k == "steered" else GREEN for k in kind]
|
||||
|
||||
fig = make_subplots(
|
||||
rows=2, cols=2, column_widths=[0.52, 0.48], row_heights=[0.5, 0.5],
|
||||
vertical_spacing=0.10, horizontal_spacing=0.11,
|
||||
specs=[[{"type": "scatter"}, {"type": "scatter", "rowspan": 2}],
|
||||
[{"type": "scatter"}, None]],
|
||||
subplot_titles=("trait: auth_nats over the pipeline (down = trait)",
|
||||
"pareto map: trait (x) vs coherence (y)",
|
||||
"coherence (hold ~1.0)"),
|
||||
)
|
||||
|
||||
# -- left top: auth zigzag. one connecting line (pipeline order) + colored markers.
|
||||
fig.add_trace(go.Scatter(
|
||||
x=xi, y=auth, mode="lines+markers", line=dict(color="#bbbbbb", width=1),
|
||||
marker=dict(size=12, color=col), showlegend=False,
|
||||
hovertext=[f"{l}: auth={a:.3f}" for l, a in zip(xlab, auth)], hoverinfo="text",
|
||||
), row=1, col=1)
|
||||
fig.update_yaxes(title_text="auth_nats (↓ trait)", row=1, col=1)
|
||||
|
||||
# -- left bottom: coherence, same x, shared tick labels.
|
||||
fig.add_trace(go.Scatter(
|
||||
x=xi, y=coh, mode="lines+markers", line=dict(color="#bbbbbb", width=1),
|
||||
marker=dict(size=12, color=col), showlegend=False,
|
||||
hovertext=[f"{l}: coh={c:.3f}" for l, c in zip(xlab, coh)], hoverinfo="text",
|
||||
), row=2, col=1)
|
||||
# fix the coherence range to [floor, ceiling] so autoscale doesn't blow up ~0.001 of noise
|
||||
# into the whole panel; the honest story is coherence pinned near 1.0. 0.95 = coherent floor.
|
||||
fig.update_yaxes(title_text="coherence (→1.0)", range=[0.83, 1.01], row=2, col=1)
|
||||
fig.add_hline(y=0.95, line=dict(color="#cccccc", width=1, dash="dot"), row=2, col=1)
|
||||
fig.update_xaxes(tickmode="array", tickvals=xi, ticktext=xlab, tickangle=-40, row=2, col=1)
|
||||
fig.update_xaxes(tickmode="array", tickvals=xi, ticktext=["" for _ in xi], row=1, col=1)
|
||||
|
||||
# -- right: pareto map. base node, then steer & heal trajectories from it.
|
||||
base = next(s for s in stages if s["stage"] == "base")
|
||||
bx, by = base["m"]["auth_nats"], base["m"]["coherence"]
|
||||
fig.add_trace(go.Scatter(
|
||||
x=[bx], y=[by], mode="markers+text", text=["base"], textposition="bottom center",
|
||||
marker=dict(size=14, color=GREY, symbol="star"), showlegend=False,
|
||||
hovertext=[f"base auth={bx:.3f} coh={by:.3f}"], hoverinfo="text",
|
||||
), row=1, col=2)
|
||||
for stage_kind, color, label in [("steered", RED, "steer"), ("healed", GREEN, "heal")]:
|
||||
pts = [s for s in stages if s["stage"] == stage_kind]
|
||||
xs = [bx] + [p["m"]["auth_nats"] for p in pts]
|
||||
ys = [by] + [p["m"]["coherence"] for p in pts]
|
||||
txt = [""] + [f"r{p['round']}" for p in pts]
|
||||
hov = [f"base"] + [f"{label} r{p['round']} auth={p['m']['auth_nats']:.3f} "
|
||||
f"coh={p['m']['coherence']:.3f} care={p['m']['care_nats']:.3f}" for p in pts]
|
||||
fig.add_trace(go.Scatter(
|
||||
x=xs, y=ys, mode="lines+markers+text", text=txt, textposition="top center",
|
||||
line=dict(color=color, width=2), marker=dict(size=11, color=color),
|
||||
name=label, showlegend=False, hovertext=hov, hoverinfo="text",
|
||||
), row=1, col=2)
|
||||
fig.update_xaxes(title_text="auth_nats (← more trait)", row=1, col=2)
|
||||
# same fixed coherence range as the line panel: shows the points hug the ceiling (coherence
|
||||
# is not the binding constraint here), so the whole story is the horizontal trait move.
|
||||
fig.update_yaxes(title_text="coherence (↑ better)", range=[0.83, 1.01], row=1, col=2)
|
||||
fig.add_hline(y=0.95, line=dict(color="#cccccc", width=1, dash="dot"), row=1, col=2)
|
||||
|
||||
fig.update_layout(
|
||||
template="simple_white", height=520, width=1100,
|
||||
title_text="steer (red) -> heal (green): does heal keep the trait at higher coherence?",
|
||||
showlegend=False, # red/green stated in the title; map points are directly labelled r0,r1
|
||||
)
|
||||
out_html = run_dir / "trajectory.html"
|
||||
out_png = _png(fig, out_html)
|
||||
return out_png
|
||||
|
||||
|
||||
def write_map(run_dir: Path, rounds: list[dict]) -> Path:
|
||||
r = [d["round"] for d in rounds]
|
||||
@@ -18,7 +121,6 @@ def write_map(run_dir: Path, rounds: list[dict]) -> Path:
|
||||
subplot_titles=("trait map: Care vs SocialNorms", "coherence + direction per round"),
|
||||
specs=[[{"type": "scatter"}, {"type": "scatter"}]],
|
||||
)
|
||||
# trajectory across the SocialNorms axis (trait moves it DOWN, Care UP), coloured by round
|
||||
fig.add_trace(go.Scatter(
|
||||
x=[d["socialnorms"] for d in rounds], y=[d["care"] for d in rounds],
|
||||
mode="lines+markers+text", text=[f"r{i}" for i in r], textposition="top center",
|
||||
|
||||
+16
-12
@@ -21,7 +21,7 @@ from steer_heal.eval import evaluate_model
|
||||
from steer_heal.filter import filter_completions, ppl_under_base
|
||||
from steer_heal.heal import heal_round
|
||||
from steer_heal.io import append_result, log_event, make_run_dir
|
||||
from steer_heal.plot import write_map
|
||||
from steer_heal.plot import write_map, write_trajectory
|
||||
from steer_heal.steering import generate_plain, generate_steered, gpu_mem, teacher_vec
|
||||
from steer_heal.ws.bake import baked
|
||||
|
||||
@@ -71,20 +71,21 @@ def _mean_finite(xs, label: str = "ppl") -> float:
|
||||
return sum(xs) / len(xs) if xs else float("nan")
|
||||
|
||||
|
||||
def _stage_row(rnd, stage: str, m: dict, base_m: dict) -> dict:
|
||||
"""One row of the base->steered->healed pareto table. dcoh/dauth = coherence
|
||||
CHANGE per nat of Authority CHANGE vs base (signed): positive = coherence lost
|
||||
while trait gained (both fall), the cost we want low; nan for the base row (0/0)."""
|
||||
def _stage_row(stg: dict, base_m: dict) -> dict:
|
||||
"""One row of the base->steered->healed pareto table from a stage {round,stage,m}.
|
||||
dcoh/dauth = coherence CHANGE per nat of Authority CHANGE vs base (signed): positive
|
||||
= coherence lost while trait gained (both fall), the cost we want low; nan for base."""
|
||||
m = stg["m"]
|
||||
dAuth = m["auth_nats"] - base_m["auth_nats"]
|
||||
dCoh = m["coherence"] - base_m["coherence"]
|
||||
ratio = dCoh / dAuth if abs(dAuth) > 1e-6 else float("nan")
|
||||
# arrows in keys -> render in the table header. dcoh/dauth: lower=better (0 = trait at
|
||||
# no coherence cost; >0 = paid coherence; <0 = coherence rose too). coh: hold ~1.0. auth: down=trait.
|
||||
return {"round": rnd, "stage": stage, "dcoh/dauth↓": ratio,
|
||||
return {"round": stg["round"], "stage": stg["stage"], "dcoh/dauth↓": ratio,
|
||||
"coh→": m["coherence"], "auth↓": m["auth_nats"], "care": m["care_nats"]}
|
||||
|
||||
|
||||
def _log_stage_table(stage_rows: list[dict]) -> None:
|
||||
def _log_stage_table(stages: list[dict], base_m: dict) -> None:
|
||||
from tabulate import tabulate
|
||||
logger.info(
|
||||
"\nstage pareto (base -> steered -> healed, per round):\n"
|
||||
@@ -92,7 +93,7 @@ def _log_stage_table(stage_rows: list[dict]) -> None:
|
||||
" coh→ = p_any_ans coherence (hold ~1.0) auth↓ = log p[Authority] (DOWN = trait) care = log p[Care] (off-target)\n"
|
||||
" WIN: healed keeps steered's low auth (trait) but recovers coh toward base AND a smaller dcoh/dauth than steered.\n"
|
||||
" UNDO: healed auth springs back to ~base while coh recovers -> heal removed the trait, not just the incoherence.\n"
|
||||
+ tabulate(stage_rows, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
+ tabulate([_stage_row(s, base_m) for s in stages], headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
@@ -104,7 +105,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
# trait), not just coherence. One extra eval per run.
|
||||
logger.info(f"\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
base_m = evaluate_model(model, tok, cfg)
|
||||
stage_rows = [_stage_row("-", "base", base_m, base_m)] # pareto table: base -> steered -> healed
|
||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
@@ -119,6 +120,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
logger.info(f"\n=== EVAL steered [c={cfg.alphas[0]}] gpu {gpu_mem()} ===")
|
||||
with v(model, C=c_op):
|
||||
m_steer = evaluate_model(model, tok, cfg)
|
||||
log_event(run_dir, stage="steered_eval", round=rnd, c=cfg.alphas[0], **m_steer) # persist for offline plot
|
||||
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
|
||||
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
@@ -153,15 +155,17 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
"adapter_ppl": adapter_ppl, "n_comps": len(comps), "n_kept": len(kept),
|
||||
"heal_nll": heal_nll}
|
||||
rounds.append(rec)
|
||||
stage_rows.append(_stage_row(rnd, "steered", m_steer, base_m))
|
||||
stage_rows.append(_stage_row(rnd, "healed", m, base_m))
|
||||
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
|
||||
stages.append({"round": rnd, "stage": "healed", "m": m})
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} "
|
||||
f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
|
||||
_log_loop_summary(rounds, base_m)
|
||||
_log_stage_table(stage_rows)
|
||||
_log_stage_table(stages, base_m)
|
||||
write_map(run_dir, rounds)
|
||||
png = write_trajectory(run_dir, stages)
|
||||
logger.info(f"trajectory plot: {png} (and {png.with_suffix('.html')})")
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user