Files
evil_MoE/docs/extract_vhack_grad-vec.md

223 lines
9.9 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Historical: v_hack extraction in gradient space
This document describes the retired gradient-scored method. The current routeA
method extracts `v_act` from pooled bottleneck activations with forward passes;
see `README.md` and `src/vgrout/extract_vhack_act.py`.
Living design doc for the v_hack pipeline. Sibling to `RESEARCH_JOURNAL.md`.
This explains *what we extract* and *why*.
## TL;DR
`v_hack[name]` is a per-module top-k orthonormal basis in **AntiPaSTO
δS-gradient space**, computed by PCA on **paired (hack clean) NLL gradients**
over a small set of contrastive completion pairs (currently N=12, 10 train + 2
heldout). At training time we project the live policy-gradient component along
this basis out of `δS.grad`.
The 2026-05-27 refactor added two things on top of the older mean-diff design:
1. **Top-k extraction** (k=12 max) with **load-time slicing** (`v_hack_k`,
default 5) so k=1 vs k=5 vs k=12 is a config flip, not a re-extract.
2. **Singular-value recording** (`_sv/{name}` keys) so v_i carries its
extract-time confidence S_i, not just direction. (Currently unused at
runtime — earlier draft used it for a suspicion gate, removed 2026-05-27;
see below.)
## Why gradient space, not activation space?
Most representation-steering work (ActAdd, RepE, CHaRS) operates on
**activations** (forward pass), shifting hidden states at inference. We
operate on **gradients of δS**, the trainable per-Linear AntiPaSTO knob.
Reasons:
- We're not steering inference; we're **shaping training**. The projection
modifies `δS.grad` before the optimizer step, so the model itself doesn't
drift toward hack-aligned weight updates.
- δS gradients have a fixed, low-dimensional structure per module
(`δS ∈ R^r` where r = SVD rank of `W`). PCA-on-grads is computationally
cheap (12 pairs × N modules; r=2560 for largest mat) and gives a clean
per-module subspace.
- This is closest in spirit to CHaRS-PCT (Principal Component Thresholding,
§3.3 of `docs/paper_chars.md`): the L principal components of local-shift
covariance. We do the same maneuver on paired δS-gradient diffs.
## Why δS basis (= weight-SVD basis), not raw param basis?
AntiPaSTO wraps each Linear with `δW = U · diag(δS) · V_h`, where `U, S, V_h
= SVD(W_pretrained)`. So `δS ∈ R^r` are coordinates **in the weight-SVD
basis**. The basis change is just a rotation — no whitening, no rescaling.
Two things this buys us:
- The number of trainable scalars is r per module (5002500), not d_in×d_out.
A few hundred contrastive pairs would be needed to estimate dense
`d_in × d_out` direction; only a few pairs are needed in `R^r`.
- Low-rank perturbations (LoRA-style hack adapters) are sparse in this basis,
which makes per-direction gating in `δS` meaningful even with N=12 pairs.
What this does **not** buy us: regularization. The weight-SVD basis is just a
convenient coordinate system; PCA on top of it still has to do the work of
finding which coordinates carry hack-clean discriminative signal.
## Extraction pipeline
```python
# pseudo: extract_v_hack(model, tokenizer, wrappers, pairs, top_k, tau_axis, n_heldout, device)
train_pairs = pairs[:-n_heldout] # currently 12 of 14
# Gather per-pair, per-module gradients on hack-completion and clean-completion NLL.
grads_hack[name]: list of [r]-tensors, length n_pairs
grads_clean[name]: list of [r]-tensors, length n_pairs
for pair in train_pairs:
for label, completion in [("hack", pair.hack), ("clean", pair.clean)]:
model.zero_grad()
loss = mean_NLL_on_completion_tokens(model, pair.prompt + completion)
loss.backward() # populates δS.grad per module
for name, info in wrappers.items():
bucket[name].append(info.delta_S.grad.detach().cpu().float().clone())
# Per module: PCA on paired diff.
for name in wrappers:
G_h = stack(grads_hack[name]) # [n_pairs, r]
G_c = stack(grads_clean[name])
D = G_h - G_c # [n_pairs, r]: per-pair hack-axis displacement
U_d, S_d, Vh_d = svd(D) # truncated, m = min(n_pairs, r)
V = Vh_d[:k_max] # [k_max, r], orthonormal rows
# Orient v_i so +v_i points hack-ward (majority vote across pairs).
proj = D @ V.T # [n_pairs, k_max]
n_pos = (proj > 0).sum(0)
flip = where(n_pos < n_pairs/2, -1, +1)
V = V * flip[:, None]
v_hack[name] = V
v_hack[f"_sv/{name}"] = S_d[:k_max] # NEW: singular values saved alongside
```
**File schema (v2):**
- `{name}` → Tensor[k_max, r], orthonormal hack-axis basis, oriented +hack
- `_sv/{name}` → Tensor[k_max], singular values of D in that basis
- metadata: `model`, `dtype`, `top_k`, `tau_axis`, `schema=v2_with_sv`
## Load-or-extract (2026-05-27)
`train.py` derives `v_hack_path` from `(model_name, v_hack_extract_top_k)`
unless overridden. If the file is missing, it extracts inline on the
already-wrapped model:
```
v_hack_path = OUT_DIR / f"v_hack_{model_slug}_k{extract_top_k}.safetensors"
if not v_hack_path.exists():
v_hack_dict, raw_grads, _ = extract_v_hack(model, tok, wrappers, PAIRS,
top_k=extract_top_k, ...)
save_file(v_hack_dict, v_hack_path, metadata={...})
v_hack, v_sv = load_v_hack(v_hack_path, model_name, wrappers, k_use=v_hack_k)
```
This means a fresh model with no cached v_hack just runs extract once
(~5 min for 4B-class) and proceeds. No prerequisite jobs, no manual flags.
## Load-time k-slicing
Extract saves k_max (default 12). Load slices to `k_use` (default 5). So
k=1 vs k=5 vs k=12 is a **config flip**, not a re-extract. The
`mean_sv_top5_frac` from our 2026-05-26 extract was 0.71, so k=5 covers
~71% of per-module D-variance — load-time slice at 5 is a reasonable
default that we can ablate cheaply.
## Runtime suspicion gate (REMOVED 2026-05-27)
**Why it was tried:** if a module has small `||D||_F` at extract time
(weak hack signal), its top SVD direction `v_1` could coincidentally
align with a structured coding-gradient direction at training time,
ablating capability rather than hack.
**Gate design (since removed):** `r_i = |g·v_i|/S_i` as a per-step
quantile drop of the top-25% (module, axis) pairs.
**Why removed:** the quantile design is a fixed-budget knob, not a
detector — `frac_axes_susp` was deterministically 0.25 every step (true
by definition of quantile), so the column carried no information.
Codex review independently flagged: `|g·v_i|` scales with live-grad norm
and `S_i` scales with extract-time-grad norm, so the cross-module ratio
is not dimensionless and high-gradient modules dominate regardless of
genuine suspiciousness. In a high-d model the worst-case damage per
spurious axis is ~`1/√r ≈ 2%` of `||g||` anyway, so the cure was
costlier than the disease.
`_sv/{name}` keys are still saved — they're cheap and may feed a
future, principled gate (extract-time `tau_axis` zeros rows where
`S_i/S_0 < tau_axis`, which is the same idea but applied once at
extract rather than at every step).
## Validation: cheap discriminative tests
The fundamental question: does v_hack actually discriminate hack from
clean gradients, or is it picking up irrelevant variance?
### Test 1: cin_hack vs cin_clean on disk pools (cheap, ~5 min)
We already have `out/probe_distill/teacher_pool/` (hacking samples) and
`out/probe_distill/base_pool/` (clean samples). For N samples each:
```
for prompt, completion, label in samples:
model.zero_grad()
loss = mean_NLL(model, prompt + completion).backward()
cin = (V @ delta_S.grad).norm() / delta_S.grad.norm()
record(label, cin)
```
**Discriminator:** `cin_hack_mean cin_clean_mean`. If ≫ 0, v_hack
discriminates. If ≈ 0, v_hack is capturing prompt-length / generic
variance, not hack-specific direction. **Cost: ~5 min, no training.**
### Test 2: held-out pair projection (existing)
`verify_vhack_heldout.py` projects gradients from held-out pairs (last
n_heldout of PAIRS) onto trained v_hack. Already in CI-style flow.
### Test 3: random-direction null
For each module, compute cin onto v_hack vs onto a random unit vector of
the same shape. If `cin_v_hack > cin_random` by a large margin, v_hack
is non-spurious. Trivial to implement.
### Test 4: per-source cin during training (live)
In mixed-pool runs we have student rollouts (initially ~no hack) and
teacher rollouts (all hack). Currently `cin` is computed on the
accumulated gradient (mixed). With ONE extra backward per step we can
compute `cin_s` (student-only grad) and `cin_t` (teacher-only grad)
separately. **Predict:** if v_hack is a real hack direction,
`cin_t > cin_s` initially; the gap shrinks as student picks up hack
(if it does). Useful for diagnosing whether the projection is doing
real work or just gradient noise.
### Test 5: bootstrap sign-stability
Bootstrap pairs (sample N-2 with replacement), re-extract v_hack,
compare `cos(v_hack_original, v_hack_bootstrap)`. If unstable, v_hack
is fitting noise. **Cost: 5 × ~5 min = 25 min total.**
## Open design questions
- **Should we whiten by S?** I.e. parameterize the AntiPaSTO knob as
`δS_i / σ_i(W)` so all directions have equal forward-pass impact.
Currently we don't. This is a separate, larger question.
- **Should we record per-pair pair tags / hack flavors?** With 12
unlabeled pairs we can't do supervised LDA. With flavor labels
(hardcode / weak-tests / persona / format-leak) we could do LDA-on-
labels, which would beat unsupervised PCA at this N.
## Related files
- `src/projected_grpo/extract_vhack_grad.py` — extract function + CLI
- `src/projected_grpo/proj.py` — runtime projection + gates
- `src/projected_grpo/train.py:load_v_hack` — load + slice + auto-extract
- `src/projected_grpo/verify_vhack_heldout.py` — Test 2 above
- `src/projected_grpo/pairs.py` — the 14 contrastive pairs
- `docs/paper_chars.md` — CHaRS notes (PCT comparison)
- `RESEARCH_JOURNAL.md` — chronological progress log