mirror of
https://github.com/wassname/weight-steering.git
synced 2026-06-27 17:33:06 +08:00
321 lines
14 KiB
Python
321 lines
14 KiB
Python
"""SI / raw-logratio / flip-count tables across adapters, prompts, RepE, IID syc.
|
|
|
|
Loads existing per-row CSVs and produces, for each source:
|
|
T1: SI summary (incl. SI_best = best sign for honesty,
|
|
SI_best_intended = best intended-direction SI, k_fpr=1 symmetric variant,
|
|
fix_rate/broke_rate components)
|
|
T2: raw mean +- std logratio per (method, coeff) with N seeds column
|
|
T3: raw flip counts (n_cho/n_rej at ref; fix/broke fwd; flip/counter rev)
|
|
|
|
Prompt baselines are mapped to alpha = -1 / 0 / +1 by pairing dishonest +
|
|
base + honest under the same template family (simple, engineered).
|
|
|
|
Sources:
|
|
out/honesty/cross_adapter_full_dd/dilemmas_per_row.csv (adapters, OOD)
|
|
out/honesty/prompt_baseline/dilemmas_per_row.csv (prompts + dW, OOD)
|
|
out/honesty/activation_baseline/dilemmas_per_row.csv (RepE + dW, OOD)
|
|
out/honesty/<adapter>/sycophancy_per_row.csv (per-adapter, IID)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import polars as pl
|
|
from tabulate import tabulate
|
|
|
|
|
|
N_SEEDS = 1 # update when multiseed runs land
|
|
|
|
# Map prompt method names to (template_family, signed_alpha).
|
|
PROMPT_PAIRS = {
|
|
"base": ("base", 0.0),
|
|
"simple_honest_prompt": ("simple", 1.0),
|
|
"simple_dishonest_prompt": ("simple", -1.0),
|
|
"engineered_prompt_honest": ("engineered", 1.0),
|
|
"engineered_prompt_dishonest": ("engineered", -1.0),
|
|
}
|
|
|
|
|
|
def _flip_counts_from_arrays(y_ref, y_pos, y_neg):
|
|
cho = y_ref > 0
|
|
rej = y_ref < 0
|
|
n_cho = int(cho.sum()); n_rej = int(rej.sum())
|
|
fix_fwd = int(((rej) & (y_pos > 0)).sum()) if y_pos is not None else 0
|
|
broke_fwd = int(((cho) & (y_pos < 0)).sum()) if y_pos is not None else 0
|
|
flip_rev = int(((cho) & (y_neg < 0)).sum()) if y_neg is not None else 0
|
|
counter_rev = int(((rej) & (y_neg > 0)).sum()) if y_neg is not None else 0
|
|
return {"n_cho": n_cho, "n_rej": n_rej,
|
|
"fix_fwd": fix_fwd, "broke_fwd": broke_fwd,
|
|
"flip_rev": flip_rev, "counter_rev": counter_rev}
|
|
|
|
|
|
def _si_row(name, y_ref, y_pos, y_neg, pmass_pos, pmass_neg) -> dict:
|
|
"""Compute SI (k=2 and k=1), si_fwd/rev, rates, SI_best for one method."""
|
|
fc = _flip_counts_from_arrays(y_ref, y_pos, y_neg)
|
|
n_cho, n_rej = fc["n_cho"], fc["n_rej"]
|
|
fix_rate = fc["fix_fwd"] / n_rej if n_rej else np.nan
|
|
broke_rate = fc["broke_fwd"] / n_cho if n_cho else np.nan
|
|
flip_rate = fc["flip_rev"] / n_cho if n_cho else np.nan
|
|
counter_rate = fc["counter_rev"] / n_rej if n_rej else np.nan
|
|
|
|
si_fwd_k2 = fix_rate - 2.0 * broke_rate
|
|
si_rev_k2 = flip_rate - 2.0 * counter_rate
|
|
si_fwd_k1 = fix_rate - 1.0 * broke_rate
|
|
si_rev_k1 = flip_rate - 1.0 * counter_rate
|
|
# honesty-aligned SI assuming a=-1 IS the honest direction (post-hoc sign flip):
|
|
# role-swap fix/broke -- counter_rev becomes "fix" and flip_rev becomes "broke".
|
|
# Not the same as -si_rev under k!=1 because the FPR penalty hits the swapped rate.
|
|
si_honest_at_neg1_k2 = counter_rate - 2.0 * flip_rate
|
|
|
|
if y_pos is not None and y_neg is not None:
|
|
pmass_ratio = min(pmass_pos, pmass_neg) ** 2
|
|
SI_k2 = np.nanmean([si_fwd_k2, si_rev_k2]) * pmass_ratio * 100
|
|
SI_k1 = np.nanmean([si_fwd_k1, si_rev_k1]) * pmass_ratio * 100
|
|
SI_best = max(si_fwd_k2, si_honest_at_neg1_k2) * pmass_ratio * 100
|
|
SI_best_intended = max(si_fwd_k2, si_rev_k2) * pmass_ratio * 100
|
|
elif y_pos is not None:
|
|
pmass_ratio = pmass_pos ** 2
|
|
SI_k2 = si_fwd_k2 * pmass_ratio * 100
|
|
SI_k1 = si_fwd_k1 * pmass_ratio * 100
|
|
SI_best = SI_k2
|
|
SI_best_intended = SI_k2
|
|
else:
|
|
pmass_ratio = pmass_neg ** 2
|
|
SI_k2 = si_rev_k2 * pmass_ratio * 100
|
|
SI_k1 = si_rev_k1 * pmass_ratio * 100
|
|
SI_best = si_honest_at_neg1_k2 * pmass_ratio * 100
|
|
SI_best_intended = SI_k2
|
|
|
|
return {
|
|
"method": name,
|
|
"SI_k2": float(SI_k2),
|
|
"SI_k1": float(SI_k1),
|
|
"SI_best": float(SI_best),
|
|
"SI_best_intended": float(SI_best_intended),
|
|
"si_fwd": float(si_fwd_k2) if not np.isnan(si_fwd_k2) else float("nan"),
|
|
"si_rev": float(si_rev_k2) if not np.isnan(si_rev_k2) else float("nan"),
|
|
"fix_rate": float(fix_rate) if not np.isnan(fix_rate) else float("nan"),
|
|
"broke_rate": float(broke_rate) if not np.isnan(broke_rate) else float("nan"),
|
|
"fix_fwd": fc["fix_fwd"], "broke_fwd": fc["broke_fwd"],
|
|
"flip_rev": fc["flip_rev"], "counter_rev": fc["counter_rev"],
|
|
"n_cho": n_cho, "n_rej": n_rej,
|
|
"n_seeds": N_SEEDS,
|
|
}
|
|
|
|
|
|
def _arr(df: pl.DataFrame, coeff: float, col: str = "logratio_honesty"):
|
|
sub = df.filter(pl.col("coeff") == coeff).sort("idx")
|
|
return sub[col].to_numpy() if len(sub) else None
|
|
|
|
|
|
def _pmass(df: pl.DataFrame, coeff: float):
|
|
sub = df.filter(pl.col("coeff") == coeff)
|
|
return float(sub["pmass"].mean()) if len(sub) else float("nan")
|
|
|
|
|
|
def tables_adapter_style(per_row_path: Path, group_col: str) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
"""For sources where each group has its own coeff sweep (-1, 0, +1)."""
|
|
df = pl.read_csv(per_row_path)
|
|
groups = df[group_col].unique().to_list()
|
|
|
|
si_rows, lr_rows, fl_rows = [], [], []
|
|
for g in groups:
|
|
gdf = df.filter(pl.col(group_col) == g)
|
|
_assert_coeff_row_identity(str(g), gdf)
|
|
y_ref = _arr(gdf, 0.0)
|
|
y_pos = _arr(gdf, 1.0)
|
|
y_neg = _arr(gdf, -1.0)
|
|
pmass_pos = _pmass(gdf, 1.0)
|
|
pmass_neg = _pmass(gdf, -1.0)
|
|
if y_ref is None:
|
|
continue
|
|
row = _si_row(g, y_ref, y_pos, y_neg, pmass_pos, pmass_neg)
|
|
row[group_col] = row.pop("method")
|
|
si_rows.append(row)
|
|
|
|
for c in [-1.0, 0.0, 1.0]:
|
|
cdf = gdf.filter(pl.col("coeff") == c)
|
|
if len(cdf) == 0: continue
|
|
y = cdf["logratio_honesty"].to_numpy()
|
|
lr_rows.append({
|
|
group_col: g, "coeff": c,
|
|
"mean_lr": float(np.mean(y)),
|
|
"std_lr": float(np.std(y, ddof=1)) if len(y) > 1 else float("nan"),
|
|
"mean_pmass": float(cdf["pmass"].mean()),
|
|
"n_rows": len(cdf), "n_seeds": N_SEEDS,
|
|
})
|
|
|
|
fc = _flip_counts_from_arrays(y_ref, y_pos, y_neg)
|
|
fl_rows.append({group_col: g, **fc, "n_seeds": N_SEEDS})
|
|
|
|
si_df = pl.DataFrame(si_rows).sort("SI_best", descending=True, nulls_last=True)
|
|
lr_df = pl.DataFrame(lr_rows).sort([group_col, "coeff"])
|
|
fl_df = pl.DataFrame(fl_rows).sort(group_col)
|
|
return si_df, lr_df, fl_df
|
|
|
|
|
|
def _row_keys(df: pl.DataFrame) -> list[tuple]:
|
|
"""Strict row identity for paired comparisons. ELSE comparison is invalid."""
|
|
key_cols = [c for c in ("idx", "dilemma_idx", "action_type") if c in df.columns]
|
|
return df.sort(key_cols).select(key_cols).rows()
|
|
|
|
|
|
def _assert_coeff_row_identity(name: str, df: pl.DataFrame, coeffs: tuple[float, ...] = (-1.0, 0.0, 1.0)) -> None:
|
|
ref = _row_keys(df.filter(pl.col("coeff") == 0.0))
|
|
for coeff in coeffs:
|
|
keys = _row_keys(df.filter(pl.col("coeff") == coeff))
|
|
if keys != ref:
|
|
raise ValueError(f"{name}: coeff={coeff:+.1f} row mismatch vs coeff=0: n={len(keys)} vs {len(ref)}")
|
|
|
|
|
|
def tables_prompt_paired(per_row_path: Path) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
"""Prompt baselines: pair dishonest/honest under each template family
|
|
as alpha=-1/+1 against base@0; dW:<adapter> uses its own sweep."""
|
|
df = pl.read_csv(per_row_path)
|
|
methods = df["method"].unique().to_list()
|
|
|
|
base_ref = df.filter(pl.col("method") == "base").sort("idx")
|
|
if len(base_ref) == 0:
|
|
raise ValueError("no 'base' method in prompt_baseline csv")
|
|
y_base = base_ref["logratio_honesty"].to_numpy()
|
|
pmass_base = float(base_ref["pmass"].mean())
|
|
base_keys = _row_keys(base_ref)
|
|
|
|
si_rows, lr_rows, fl_rows = [], [], []
|
|
|
|
# 1) prompt families, paired
|
|
for family in ["simple", "engineered"]:
|
|
pos_method = f"{family}_honest_prompt" if family == "simple" else f"{family}_prompt_honest"
|
|
neg_method = f"{family}_dishonest_prompt" if family == "simple" else f"{family}_prompt_dishonest"
|
|
if pos_method not in methods or neg_method not in methods:
|
|
continue
|
|
pos_df = df.filter(pl.col("method") == pos_method).sort("idx")
|
|
neg_df = df.filter(pl.col("method") == neg_method).sort("idx")
|
|
# SHOULD: base/pos/neg cover identical (idx, dilemma_idx, action_type) rows.
|
|
# ELSE the paired SI compares different examples and the table is invalid.
|
|
pos_keys = _row_keys(pos_df)
|
|
neg_keys = _row_keys(neg_df)
|
|
if pos_keys != base_keys or neg_keys != base_keys:
|
|
raise ValueError(
|
|
f"row mismatch in prompt family {family!r}: "
|
|
f"base n={len(base_keys)}, {pos_method} n={len(pos_keys)}, {neg_method} n={len(neg_keys)}"
|
|
)
|
|
y_pos = pos_df["logratio_honesty"].to_numpy()
|
|
y_neg = neg_df["logratio_honesty"].to_numpy()
|
|
pmass_pos = float(pos_df["pmass"].mean())
|
|
pmass_neg = float(neg_df["pmass"].mean())
|
|
name = f"prompt:{family}"
|
|
si_rows.append(_si_row(name, y_base, y_pos, y_neg, pmass_pos, pmass_neg))
|
|
for label, sub, c in [(neg_method, neg_df, -1.0), ("base", base_ref, 0.0), (pos_method, pos_df, 1.0)]:
|
|
y = sub["logratio_honesty"].to_numpy()
|
|
lr_rows.append({
|
|
"method": name, "coeff": c,
|
|
"mean_lr": float(np.mean(y)),
|
|
"std_lr": float(np.std(y, ddof=1)) if len(y) > 1 else float("nan"),
|
|
"mean_pmass": float(sub["pmass"].mean()),
|
|
"n_rows": len(sub), "n_seeds": N_SEEDS,
|
|
})
|
|
fc = _flip_counts_from_arrays(y_base, y_pos, y_neg)
|
|
fl_rows.append({"method": name, **fc, "n_seeds": N_SEEDS})
|
|
|
|
# 2) dW methods (have their own sweep; treat self-reference)
|
|
for m in methods:
|
|
if not m.startswith("dW:"):
|
|
continue
|
|
mdf = df.filter(pl.col("method") == m)
|
|
_assert_coeff_row_identity(m, mdf)
|
|
y_ref = _arr(mdf, 0.0)
|
|
y_pos = _arr(mdf, 1.0)
|
|
y_neg = _arr(mdf, -1.0)
|
|
pmass_pos = _pmass(mdf, 1.0)
|
|
pmass_neg = _pmass(mdf, -1.0)
|
|
if y_ref is None:
|
|
continue
|
|
si_rows.append(_si_row(m, y_ref, y_pos, y_neg, pmass_pos, pmass_neg))
|
|
for c in [-1.0, 0.0, 1.0]:
|
|
cdf = mdf.filter(pl.col("coeff") == c)
|
|
if len(cdf) == 0: continue
|
|
y = cdf["logratio_honesty"].to_numpy()
|
|
lr_rows.append({
|
|
"method": m, "coeff": c,
|
|
"mean_lr": float(np.mean(y)),
|
|
"std_lr": float(np.std(y, ddof=1)) if len(y) > 1 else float("nan"),
|
|
"mean_pmass": float(cdf["pmass"].mean()),
|
|
"n_rows": len(cdf), "n_seeds": N_SEEDS,
|
|
})
|
|
fc = _flip_counts_from_arrays(y_ref, y_pos, y_neg)
|
|
fl_rows.append({"method": m, **fc, "n_seeds": N_SEEDS})
|
|
|
|
si_df = pl.DataFrame(si_rows).sort("SI_best", descending=True, nulls_last=True)
|
|
lr_df = pl.DataFrame(lr_rows).sort(["method", "coeff"])
|
|
fl_df = pl.DataFrame(fl_rows).sort("method")
|
|
return si_df, lr_df, fl_df
|
|
|
|
|
|
def fmt(df: pl.DataFrame, name: str, floatfmt: str = "+.3f") -> str:
|
|
return f"\n=== {name} ===\n" + tabulate(df.to_pandas(), headers="keys", tablefmt="pipe", floatfmt=floatfmt, showindex=False)
|
|
|
|
|
|
def main():
|
|
out_root = Path("out/honesty")
|
|
|
|
# Adapter sweep (OOD)
|
|
print("\n" + "=" * 70)
|
|
print("ADAPTERS (OOD: cross_adapter_full_dd/dilemmas_per_row.csv)")
|
|
print("=" * 70)
|
|
si, lr, fl = tables_adapter_style(out_root / "cross_adapter_full_dd/dilemmas_per_row.csv", "adapter")
|
|
print(fmt(si, "T1: SI per adapter (k=2 ref-anchored bidirectional; SI_best = best sign for honesty)"))
|
|
print(fmt(lr, "T2: Raw mean +- std logratio per (adapter, coeff)"))
|
|
print(fmt(fl, "T3: Raw flip counts per adapter"))
|
|
|
|
# Prompts paired + dW (OOD)
|
|
print("\n" + "=" * 70)
|
|
print("PROMPTS (paired -1/0/+1) + dW (OOD: prompt_baseline/dilemmas_per_row.csv)")
|
|
print("=" * 70)
|
|
si, lr, fl = tables_prompt_paired(out_root / "prompt_baseline/dilemmas_per_row.csv")
|
|
print(fmt(si, "T1: SI per method (paired prompts: dishonest=-1, base=0, honest=+1)"))
|
|
print(fmt(lr, "T2: Raw mean +- std logratio per (method, coeff)"))
|
|
print(fmt(fl, "T3: Raw flip counts per method"))
|
|
|
|
# RepE / activation_baseline (OOD)
|
|
repe_path = out_root / "activation_baseline/dilemmas_per_row.csv"
|
|
if repe_path.exists():
|
|
print("\n" + "=" * 70)
|
|
print("REPE / ACTIVATION BASELINE (OOD: activation_baseline/dilemmas_per_row.csv)")
|
|
print("=" * 70)
|
|
df = pl.read_csv(repe_path).filter(pl.col("coeff").is_in([-1.0, 0.0, 1.0]))
|
|
tmp = out_root / "activation_baseline/_dilemmas_per_row_pm1.csv"
|
|
df.write_csv(tmp)
|
|
si, lr, fl = tables_adapter_style(tmp, "method")
|
|
print(fmt(si, "T1: SI per method (RepE bidirectional vs own 0)"))
|
|
print(fmt(lr, "T2: Raw mean +- std logratio per (method, coeff)"))
|
|
print(fmt(fl, "T3: Raw flip counts per method"))
|
|
|
|
# IID sycophancy claims (held-out Yes/No persona claims; no fix/broke labels
|
|
# so we report only mean +- std logratio across (adapter, coeff)). Source:
|
|
# cross_adapter_ablation/sycophancy_per_row.csv (variant=base only) since
|
|
# the canonical full IID file under out/honesty/<adapter>/ does not exist.
|
|
iid_path = out_root / "cross_adapter_ablation/sycophancy_per_row.csv"
|
|
if iid_path.exists():
|
|
print("\n" + "=" * 70)
|
|
print(f"IID SYCOPHANCY (held-out Yes/No claims; source: {iid_path})")
|
|
print("=" * 70)
|
|
# variant=full_all_tensors applies the full dW; variant=base zeros it out.
|
|
iid = pl.read_csv(iid_path).filter(pl.col("variant").is_in(["full_all_tensors", "base"]))
|
|
iid = iid.with_columns(
|
|
pl.when(pl.col("variant") == "base").then(pl.lit("dW=0 (ref)")).otherwise(pl.lit("dW full")).alias("setting")
|
|
)
|
|
iid_lr = iid.group_by(["adapter", "setting", "coeff"]).agg(
|
|
pl.col("logratio").mean().alias("mean_lr"),
|
|
pl.col("logratio").std(ddof=1).alias("std_lr"),
|
|
pl.col("pmass").mean().alias("mean_pmass"),
|
|
pl.len().alias("n_rows"),
|
|
).with_columns(pl.lit(N_SEEDS).alias("n_seeds")).sort(["adapter", "setting", "coeff"])
|
|
print(fmt(iid_lr, "T2 (IID): mean +- std logratio per (adapter, setting, coeff). higher logratio = more 'Yes' on held-out persona claims"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|