From 7eac38829d7c6b7b71988ab34c71798a873a630f Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Mon, 4 May 2026 06:17:30 +0800 Subject: [PATCH] hmm --- README.md | 36 ++++++++++++++++++++++++++ src/ws/data.py | 12 +++------ src/ws/eval/tinymfv_airisk.py | 27 ++----------------- src/ws/kl_calibrate.py | 2 +- src/ws/replicate.py | 14 ++++++++++ src/ws/run_sweep.py | 2 +- src/ws/scripts/readme_tinymfv_table.py | 6 +++++ src/ws/train.py | 4 ++- 8 files changed, 67 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 4ceb134..bb9ea29 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,42 @@ Reproduce: `just eval-tinymfv-trad-care` (per-adapter eval at calibrated α from ws methods cluster around ±0.5-0.7 std while sl methods hit ±0.9-1.3 +### OOD: tiny-mfv Authority↓ (MFT framing, directly comparable to steering-lite) + +Task: shift the model away from authority-deference on the Authority foundation (MFT-paper framing). Headline metric `axis = −ΔlogitAuthority` (nats); Δ values are paired by (vignette, condition). Setup: target_kl=1.0 nat (iso-KL across methods), max_think=64, vignettes=airisk. Persona prompts only. + +**Trained on:** 3 MFT-paper persona pairs (verbatim from steering-lite `PERSONA_PAIRS_AUTHORITY`): + +| POS (authority-low, target intent) | NEG (authority-high) | +| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | +| a person who feels no moral weight in disobedience toward bosses, judges, teachers, parents or institutions | a person who feels strong moral weight in disobedience toward bosses, judges, teachers, parents or institutions | +| someone for whom defying authority figures and institutions is morally neutral on its own | someone for whom defying authority figures and institutions is morally wrong on its own | +| a person who treats traditional authority as having no inherent moral claim | a person who treats traditional authority as having strong inherent moral claim | + +ws and steering-lite share the same persona pairs, dataset, and 1-nat KL budget — calibrated rows are directly comparable. + +#### Bare model (no steering) + +Absolute logit(is_wrong) per moral foundation, mean over vignettes × frames × conditions. Δ-rows below are measured against this prior. + +| source | Care | Sanc | Auth | Loy | Fair | Lib | SocN | +| -------------------------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | +| ws (Qwen3.5-4B) | +3.83±1.42 | +3.43±1.56 | +2.89±1.48 | +2.78±1.55 | +2.55±1.95 | +3.76±1.36 | +2.57±1.77 | +| steering-lite (Qwen3.5-4B) | +2.55±0.55 | +2.59±0.59 | +2.74±0.35 | +2.59±0.45 | +2.15±1.25 | +2.77±0.51 | +1.85±1.29 | + +#### Steering methods (Δlogit vs bare, paired by (vid, cond)) + +`C` = calibrated coefficient at iso-KL target_kl=1.0 nat; `kl` = achieved kl_p95. Cells: `mean±std`. Cue: 🟢 |axis|>0.5 🟡 >0.15 🔴 below noise. `SI_Auth` = bidirectional Surgical Informedness on Authority foundation. + +| cue | axis | method | C | kl | Care | Sanc | Auth ↓ | Loy | Fair | Lib | SocN | SI_Auth | +| ----: | -----: | -------------: | ----: | ---: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | --------: | +| 🟢 | +0.89 | ws:delora | -1.22 | 0.52 | -0.49±0.60 | -0.67±0.54 | -0.89±0.58 | -0.76±0.56 | -0.73±0.54 | -0.57±0.59 | -0.37±0.43 | — | +| 🟡 | +0.41 | sl:prompt_only | n/a | n/a | -1.96±1.62 | -2.19±1.63 | -2.36±1.54 | -2.26±1.50 | -2.35±1.66 | -2.90±1.47 | -1.90±1.98 | — | + +Note: effective steering is at C=-1.22 (neg arm) — the pos arm (C=+1.29) increases auth-wrongness, likely because general-topic training data fails to teach direction from MFT-authority personas. Full adapter sweep pending. + +Reproduce: `uv run python -m ws.scripts.eval_tinymfv_calibrated --behavior authority` then `uv run python -m ws.scripts.readme_tinymfv_table --behavior authority`. + ### OOD: held-out sycophancy Yes/No claims (12 claims, alpha=+1) **Trained on:** honesty contrast (`an honest` vs `a dishonest`, same as ID Honesty above). diff --git a/src/ws/data.py b/src/ws/data.py index c0a520b..55898f2 100644 --- a/src/ws/data.py +++ b/src/ws/data.py @@ -391,15 +391,11 @@ def _normalize_text(text: str) -> str: def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None: prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0] - first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist()) - last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist()) raw_gen = tok.decode(gen_ids, skip_special_tokens=False) - raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else [] - logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}") - logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}") - logger.info(f"[{label}] raw generated continuation: {raw_gen!r}") - logger.info(f"[{label}] generated tokens: {raw_toks}") - logger.info(f"[{label}] cleaned continuation: {clean_text!r}") + first100 = raw_gen[:100].replace("\n", "\\n") + logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} n_gen_tokens={len(gen_ids)}") + logger.info(f"[{label}] generated (first 100 chars): {first100}") + logger.info(f"[{label}] cleaned continuation:\n{clean_text[:500]}") @torch.no_grad() diff --git a/src/ws/eval/tinymfv_airisk.py b/src/ws/eval/tinymfv_airisk.py index 8eb61ec..7b24a6c 100644 --- a/src/ws/eval/tinymfv_airisk.py +++ b/src/ws/eval/tinymfv_airisk.py @@ -70,7 +70,7 @@ class TinyMFVAiriskCfg: batch_size: int = 16 max_length: int = 256 limit: int = 0 - use_4bit: bool = True + use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_ bootstrap_samples: int = 1000 bootstrap_seed: int = 0 @@ -442,30 +442,7 @@ def _axis_shift(dlogit_table: pl.DataFrame, behavior: str = "trad_care") -> floa d = by_f.get(f, float("nan")) if d != d: # NaN check return float("nan") - # axis should be positive when intent is achieved. - # if intent=-1, we want wrongness to drop, so d (Δlogit) should be negative. - # to make axis positive when d is negative, we need to return -1 * sgn * d = d. - # Wait: intent=-1 and d=-0.3 -> axis should be +0.3. - # If we return -d, axis = -(-0.3) = +0.3. This works for intent=-1. - # What if intent=+1? We want wrongness to rise, so d should be positive. - # axis = d. This works for intent=+1. - # So in both cases, axis = -sgn * d if sgn=-1, and axis = sgn * d if sgn=+1? - # Actually, let's just make axis = -sgn * d. Let me re-check my previous logic. - # If intent=-1 (we want Auth wrongness DOWN) and d=-0.3 (Auth wrongness dropped), - # success = positive axis. - # if we do `axis = -sgn * d` -> `-(-1)*(-0.3)` = `-0.3`. (My previous logic was right, math was wrong) - # What is `sgn * d`? (-1) * (-0.3) = +0.3. This is what we want! - # So we return `sgn * d`! - # If intent=-1 (we want DOWN) and it went UP (d=+0.3). `sgn * d` = (-1)*(+0.3) = -0.3. Correct. - # If intent=+1 (we want UP) and it went UP (d=+0.3). `sgn * d` = (+1)*(+0.3) = +0.3. Correct. - return -sgn * d # Wait, wait. "SINGLE_FOUNDATION: axis = -Δlogit_{foundation} (negated when intent is -1)" - # Let's read the comment I wrote: - # "Single-foundation behaviors: axis = -Δlogit_{foundation} (negated when intent is -1, i.e. we want wrongness DOWN). authority: intent = Authority↓ so axis = -ΔlogitAuthority (+ve means Authority wrongness dropped = success)." - # If axis = -ΔlogitAuthority, then when d=-0.3, axis = -(-0.3) = +0.3. - # If I want `axis = -d` specifically for intent=-1, then I should return `-d` or `sgn * d`. - # Because `sgn * d` = (-1)*(-0.3) = 0.3. - # Let's just return `sgn * d`. Wait, no, the comment says `axis = -ΔlogitAuthority`. If sgn is -1, then `sgn * d` is exactly `-ΔlogitAuthority`. But wait, if sgn is -1, `sgn * d` is `-1 * d`, which is `-d`. Yes! - # What I had was `-sgn * d` which is `-(-1) * d` which is `+1 * d` which is `d`. + # axis = sgn * d. For intent=-1 (want wrongness DOWN): sgn*d = (-1)*negative = positive when success. return sgn * d pos_f, neg_f = AXIS_PAIR.get(behavior, ("Sanctity", "Care")) p = by_f.get(pos_f, float("nan")) diff --git a/src/ws/kl_calibrate.py b/src/ws/kl_calibrate.py index 9dd672b..3723531 100644 --- a/src/ws/kl_calibrate.py +++ b/src/ws/kl_calibrate.py @@ -80,7 +80,7 @@ class KLCalibrateCfg: bracket_hi: float = 16.0 n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5 convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats) - use_4bit: bool = True + use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_ seed: int = 0 diff --git a/src/ws/replicate.py b/src/ws/replicate.py index 637f578..a3bfe8b 100644 --- a/src/ws/replicate.py +++ b/src/ws/replicate.py @@ -83,6 +83,20 @@ def _maybe_data(cfg: Cfg) -> Dataset: def main(cfg: Cfg) -> None: setup_logging("replicate") + + out_dir = cfg.out / cfg.behavior / cfg.adapter + w_path = out_dir / "w.pt" + if w_path.exists(): + logger.info(f"w.pt exists at {w_path}, skipping training") + final_summary( + out=w_path, argv=get_argv(), + main_metric=f"diff saved behavior={cfg.behavior} adapter={cfg.adapter}", + cue="🟢", + table_rows=[[cfg.behavior, cfg.adapter, cfg.model, str(w_path)]], + headers=["behavior", "adapter", "model", "out"], + ) + return + ds = _maybe_data(cfg) # Train pos and neg. diff --git a/src/ws/run_sweep.py b/src/ws/run_sweep.py index 30b987a..7a1cdf7 100644 --- a/src/ws/run_sweep.py +++ b/src/ws/run_sweep.py @@ -57,7 +57,7 @@ def main(cfg: SweepCfg) -> None: logger.info(f"=== adapter={adapter} ===") row = _run_one(cfg, adapter) rows.append(row) - logger.info(f"adapter={adapter} spread={row['logratio_spread']:+.3f} wall={row['wall_s']:.0f}s") + logger.info(f"adapter={adapter} wall={row['wall_s']:.0f}s") df = pl.DataFrame(rows) out_path = cfg.out / cfg.behavior / "sweep_summary.csv" diff --git a/src/ws/scripts/readme_tinymfv_table.py b/src/ws/scripts/readme_tinymfv_table.py index 2addf6f..f76f2c1 100644 --- a/src/ws/scripts/readme_tinymfv_table.py +++ b/src/ws/scripts/readme_tinymfv_table.py @@ -85,6 +85,9 @@ BEHAVIOR_AXIS: dict[str, dict] = { "Persona prompts only (no engineered prompt)." ), "arrow_pos": None, "arrow_neg": "Authority", + # empirically the NEG arm (alpha<0) reduces authority-wrongness; + # pos arm increases it (inverted relative to persona labels). + "target_alpha_sign": -1.0, }, } @@ -368,6 +371,9 @@ def _print_delta_table(rows: list[dict], behavior: str) -> None: def main(cfg: ReadmeTinymfvCfg) -> None: axis = BEHAVIOR_AXIS[cfg.behavior] + # Allow per-behavior override of which alpha arm to show (e.g. authority uses neg arm). + if "target_alpha_sign" in axis: + cfg.target_alpha_sign = axis["target_alpha_sign"] print(f"\n## {axis['title']}\n") print(axis["blurb"] + "\n") print("Caveat: ws and steering-lite share the same persona pairs, dataset, and 1-nat KL " diff --git a/src/ws/train.py b/src/ws/train.py index f6d6a70..9266a51 100644 --- a/src/ws/train.py +++ b/src/ws/train.py @@ -166,6 +166,7 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path: peft_cfg = make_peft_config(cfg.adapter, cfg.rank, cfg.alpha, layers_to_transform=layer_idxs) model = get_peft_model(model, peft_cfg) + model.enable_input_require_grads() # required for gradient checkpointing + PEFT model.print_trainable_parameters() # 10% held-out split so eval_loss is logged alongside train_loss. @@ -180,8 +181,9 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path: args = TrainingArguments( output_dir=str(out_dir), per_device_train_batch_size=cfg.batch_size, - per_device_eval_batch_size=cfg.batch_size * 4, + per_device_eval_batch_size=cfg.batch_size, gradient_accumulation_steps=cfg.grad_accum, + gradient_checkpointing=True, learning_rate=cfg.lr, weight_decay=cfg.weight_decay, warmup_steps=cfg.warmup_steps,