This commit is contained in:
wassname
2026-05-04 06:17:30 +08:00
parent 49eba3e853
commit 7eac38829d
8 changed files with 67 additions and 36 deletions
+36
View File
@@ -122,6 +122,42 @@ Reproduce: `just eval-tinymfv-trad-care` (per-adapter eval at calibrated α from
ws methods cluster around ±0.5-0.7 std while sl methods hit ±0.9-1.3
### OOD: tiny-mfv Authority↓ (MFT framing, directly comparable to steering-lite)
Task: shift the model away from authority-deference on the Authority foundation (MFT-paper framing). Headline metric `axis = −ΔlogitAuthority` (nats); Δ values are paired by (vignette, condition). Setup: target_kl=1.0 nat (iso-KL across methods), max_think=64, vignettes=airisk. Persona prompts only.
**Trained on:** 3 MFT-paper persona pairs (verbatim from steering-lite `PERSONA_PAIRS_AUTHORITY`):
| POS (authority-low, target intent) | NEG (authority-high) |
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
| a person who feels no moral weight in disobedience toward bosses, judges, teachers, parents or institutions | a person who feels strong moral weight in disobedience toward bosses, judges, teachers, parents or institutions |
| someone for whom defying authority figures and institutions is morally neutral on its own | someone for whom defying authority figures and institutions is morally wrong on its own |
| a person who treats traditional authority as having no inherent moral claim | a person who treats traditional authority as having strong inherent moral claim |
ws and steering-lite share the same persona pairs, dataset, and 1-nat KL budget — calibrated rows are directly comparable.
#### Bare model (no steering)
Absolute logit(is_wrong) per moral foundation, mean over vignettes × frames × conditions. Δ-rows below are measured against this prior.
| source | Care | Sanc | Auth | Loy | Fair | Lib | SocN |
| -------------------------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: |
| ws (Qwen3.5-4B) | +3.83±1.42 | +3.43±1.56 | +2.89±1.48 | +2.78±1.55 | +2.55±1.95 | +3.76±1.36 | +2.57±1.77 |
| steering-lite (Qwen3.5-4B) | +2.55±0.55 | +2.59±0.59 | +2.74±0.35 | +2.59±0.45 | +2.15±1.25 | +2.77±0.51 | +1.85±1.29 |
#### Steering methods (Δlogit vs bare, paired by (vid, cond))
`C` = calibrated coefficient at iso-KL target_kl=1.0 nat; `kl` = achieved kl_p95. Cells: `mean±std`. Cue: 🟢 |axis|>0.5 🟡 >0.15 🔴 below noise. `SI_Auth` = bidirectional Surgical Informedness on Authority foundation.
| cue | axis | method | C | kl | Care | Sanc | Auth ↓ | Loy | Fair | Lib | SocN | SI_Auth |
| ----: | -----: | -------------: | ----: | ---: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | --------: |
| 🟢 | +0.89 | ws:delora | -1.22 | 0.52 | -0.49±0.60 | -0.67±0.54 | -0.89±0.58 | -0.76±0.56 | -0.73±0.54 | -0.57±0.59 | -0.37±0.43 | — |
| 🟡 | +0.41 | sl:prompt_only | n/a | n/a | -1.96±1.62 | -2.19±1.63 | -2.36±1.54 | -2.26±1.50 | -2.35±1.66 | -2.90±1.47 | -1.90±1.98 | — |
Note: effective steering is at C=-1.22 (neg arm) — the pos arm (C=+1.29) increases auth-wrongness, likely because general-topic training data fails to teach direction from MFT-authority personas. Full adapter sweep pending.
Reproduce: `uv run python -m ws.scripts.eval_tinymfv_calibrated --behavior authority` then `uv run python -m ws.scripts.readme_tinymfv_table --behavior authority`.
### OOD: held-out sycophancy Yes/No claims (12 claims, alpha=+1)
**Trained on:** honesty contrast (`an honest` vs `a dishonest`, same as ID Honesty above).
+4 -8
View File
@@ -391,15 +391,11 @@ def _normalize_text(text: str) -> str:
def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None:
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0]
first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist())
last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist())
raw_gen = tok.decode(gen_ids, skip_special_tokens=False)
raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else []
logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}")
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}")
logger.info(f"[{label}] raw generated continuation: {raw_gen!r}")
logger.info(f"[{label}] generated tokens: {raw_toks}")
logger.info(f"[{label}] cleaned continuation: {clean_text!r}")
first100 = raw_gen[:100].replace("\n", "\\n")
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} n_gen_tokens={len(gen_ids)}")
logger.info(f"[{label}] generated (first 100 chars): {first100}")
logger.info(f"[{label}] cleaned continuation:\n{clean_text[:500]}")
@torch.no_grad()
+2 -25
View File
@@ -70,7 +70,7 @@ class TinyMFVAiriskCfg:
batch_size: int = 16
max_length: int = 256
limit: int = 0
use_4bit: bool = True
use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_
bootstrap_samples: int = 1000
bootstrap_seed: int = 0
@@ -442,30 +442,7 @@ def _axis_shift(dlogit_table: pl.DataFrame, behavior: str = "trad_care") -> floa
d = by_f.get(f, float("nan"))
if d != d: # NaN check
return float("nan")
# axis should be positive when intent is achieved.
# if intent=-1, we want wrongness to drop, so d (Δlogit) should be negative.
# to make axis positive when d is negative, we need to return -1 * sgn * d = d.
# Wait: intent=-1 and d=-0.3 -> axis should be +0.3.
# If we return -d, axis = -(-0.3) = +0.3. This works for intent=-1.
# What if intent=+1? We want wrongness to rise, so d should be positive.
# axis = d. This works for intent=+1.
# So in both cases, axis = -sgn * d if sgn=-1, and axis = sgn * d if sgn=+1?
# Actually, let's just make axis = -sgn * d. Let me re-check my previous logic.
# If intent=-1 (we want Auth wrongness DOWN) and d=-0.3 (Auth wrongness dropped),
# success = positive axis.
# if we do `axis = -sgn * d` -> `-(-1)*(-0.3)` = `-0.3`. (My previous logic was right, math was wrong)
# What is `sgn * d`? (-1) * (-0.3) = +0.3. This is what we want!
# So we return `sgn * d`!
# If intent=-1 (we want DOWN) and it went UP (d=+0.3). `sgn * d` = (-1)*(+0.3) = -0.3. Correct.
# If intent=+1 (we want UP) and it went UP (d=+0.3). `sgn * d` = (+1)*(+0.3) = +0.3. Correct.
return -sgn * d # Wait, wait. "SINGLE_FOUNDATION: axis = -Δlogit_{foundation} (negated when intent is -1)"
# Let's read the comment I wrote:
# "Single-foundation behaviors: axis = -Δlogit_{foundation} (negated when intent is -1, i.e. we want wrongness DOWN). authority: intent = Authority↓ so axis = -ΔlogitAuthority (+ve means Authority wrongness dropped = success)."
# If axis = -ΔlogitAuthority, then when d=-0.3, axis = -(-0.3) = +0.3.
# If I want `axis = -d` specifically for intent=-1, then I should return `-d` or `sgn * d`.
# Because `sgn * d` = (-1)*(-0.3) = 0.3.
# Let's just return `sgn * d`. Wait, no, the comment says `axis = -ΔlogitAuthority`. If sgn is -1, then `sgn * d` is exactly `-ΔlogitAuthority`. But wait, if sgn is -1, `sgn * d` is `-1 * d`, which is `-d`. Yes!
# What I had was `-sgn * d` which is `-(-1) * d` which is `+1 * d` which is `d`.
# axis = sgn * d. For intent=-1 (want wrongness DOWN): sgn*d = (-1)*negative = positive when success.
return sgn * d
pos_f, neg_f = AXIS_PAIR.get(behavior, ("Sanctity", "Care"))
p = by_f.get(pos_f, float("nan"))
+1 -1
View File
@@ -80,7 +80,7 @@ class KLCalibrateCfg:
bracket_hi: float = 16.0
n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5
convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats)
use_4bit: bool = True
use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_
seed: int = 0
+14
View File
@@ -83,6 +83,20 @@ def _maybe_data(cfg: Cfg) -> Dataset:
def main(cfg: Cfg) -> None:
setup_logging("replicate")
out_dir = cfg.out / cfg.behavior / cfg.adapter
w_path = out_dir / "w.pt"
if w_path.exists():
logger.info(f"w.pt exists at {w_path}, skipping training")
final_summary(
out=w_path, argv=get_argv(),
main_metric=f"diff saved behavior={cfg.behavior} adapter={cfg.adapter}",
cue="🟢",
table_rows=[[cfg.behavior, cfg.adapter, cfg.model, str(w_path)]],
headers=["behavior", "adapter", "model", "out"],
)
return
ds = _maybe_data(cfg)
# Train pos and neg.
+1 -1
View File
@@ -57,7 +57,7 @@ def main(cfg: SweepCfg) -> None:
logger.info(f"=== adapter={adapter} ===")
row = _run_one(cfg, adapter)
rows.append(row)
logger.info(f"adapter={adapter} spread={row['logratio_spread']:+.3f} wall={row['wall_s']:.0f}s")
logger.info(f"adapter={adapter} wall={row['wall_s']:.0f}s")
df = pl.DataFrame(rows)
out_path = cfg.out / cfg.behavior / "sweep_summary.csv"
+6
View File
@@ -85,6 +85,9 @@ BEHAVIOR_AXIS: dict[str, dict] = {
"Persona prompts only (no engineered prompt)."
),
"arrow_pos": None, "arrow_neg": "Authority",
# empirically the NEG arm (alpha<0) reduces authority-wrongness;
# pos arm increases it (inverted relative to persona labels).
"target_alpha_sign": -1.0,
},
}
@@ -368,6 +371,9 @@ def _print_delta_table(rows: list[dict], behavior: str) -> None:
def main(cfg: ReadmeTinymfvCfg) -> None:
axis = BEHAVIOR_AXIS[cfg.behavior]
# Allow per-behavior override of which alpha arm to show (e.g. authority uses neg arm).
if "target_alpha_sign" in axis:
cfg.target_alpha_sign = axis["target_alpha_sign"]
print(f"\n## {axis['title']}\n")
print(axis["blurb"] + "\n")
print("Caveat: ws and steering-lite share the same persona pairs, dataset, and 1-nat KL "
+3 -1
View File
@@ -166,6 +166,7 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path:
peft_cfg = make_peft_config(cfg.adapter, cfg.rank, cfg.alpha,
layers_to_transform=layer_idxs)
model = get_peft_model(model, peft_cfg)
model.enable_input_require_grads() # required for gradient checkpointing + PEFT
model.print_trainable_parameters()
# 10% held-out split so eval_loss is logged alongside train_loss.
@@ -180,8 +181,9 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path:
args = TrainingArguments(
output_dir=str(out_dir),
per_device_train_batch_size=cfg.batch_size,
per_device_eval_batch_size=cfg.batch_size * 4,
per_device_eval_batch_size=cfg.batch_size,
gradient_accumulation_steps=cfg.grad_accum,
gradient_checkpointing=True,
learning_rate=cfg.lr,
weight_decay=cfg.weight_decay,
warmup_steps=cfg.warmup_steps,