scaffold steer_heal: spec, repo infra, vendored deps

Setup per setup-repo conventions: uv + justfile + fast-dev-run on
wassname/qwen3-5lyr-tiny-random, package under src/steer_heal (config +
pipeline skeleton). Stages fail fast with NotImplementedError pointing at
the docs/vendor module to port from.

Design in spec.md: distil a steering-lite mean-diff teacher vector (iso-KL
dosed) into a conditioned LoRA, heal incoherency with a KL-rev-to-original
barrier, fold each round via w2schar gated bake, eval on tinymfv. Three
uncertainty gates (filter / heal / iterate) each with a UAT artifact.

Base model google/gemma-3-1b-it (RTX 3090, 24GB). Reference repos vendored
under docs/vendor (gitignored): steering-lite, isokl, tinymfv, w2schar-mini.
The lighter three are editable path deps; w2schar (py3.13 + flash-attn) is
reference-only, we copy its adapter/bake/plot modules.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 09:49:31 +08:00
parent b98535066a
commit 940a3742c5
11 changed files with 3199 additions and 26 deletions
+11
View File
@@ -0,0 +1,11 @@
outputs/
results/
logs/
wandb/
data/
docs/vendor/
__pycache__/
*.pyc
.env
.venv/
*.safetensors
+29
View File
@@ -0,0 +1,29 @@
**This is novel ML research.** Not in your training data. Extrapolate carefully. Read `spec.md` first.
## What this is
Distil an activation steering vector (steering-lite) into a conditioned LoRA, heal the incoherency it injects with a KL-rev-to-original barrier, fold the round into a gated weight bake, and loop. Eval on tinymfv (auth/care axis + coherence). Full design and the three uncertainty gates are in `spec.md`.
## Workflow
- Inherit global rules from `~/.claude/CLAUDE.md`.
- `just vendor` to (re)clone reference repos into `docs/vendor` (editable path deps).
- `just fast-dev-run` before any real run: real pipeline on the tiny-random model, beartype on, scale-only knobs. If a bug slips past it, strengthen the gate, do not add a `tests/` dir.
- `just run` for a real run on gemma-3-1b-it (RTX 3090, 24GB).
- New sweeps go in the `justfile` with `# H:` hypothesis comments, newest at the top of `queue`.
- `tail docs/RESEARCH_JOURNAL.md` for latest context.
## Reuse, do not reinvent (docs/vendor)
- steering-lite: `Vector.train(...).calibrate(target_kl=...)`, mean-diff vector + iso-KL dose.
- iso-kl-figure: coefficient calibration and KL/coherence measurement.
- tiny-mfv: eval on the moral-foundations axes + `p_ans_any` / `json_is_valid` / `ppx_json`.
- w2schar-mini (NOT a dep, needs py3.13): copy `src/csm/ws/{adapter,bake,history}.py` for the conditioned LoRA + gated bake, and port `src/csm/plot.py` `_build_scatter` for the Care-vs-Authority HTML map. The base stays pristine at gate 0 = our KL anchor.
## Code style
- `einops`/`einsum` for shape ops and contractions; `jaxtyping` on function boundaries only.
- `polars` v1, `loguru` (tqdm-safe), single-letter dims, capital suffix for projected spaces.
- Fail fast, crash loudly. No defensive guards, no fallbacks, no silent skips.
- One objective + one constraint (barrier), never competing losses. See `spec.md` Loss.
- Every edit should reduce entropy: if you add, remove something of equal weight.
+10 -15
View File
@@ -4,30 +4,25 @@ Hypothesis: you can distill a steering vector into LoRA weights and "heal" the i
The crux: KL-to-base penalises all drift, persona shift included. The bet is that incoherency drift is large and erratic while the persona shift is small and systematic, so KL kills the incoherency preferentially. If that's wrong, we just trade persona strength for coherence instead of getting both.
## Source
Found this interesting: https://r.jina.ai/https://arxiv.org/html/2606.00995v1
They use steering vectors as an internal perturbation to generate synthetic data, which is what weight steering does too. But:
- they use single completions, not pairs
- they don't measure incoherency (they could)
- they only use one direction: base to pos, not neg to pos
So this is similar to weight steering, except you heal with KL or WD instead of taking the direction between two adapters.
## Method
## Experiment
1. Pick a positive persona, e.g. `pos = "you do not defer to authority and instead stick to principle no matter your involvement"`.
2. Build the steering vector from the distance `hs_base -> hs_pos` (hidden states). This is normal mean-mass contrastive steering, see my reference repo https://github.com/wassname/steering-lite.
2. Build the steering vector from the distance `hs_base -> hs_pos` (hidden states). This is normal mean-mass contrastive steering
3. Generate completions with this vector.
- Drop completions that are incoherent, or that verbalise the trait instead of enacting it (we want the model to act it out, not narrate "I am someone who..."). Filter as much as we can.
- **Q0 can we filter?**
- We might be able to dial the vector down for long trajectories. Could we even backtrack an incoherent vector and replay parts with less intervention? Or just cosine-gate at test time.
4. Train a LoRA on these completions, could be just 50 completions and 2 epochs. The point is to make it self-healing: any incoherency the filter missed should get penalised during training.
- Regularise with KL or WD so the outputs, distribution, or weights don't shift too far from base. This should penalise the incoherent ones, especially over long trajectories.
- Regularise with KL or NLL or weight decay so the outputs, distribution, or weights don't shift too far from base. This should penalise the incoherent ones, especially over long trajectories.
- **Q1: can we heal incoherency?**
5. Bake in the LoRA adapter. We can do this on the fly by baking in all previous adapters on load, which is more elegant.
6. Eval the checkpoint on https://github.com/wassname/tinymfv.
7. If it works, loop. We could even do this online, GRPO-style per batch, or iteratively. Iterative is simpler to start.
- **Q2: is it coherent over a loop?**
## Motovation:
If it works it will be a novel alignment method that works without label and might be resistant to deceptive alignment
## Eval
+13
View File
@@ -0,0 +1,13 @@
# Research Journal
# 2026-06-04
## Scaffold
Set up the repo: uv + justfile + `fast-dev-run` on `wassname/qwen3-5lyr-tiny-random`, package under `src/steer_heal`, config in `config.py`, pipeline skeleton in `run.py`. Design and the three uncertainty gates are in `spec.md`.
Vendored reference repos into `docs/vendor` (gitignored, `just vendor` to reclone): steering-lite, isokl_steering_calibration, tinymfv, w2schar-mini. The first three are editable path deps; w2schar-mini needs py3.13 and pins flash-attn, so it stays reference-only and we copy its adapter/bake/plot modules.
Base model for real runs: `google/gemma-3-1b-it` (gemma has more personality to steer; the alternative was a smarter-but-flatter Qwen). RTX 3090, 24 GB.
**Next:** port `teacher_vec` (steering-lite + iso-KL), then the U1 filter gate. Pipeline stages currently fail fast with `NotImplementedError` pointing at the vendor module to port from.
+47
View File
@@ -0,0 +1,47 @@
set shell := ["bash", "-cu"]
BASE := "uv run python -m steer_heal.run"
SEEDS_3 := "41 42 43"
# List available recipes
default:
@just --list
# Clone the vendored reference repos (editable path deps live here).
vendor:
#!/usr/bin/env bash
set -eux
mkdir -p docs/vendor && cd docs/vendor
for r in steering-lite isokl_steering_calibration tinymfv w2schar-mini; do
[ -d "$r" ] || git clone --depth 1 "https://github.com/wassname/$r"
done
# fast-dev-run: ONE end-to-end run of the real pipeline on the tiny-random model.
# Real LLM, real eval, real I/O; only knob is scale. NOT a unit test.
fast-dev-run *ARGS:
BEARTYPE=1 {{ BASE }} --fast-dev-run {{ ARGS }}
# Real run on gemma-3-1b-it (24GB / RTX 3090). Set flash-attn first if installed.
run *ARGS:
STEER_ATTN_IMPL=eager {{ BASE }} {{ ARGS }}
# Queue sweeps (comment out completed; `just results` to check).
queue:
#!/usr/bin/env bash
set -x
just sweep-reg
# H: kl_rev heals best (mode-seeking suppresses low-base-prob = incoherent tokens).
sweep-reg:
#!/usr/bin/env bash
set -x
export WANDB_RUN_GROUP="sweep-reg-$(date +%Y%m%d-%H%M)"
for seed in {{ SEEDS_3 }}; do
for reg in nll kl_fwd kl_rev wd; do
echo "=== reg=$reg seed=$seed ==="
{{ BASE }} --reg=$reg --seed=$seed
done
done
# flash-attn: install a prebuilt wheel (see `flash-attn-prebuilt` skill), then
# run with STEER_ATTN_IMPL=flash_attention_2.
+12 -9
View File
@@ -6,7 +6,9 @@ requires-python = ">=3.11"
dependencies = [
"torch",
"transformers",
"accelerate",
"datasets",
"safetensors",
"einops",
"jaxtyping",
"beartype",
@@ -17,20 +19,21 @@ dependencies = [
"tqdm",
"numpy",
"wandb",
"matplotlib",
"plotly",
"baukit",
# wassname building blocks (added via uv add git+...; see [tool.uv.sources])
# wassname building blocks, vendored under docs/vendor (run `just vendor`)
"steering-lite",
"iso-kl-steering-calibration",
"tinymfv",
"w2schar-mini",
"iso-kl-figure",
"tiny-mfv",
]
[tool.uv.sources]
steering-lite = { git = "https://github.com/wassname/steering-lite" }
iso-kl-steering-calibration = { git = "https://github.com/wassname/isokl_steering_calibration" }
tinymfv = { git = "https://github.com/wassname/tinymfv" }
w2schar-mini = { git = "https://github.com/wassname/w2schar-mini" }
# Editable path deps to the clones in docs/vendor (wassname's vendor pattern).
# w2schar-mini is NOT a dep (needs py3.13 + pinned flash-attn wheels); we vendor
# it for reference and copy its adapter/bake/plot modules into src/steer_heal/ws.
steering-lite = { path = "docs/vendor/steering-lite", editable = true }
iso-kl-figure = { path = "docs/vendor/isokl_steering_calibration", editable = true }
tiny-mfv = { path = "docs/vendor/tinymfv", editable = true }
baukit = { git = "https://github.com/davidbau/baukit.git" }
[tool.ruff.lint]
+4 -2
View File
@@ -16,7 +16,9 @@ Building blocks, all yours unless noted:
- steering-lite — https://github.com/wassname/steering-lite. Mean-diff steering vector extraction and hook-based application. `v = Vector.train(model, tok, pos, neg, MeanDiffC(...)).calibrate(model, tok, target_kl=1.0)`; apply with `with v(model, C=...): model.generate(...)`. Vector is L2-normalised per layer; application is `h + coeff * v` broadcast over positions (no norm-matching).
- isokl_steering_calibration — https://github.com/wassname/isokl_steering_calibration. iso-KL calibration: bisects the coefficient until p95 per-token KL(steered||base) hits a target (default 1 nat), giving a deterministic dose `c_star`. Then sweep `alpha = c_star * [0.5, 1, 1.5, 2]`. Pairs KL with an "alive" coherence check (force a JSON boolean prefill, require >=0.75 mass on true/false), which is the same idea as tinymfv `p_ans_any`. Reports a cumulative coherence budget of ~1.7 nats across iterated rounds, directly relevant to our loop.
- lora-lite — https://github.com/wassname/lora-lite. Hackable LoRA via forward hooks; base frozen, loss fully under our control (no built-in KL, we add it). Caveat: no merge/unmerge and one adapter per attach, so we do not "bake in" between rounds. Resolution: w2schar-mini's gated-history baking (below).
- w2schar-mini — https://github.com/wassname/w2schar-mini. Conditioned LoRA (scalar gate `c`) in an iterated distillation loop, the closest prior setup to ours. `csm.ws.bake.baked` composes N gated adapters into the weights (`W += sum_i c_i*(alpha_i/r_i)*B_i A_i`) and restores on exit; `csm.ws.history.load_base_with_history` gates history off at `c=0` so the base stays pristine. Reuse this for the baking/accumulator and the `C_0=C_N=0` KL anchor instead of reimplementing.
- w2schar-mini — https://github.com/wassname/w2schar-mini. Conditioned LoRA (scalar gate `c`) in an iterated distillation loop, the closest prior setup to ours. `csm.ws.bake.baked` composes N gated adapters into the weights (`W += sum_i c_i*(alpha_i/r_i)*B_i A_i`) and restores on exit; `csm.ws.history.load_base_with_history` gates history off at `c=0` so the base stays pristine. Reuse `ModulatedLoRA` + `baked` for the accumulator and the `C_0=C_N=0` KL anchor, and port `csm/plot.py` `_build_scatter` (plotly Care-vs-Authority scatter, one node per round, `to_html`) for our loop map. Not a dependency: it needs py3.13 and pins flash-attn, so we vendor it and copy the modules.
All four are cloned into `docs/vendor` (gitignored, `just vendor` to reclone); the lighter three are editable path deps.
- tinymfv — https://github.com/wassname/tinymfv. Eval on the moral-foundations auth vs care axis, plus coherence metrics `p_ans_any` (best), `json_is_valid`, `ppx_json`.
- Related, for positioning: Fierro and Roger, "Steering Language Models with Weight Arithmetic", arXiv:2511.05408 — https://arxiv.org/abs/2511.05408, code https://github.com/safety-research/weight-steering. Weight steering edits weights directly using the difference between two fine-tuned models. No coherence measurement, no KL, no iteration.
@@ -81,7 +83,7 @@ Third uncertainty: over rounds, does the auth axis increase monotonically (same
- Internalisation: `cos(v_student, v_teacher)` per round.
- Budget: track cumulative KL vs the iso-KL ~1.7 nat prior.
Gate UAT: `results/u3_loop.png`, three stacked panels sharing a round axis (auth shift, coherence, the two cosines), see /tufte-viz. Pass if auth increases monotonically and coherence stays above the floor for >=3 rounds.
Gate UAT: `results/index.html`, the ported w2schar Care-vs-Authority plotly map (one node per round, trajectory across the auth axis) plus a coherence and direction-cosine panel sharing the round axis, see /tufte-viz. Pass if auth increases monotonically and coherence stays above the floor for >=3 rounds.
## Algorithm (pseudopy)
+6
View File
@@ -0,0 +1,6 @@
import os
if os.environ.get("BEARTYPE"):
from beartype.claw import beartype_this_package
beartype_this_package()
+67
View File
@@ -0,0 +1,67 @@
from dataclasses import dataclass, replace
from typing import Literal
@dataclass
class RunConfig:
"""One steer_heal run. `fast_dev_run` swaps in the tiny-random preset.
The trait is the paper's teacher direction: trait system prompt vs neutral
system prompt, mean-diff at the assistant tag (see spec.md).
"""
# ── model ──
model: str = "google/gemma-3-1b-it"
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
# ── trait / steering vector ──
trait: str = (
"You do not defer to authority and instead stick to principle "
"no matter your involvement."
)
neutral: str = "You are a helpful assistant."
layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer
target_kl: float = 1.0 # iso-KL p95 dose (nats)
alphas: tuple[float, ...] = (0.5, 1.0, 1.5, 2.0) # multiples of c_star to generate at
# ── generation + filter (U1) ──
n_prompts: int = 64
n_keep: int = 50
gen_max_new_tokens: int = 256
max_len: int = 1024
ppl_tau: float = 50.0 # drop completions with ppl-under-original above this
rep_tau: float = 0.3 # drop completions whose max n-gram repeat fraction exceeds this
# ── heal (U2): one objective + divergence-to-ORIGINAL barrier ──
reg: Literal["nll", "kl_fwd", "kl_rev", "wd"] = "kl_rev"
lam: float = 1.0 # barrier weight (also weight_decay when reg == "wd")
tau: float = 0.5 # barrier engages only when divergence > tau (nats)
lora_r: int = 8
lora_alpha: float = 16.0
epochs: int = 2
lr: float = 1e-4
# ── loop (U3) ──
n_rounds: int = 4
seed: int = 42
fast_dev_run: bool = False
TINY = dict(
n_prompts=4,
n_keep=3,
gen_max_new_tokens=32,
max_len=128,
epochs=1,
n_rounds=1,
alphas=(1.0,),
)
def resolve(cfg: RunConfig) -> RunConfig:
"""Apply the fast-dev-run preset (tiny random model, scaled-down everything)."""
if cfg.fast_dev_run:
return replace(cfg, model=cfg.fast_dev_model, **TINY)
return cfg
+105
View File
@@ -0,0 +1,105 @@
"""steer_heal pipeline entry point.
Loop (anchored to the round-0 original throughout, see spec.md):
teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter
-> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake)
-> tinymfv eval -> repeat.
Stages marked TODO are ported from docs/vendor/* as we implement them; this
file fails fast at the first unimplemented stage rather than stubbing fake
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
"""
import os
import sys
from datetime import datetime
from pathlib import Path
import torch
import tyro
from loguru import logger
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from steer_heal.config import RunConfig, resolve
REPO = Path(__file__).resolve().parents[2]
def setup_logging() -> None:
# tqdm-safe loguru, single-char level icons, verbose copy on disk.
logger.remove()
logger.add(lambda m: tqdm.write(m, end=""), colorize=True,
format="<level>{level.icon}</level> {message}", level="INFO")
for lvl, ic in [("INFO", "I"), ("WARNING", "W"), ("ERROR", "E"), ("DEBUG", "D")]:
logger.level(lvl, icon=ic)
log_dir = REPO / "logs"
log_dir.mkdir(exist_ok=True)
f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log"
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}",
level="DEBUG")
logger.info(f"verbose log: {f}")
def load_model(model_id: str, dtype: torch.dtype):
tok = AutoTokenizer.from_pretrained(model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", torch_dtype=dtype,
low_cpu_mem_usage=True, attn_implementation=attn,
)
model.eval()
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})")
return model, tok
# ── stages (ported from docs/vendor as we implement; fail fast until then) ──
def teacher_vec(model, tok, cfg: RunConfig):
# steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag,
# then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl.
raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration")
def generate_and_filter(model, tok, v, orig, cfg: RunConfig):
# gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1).
raise NotImplementedError("TODO: generate_and_filter (U1 gate)")
def heal(model, orig, comps, cfg: RunConfig):
# SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off).
# adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws.
raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake")
def evaluate(model, cfg: RunConfig) -> dict:
# tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json.
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
def steer_heal(model, tok, orig, cfg: RunConfig):
for r in range(cfg.n_rounds):
logger.info(f"── round {r} ──")
v = teacher_vec(model, tok, cfg)
comps = generate_and_filter(model, tok, v, orig, cfg)
heal(model, orig, comps, cfg)
logger.info(evaluate(model, cfg))
return model
def main(cfg: RunConfig) -> None:
setup_logging()
cfg = resolve(cfg)
torch.manual_seed(cfg.seed)
logger.info(f"config: {cfg}")
dtype = getattr(torch, cfg.dtype)
model, tok = load_model(cfg.model, dtype)
orig = model # round-0 anchor; KL reference = same module with adapter gates off
steer_heal(model, tok, orig, cfg)
if __name__ == "__main__":
main(tyro.cli(RunConfig))
Generated
+2895
View File
File diff suppressed because it is too large Load Diff