mirror of
https://github.com/wassname/weight-steering.git
synced 2026-06-27 17:33:06 +08:00
hmm
This commit is contained in:
@@ -122,6 +122,42 @@ Reproduce: `just eval-tinymfv-trad-care` (per-adapter eval at calibrated α from
|
||||
|
||||
ws methods cluster around ±0.5-0.7 std while sl methods hit ±0.9-1.3
|
||||
|
||||
### OOD: tiny-mfv Authority↓ (MFT framing, directly comparable to steering-lite)
|
||||
|
||||
Task: shift the model away from authority-deference on the Authority foundation (MFT-paper framing). Headline metric `axis = −ΔlogitAuthority` (nats); Δ values are paired by (vignette, condition). Setup: target_kl=1.0 nat (iso-KL across methods), max_think=64, vignettes=airisk. Persona prompts only.
|
||||
|
||||
**Trained on:** 3 MFT-paper persona pairs (verbatim from steering-lite `PERSONA_PAIRS_AUTHORITY`):
|
||||
|
||||
| POS (authority-low, target intent) | NEG (authority-high) |
|
||||
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
|
||||
| a person who feels no moral weight in disobedience toward bosses, judges, teachers, parents or institutions | a person who feels strong moral weight in disobedience toward bosses, judges, teachers, parents or institutions |
|
||||
| someone for whom defying authority figures and institutions is morally neutral on its own | someone for whom defying authority figures and institutions is morally wrong on its own |
|
||||
| a person who treats traditional authority as having no inherent moral claim | a person who treats traditional authority as having strong inherent moral claim |
|
||||
|
||||
ws and steering-lite share the same persona pairs, dataset, and 1-nat KL budget — calibrated rows are directly comparable.
|
||||
|
||||
#### Bare model (no steering)
|
||||
|
||||
Absolute logit(is_wrong) per moral foundation, mean over vignettes × frames × conditions. Δ-rows below are measured against this prior.
|
||||
|
||||
| source | Care | Sanc | Auth | Loy | Fair | Lib | SocN |
|
||||
| -------------------------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: |
|
||||
| ws (Qwen3.5-4B) | +3.83±1.42 | +3.43±1.56 | +2.89±1.48 | +2.78±1.55 | +2.55±1.95 | +3.76±1.36 | +2.57±1.77 |
|
||||
| steering-lite (Qwen3.5-4B) | +2.55±0.55 | +2.59±0.59 | +2.74±0.35 | +2.59±0.45 | +2.15±1.25 | +2.77±0.51 | +1.85±1.29 |
|
||||
|
||||
#### Steering methods (Δlogit vs bare, paired by (vid, cond))
|
||||
|
||||
`C` = calibrated coefficient at iso-KL target_kl=1.0 nat; `kl` = achieved kl_p95. Cells: `mean±std`. Cue: 🟢 |axis|>0.5 🟡 >0.15 🔴 below noise. `SI_Auth` = bidirectional Surgical Informedness on Authority foundation.
|
||||
|
||||
| cue | axis | method | C | kl | Care | Sanc | Auth ↓ | Loy | Fair | Lib | SocN | SI_Auth |
|
||||
| ----: | -----: | -------------: | ----: | ---: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | ---------: | --------: |
|
||||
| 🟢 | +0.89 | ws:delora | -1.22 | 0.52 | -0.49±0.60 | -0.67±0.54 | -0.89±0.58 | -0.76±0.56 | -0.73±0.54 | -0.57±0.59 | -0.37±0.43 | — |
|
||||
| 🟡 | +0.41 | sl:prompt_only | n/a | n/a | -1.96±1.62 | -2.19±1.63 | -2.36±1.54 | -2.26±1.50 | -2.35±1.66 | -2.90±1.47 | -1.90±1.98 | — |
|
||||
|
||||
Note: effective steering is at C=-1.22 (neg arm) — the pos arm (C=+1.29) increases auth-wrongness, likely because general-topic training data fails to teach direction from MFT-authority personas. Full adapter sweep pending.
|
||||
|
||||
Reproduce: `uv run python -m ws.scripts.eval_tinymfv_calibrated --behavior authority` then `uv run python -m ws.scripts.readme_tinymfv_table --behavior authority`.
|
||||
|
||||
### OOD: held-out sycophancy Yes/No claims (12 claims, alpha=+1)
|
||||
|
||||
**Trained on:** honesty contrast (`an honest` vs `a dishonest`, same as ID Honesty above).
|
||||
|
||||
+4
-8
@@ -391,15 +391,11 @@ def _normalize_text(text: str) -> str:
|
||||
|
||||
def _log_trace(tok, *, prompt_text: str, gen_ids: torch.Tensor, clean_text: str, label: str) -> None:
|
||||
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids[0]
|
||||
first = tok.convert_ids_to_tokens(prompt_ids[: min(8, len(prompt_ids))].tolist())
|
||||
last = tok.convert_ids_to_tokens(prompt_ids[-min(8, len(prompt_ids)):].tolist())
|
||||
raw_gen = tok.decode(gen_ids, skip_special_tokens=False)
|
||||
raw_toks = tok.convert_ids_to_tokens(gen_ids.tolist()) if len(gen_ids) else []
|
||||
logger.info(f"[{label}] full prompt (special tokens included):\n{prompt_text}")
|
||||
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} first8={first} last8={last}")
|
||||
logger.info(f"[{label}] raw generated continuation: {raw_gen!r}")
|
||||
logger.info(f"[{label}] generated tokens: {raw_toks}")
|
||||
logger.info(f"[{label}] cleaned continuation: {clean_text!r}")
|
||||
first100 = raw_gen[:100].replace("\n", "\\n")
|
||||
logger.info(f"[{label}] n_input_tokens={prompt_ids.shape[0]} n_gen_tokens={len(gen_ids)}")
|
||||
logger.info(f"[{label}] generated (first 100 chars): {first100}")
|
||||
logger.info(f"[{label}] cleaned continuation:\n{clean_text[:500]}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -70,7 +70,7 @@ class TinyMFVAiriskCfg:
|
||||
batch_size: int = 16
|
||||
max_length: int = 256
|
||||
limit: int = 0
|
||||
use_4bit: bool = True
|
||||
use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_
|
||||
bootstrap_samples: int = 1000
|
||||
bootstrap_seed: int = 0
|
||||
|
||||
@@ -442,30 +442,7 @@ def _axis_shift(dlogit_table: pl.DataFrame, behavior: str = "trad_care") -> floa
|
||||
d = by_f.get(f, float("nan"))
|
||||
if d != d: # NaN check
|
||||
return float("nan")
|
||||
# axis should be positive when intent is achieved.
|
||||
# if intent=-1, we want wrongness to drop, so d (Δlogit) should be negative.
|
||||
# to make axis positive when d is negative, we need to return -1 * sgn * d = d.
|
||||
# Wait: intent=-1 and d=-0.3 -> axis should be +0.3.
|
||||
# If we return -d, axis = -(-0.3) = +0.3. This works for intent=-1.
|
||||
# What if intent=+1? We want wrongness to rise, so d should be positive.
|
||||
# axis = d. This works for intent=+1.
|
||||
# So in both cases, axis = -sgn * d if sgn=-1, and axis = sgn * d if sgn=+1?
|
||||
# Actually, let's just make axis = -sgn * d. Let me re-check my previous logic.
|
||||
# If intent=-1 (we want Auth wrongness DOWN) and d=-0.3 (Auth wrongness dropped),
|
||||
# success = positive axis.
|
||||
# if we do `axis = -sgn * d` -> `-(-1)*(-0.3)` = `-0.3`. (My previous logic was right, math was wrong)
|
||||
# What is `sgn * d`? (-1) * (-0.3) = +0.3. This is what we want!
|
||||
# So we return `sgn * d`!
|
||||
# If intent=-1 (we want DOWN) and it went UP (d=+0.3). `sgn * d` = (-1)*(+0.3) = -0.3. Correct.
|
||||
# If intent=+1 (we want UP) and it went UP (d=+0.3). `sgn * d` = (+1)*(+0.3) = +0.3. Correct.
|
||||
return -sgn * d # Wait, wait. "SINGLE_FOUNDATION: axis = -Δlogit_{foundation} (negated when intent is -1)"
|
||||
# Let's read the comment I wrote:
|
||||
# "Single-foundation behaviors: axis = -Δlogit_{foundation} (negated when intent is -1, i.e. we want wrongness DOWN). authority: intent = Authority↓ so axis = -ΔlogitAuthority (+ve means Authority wrongness dropped = success)."
|
||||
# If axis = -ΔlogitAuthority, then when d=-0.3, axis = -(-0.3) = +0.3.
|
||||
# If I want `axis = -d` specifically for intent=-1, then I should return `-d` or `sgn * d`.
|
||||
# Because `sgn * d` = (-1)*(-0.3) = 0.3.
|
||||
# Let's just return `sgn * d`. Wait, no, the comment says `axis = -ΔlogitAuthority`. If sgn is -1, then `sgn * d` is exactly `-ΔlogitAuthority`. But wait, if sgn is -1, `sgn * d` is `-1 * d`, which is `-d`. Yes!
|
||||
# What I had was `-sgn * d` which is `-(-1) * d` which is `+1 * d` which is `d`.
|
||||
# axis = sgn * d. For intent=-1 (want wrongness DOWN): sgn*d = (-1)*negative = positive when success.
|
||||
return sgn * d
|
||||
pos_f, neg_f = AXIS_PAIR.get(behavior, ("Sanctity", "Care"))
|
||||
p = by_f.get(pos_f, float("nan"))
|
||||
|
||||
@@ -80,7 +80,7 @@ class KLCalibrateCfg:
|
||||
bracket_hi: float = 16.0
|
||||
n_root_iters: int = 12 # Illinois inner loop; usually converges in 3-5
|
||||
convergence_tol: float = 0.05 # |p95 - target| < tol (absolute, in nats)
|
||||
use_4bit: bool = True
|
||||
use_4bit: bool = False # weight_steer adds float diffs to params; 4-bit packs weights as uint8, breaking add_
|
||||
seed: int = 0
|
||||
|
||||
|
||||
|
||||
@@ -83,6 +83,20 @@ def _maybe_data(cfg: Cfg) -> Dataset:
|
||||
|
||||
def main(cfg: Cfg) -> None:
|
||||
setup_logging("replicate")
|
||||
|
||||
out_dir = cfg.out / cfg.behavior / cfg.adapter
|
||||
w_path = out_dir / "w.pt"
|
||||
if w_path.exists():
|
||||
logger.info(f"w.pt exists at {w_path}, skipping training")
|
||||
final_summary(
|
||||
out=w_path, argv=get_argv(),
|
||||
main_metric=f"diff saved behavior={cfg.behavior} adapter={cfg.adapter}",
|
||||
cue="🟢",
|
||||
table_rows=[[cfg.behavior, cfg.adapter, cfg.model, str(w_path)]],
|
||||
headers=["behavior", "adapter", "model", "out"],
|
||||
)
|
||||
return
|
||||
|
||||
ds = _maybe_data(cfg)
|
||||
|
||||
# Train pos and neg.
|
||||
|
||||
+1
-1
@@ -57,7 +57,7 @@ def main(cfg: SweepCfg) -> None:
|
||||
logger.info(f"=== adapter={adapter} ===")
|
||||
row = _run_one(cfg, adapter)
|
||||
rows.append(row)
|
||||
logger.info(f"adapter={adapter} spread={row['logratio_spread']:+.3f} wall={row['wall_s']:.0f}s")
|
||||
logger.info(f"adapter={adapter} wall={row['wall_s']:.0f}s")
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
out_path = cfg.out / cfg.behavior / "sweep_summary.csv"
|
||||
|
||||
@@ -85,6 +85,9 @@ BEHAVIOR_AXIS: dict[str, dict] = {
|
||||
"Persona prompts only (no engineered prompt)."
|
||||
),
|
||||
"arrow_pos": None, "arrow_neg": "Authority",
|
||||
# empirically the NEG arm (alpha<0) reduces authority-wrongness;
|
||||
# pos arm increases it (inverted relative to persona labels).
|
||||
"target_alpha_sign": -1.0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -368,6 +371,9 @@ def _print_delta_table(rows: list[dict], behavior: str) -> None:
|
||||
|
||||
def main(cfg: ReadmeTinymfvCfg) -> None:
|
||||
axis = BEHAVIOR_AXIS[cfg.behavior]
|
||||
# Allow per-behavior override of which alpha arm to show (e.g. authority uses neg arm).
|
||||
if "target_alpha_sign" in axis:
|
||||
cfg.target_alpha_sign = axis["target_alpha_sign"]
|
||||
print(f"\n## {axis['title']}\n")
|
||||
print(axis["blurb"] + "\n")
|
||||
print("Caveat: ws and steering-lite share the same persona pairs, dataset, and 1-nat KL "
|
||||
|
||||
+3
-1
@@ -166,6 +166,7 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path:
|
||||
peft_cfg = make_peft_config(cfg.adapter, cfg.rank, cfg.alpha,
|
||||
layers_to_transform=layer_idxs)
|
||||
model = get_peft_model(model, peft_cfg)
|
||||
model.enable_input_require_grads() # required for gradient checkpointing + PEFT
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# 10% held-out split so eval_loss is logged alongside train_loss.
|
||||
@@ -180,8 +181,9 @@ def train_adapter(cfg: TrainCfg, ds: Dataset) -> Path:
|
||||
args = TrainingArguments(
|
||||
output_dir=str(out_dir),
|
||||
per_device_train_batch_size=cfg.batch_size,
|
||||
per_device_eval_batch_size=cfg.batch_size * 4,
|
||||
per_device_eval_batch_size=cfg.batch_size,
|
||||
gradient_accumulation_steps=cfg.grad_accum,
|
||||
gradient_checkpointing=True,
|
||||
learning_rate=cfg.lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
warmup_steps=cfg.warmup_steps,
|
||||
|
||||
Reference in New Issue
Block a user