mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 18:07:16 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c038444eb | |||
| ee54945076 | |||
| 282fb3de47 | |||
| 22fd4b8dbe | |||
| ea89a0ee35 | |||
| e095dc8227 | |||
| 4b90f19400 |
@@ -0,0 +1,35 @@
|
||||
Let me try to get the original version of `filter.py` from HEAD to compare:Now I have all the context I need. Let me verify one more thing about how `rep_frac` integrates with the filter pipeline:I now have all the context needed. Let me compile the review.
|
||||
|
||||
---
|
||||
|
||||
## Code Review: `rep_frac` affect-loop detection gate
|
||||
|
||||
### Summary
|
||||
Patches `rep_frac` with a diffuse affect-loop detector that catches low-diversity, affective roleplay mush ("oh my goodness / my darling / sweet" loops) that evaded the existing n-gram repetition detector. When a long completion (>=128 lex words) has high affect-word ratios combined with low lexical diversity and high compressibility, `rep_frac` returns 1.0, triggering the existing `rep_tau` rejection. The T1 verification log confirms this flips task-181 junk from kept to rejected while passing coherent hand examples.
|
||||
|
||||
### Important (should fix)
|
||||
|
||||
- **`src/steer_heal/filter.py:62-63` — `AFFECT_LOOP_WORDS` includes `"love"`, the target trait**: The word `"love"` is in the affect-word set, so every genuine love declaration contributes to `affect_frac`. At the required thresholds (>=0.25 or >=0.35), this is very unlikely to flag legitimate text (a normal 128-word declaration has `affect_frac` ~0.01-0.02), and the verification log confirms hand examples pass. However, the `"love"` inclusion creates a long-tail risk: if a future model produces diverse-but-passionate declarations that happen to use affect words at 25%+ density, they could be silently dropped. Consider whether `"love"` needs to be in the set given it's the signal we want.
|
||||
|
||||
- **`src/steer_heal/filter.py:73-74` — Short-text early return**: The word n-gram loop returns 1.0 immediately when any n-gram level produces empty grams (e.g., a 3-word text at n=4). This is pre-existing code, but the affect-loop gates that follow will never run on short completions as a result. This is correct behavior (short completions ARE degenerate), but the interaction with the new affect-loop gates is undocumented. Consider adding a comment noting the short-circuit is intentional and the affect-loop gates are gated on `>=64` words anyway.
|
||||
|
||||
### Suggestions
|
||||
|
||||
- **`src/steer_heal/filter.py:88-93` — Caps-heavy gate threshold is the broadest**: The third gate (`caps_frac >= 0.15, affect_frac >= 0.25, unique_frac < 0.55`) has the loosest `unique_frac` threshold (0.55 vs 0.45/0.50 in the other two). A completion with moderate caps (15% uppercase, e.g., proper nouns and emphasis) and 25% affect words but 55% unique words could be flagged. The verification log shows hand examples pass, so this is fine in practice. Worth noting in a comment why this gate has looser thresholds than the others.
|
||||
|
||||
- **`src/steer_heal/filter.py:83` — `affect_frac`, `punct_frac`, `caps_frac` computed for all >=64-word texts**: These are computed even for texts in the 64-127 word range where they're not used (the gates require >=128). This is harmless overhead but slightly misleading when reading the code. Consider either moving them inside the `>=128` guard or adding a brief comment.
|
||||
|
||||
### Positive
|
||||
|
||||
- **`src/steer_heal/filter.py:69-70` — Orthogonal signal use**: The affect-loop detection uses three orthogonal signals (affect-word ratio, punctuation density, caps density) each combined with lexical diversity and compressibility. This multi-signal design reduces false positives: a legitimate text with high caps won't trigger unless it's also low-diversity and affect-heavy.
|
||||
- **`src/steer_heal/filter.py:63` — `AFFECT_LOOP_WORDS` is case-matched to lowercase**: Since `lex_words` comes from `text_lc` (already lowercased), the set membership check is correct. No case-sensitivity bug.
|
||||
- **R2 compliance**: No new config knob, no fallback logic, the gate lives inside `rep_frac` and feeds the existing `keep = rep < rep_tau` path. Exactly as required.
|
||||
- **`docs/spec/20260624_love_filter_tighten_requeue.md` — Spec-driven verification**: The T1 verification log at `/tmp/steer_heal_love_filter_tighten_verify.log` provides concrete evidence: round-0 old-kept drops from 81 to 36, round-1 from 91 to 4, round-2 from 90 to 0. The representative rejected rows clearly show the "oh my goodness / my darling / sweet" loops being caught.
|
||||
|
||||
### Verdict
|
||||
**APPROVE** (for the T2 filter changes). The affect-loop detection is well-designed and verified. The remaining tasks (T3 fast-dev-run, T4 commit+push+enqueue) are not yet done. The `"love"` word in `AFFECT_LOOP_WORDS` is worth a second look but is very unlikely to cause issues at the current thresholds.
|
||||
|
||||
## Triage
|
||||
- Accepted: removed `"love"` from `AFFECT_LOOP_WORDS`; it is the target signal and was unnecessary for catching the observed loops.
|
||||
- Rejected for now: adding more comments around pre-existing short-completion behavior and caps thresholds. The code already fails short completions via the old n-gram guard, and the caps gate is constrained by affect density plus low lexical diversity.
|
||||
- Reverified after the accepted change: `/tmp/steer_heal_love_filter_tighten_verify2.log` and `/tmp/steer_heal_love_filter_tighten_fast2.log`.
|
||||
@@ -0,0 +1,47 @@
|
||||
## Review: `steer_heal` collapse-audit patch
|
||||
|
||||
### Correctness concerns
|
||||
|
||||
**1. The `zlib` heuristic conflates low lexical diversity with repetition (moderate risk).**
|
||||
|
||||
The core addition:
|
||||
```python
|
||||
unique_frac = len(set(lex_words)) / len(lex_words)
|
||||
compressed_frac = len(zlib.compress(...)) / max(len(text.encode()), 1)
|
||||
if unique_frac < 0.18 and compressed_frac < 0.32:
|
||||
return 1.0
|
||||
```
|
||||
This catches *any* low-diversity, highly-compressible text — not just diffuse love/affect loops. A stylistically flat but valid completion (e.g., simple declarative children's prose) could trip both thresholds. The comment says "diffuse affect loops *can* evade," but the guard doesn't restrict itself to affect — it's a blunt lexical-diversity floor. The magic constants (0.18, 0.32) appear data-derived (#181) but aren't validated against a separate holdout of non-collapse low-diversity text.
|
||||
|
||||
**2. `len(text.encode())` vs `zlib.compress(text.lower().encode())` — encoding mismatch (low risk).**
|
||||
|
||||
The denominator uses the raw `text.encode()` byte length while the numerator uses `text.lower().encode()`. For ASCII-only English these are identical, but any non-ASCII codepoint with a case-folding that changes byte width (e.g., `İ` → `i̇` in Turkish) would skew the ratio. Unlikely to hit in practice given English model outputs, but sloppy.
|
||||
|
||||
**3. The `len(lex_words) >= 128` guard creates a blind spot.**
|
||||
|
||||
Diffuse loops in completions shorter than 128 alphabetic words are invisible to the new heuristic. If the model collapses early in generation, the gate never fires.
|
||||
|
||||
### Verification gap: doesn't distinguish the failure mode
|
||||
|
||||
The rescoring evidence shows `old_rep 0.073–0.131 → new_rep=1.0` for r2 collapsed samples, proving the old `rep_frac` was missing them. But the evidence **never shows what those completions actually contain**. Without seeing the raw text, we can't rule out that the new gate is catching *unrelated low-diversity outputs* rather than the target "my sweet / my darling / oh my goodness" loops. The `brief=True` path now suppresses the full dump that would have provided that audit trail. This undersells the "preserve audit evidence" requirement.
|
||||
|
||||
### What's good
|
||||
|
||||
- The fail-fast `ValueError` in `run.py` when no probe passes is correct and necessary.
|
||||
- The `brief` mode counts are computed before the early return — no dropped data.
|
||||
- The structural refactor (counts moved above the polars import) is clean.
|
||||
|
||||
## Triage
|
||||
|
||||
Accepted concern 1. The committed heuristic now also requires repeated phrase evidence:
|
||||
`top_bigram_n >= 12` or `top_trigram_n >= 8`.
|
||||
|
||||
Accepted concern 2. The committed compression ratio uses the same lowercased byte string
|
||||
for numerator and denominator.
|
||||
|
||||
Partially accepted concern 3. The committed guard is `len(lex_words) >= 64`, not 128.
|
||||
Shorter loops remain covered by the existing word/character n-gram checks.
|
||||
|
||||
Verification gap addressed in `docs/spec/20260624_love_loop_collapse_audit.md`, which links
|
||||
the raw task log line and event artifact containing the repeated "my sweet / my darling /
|
||||
oh my goodness" samples.
|
||||
@@ -0,0 +1,57 @@
|
||||
# Last-good KL anchor
|
||||
|
||||
## Goal
|
||||
Implement a ratcheting KL reference: heal each round against the most recent checkpoint that still passed the coherence gate. If a healed checkpoint passes, it becomes the new reference; if it fails the adoption gate but remains above `coh_floor`, the loop continues without blessing the failed checkpoint as the next reference.
|
||||
|
||||
This tests the hypothesis that `prev` lets incoherence drift and `base` fights trait history, while `last_good` keeps the anchor coherent without forcing the model all the way back to round 0.
|
||||
|
||||
## Scope
|
||||
In: config knob, heal reference selection, loop state, a just recipe, fast-dev proof, queued real run.
|
||||
|
||||
Out: new filtering heuristics, new metrics, multi-arm sweep, changing the diary/report format unless needed for proof.
|
||||
|
||||
## Requirements
|
||||
- R1: `barrier_ref=last_good` uses the latest coherent checkpoint as the KL reference.
|
||||
Done means: the heal log prints `barrier_ref=last_good ref_round=<n>` and the ref stays unchanged until a round passes the coherence gate.
|
||||
VERIFY: `just fast-dev-run --barrier-ref=last_good ...` reaches heal and logs the selected reference.
|
||||
- R2: Coherence adoption is explicit and fail-fast.
|
||||
Done means: after each eval, the loop logs whether the checkpoint was adopted as last-good; a failed adoption gate holds the old reference, while `coh_floor` still stops broken runs.
|
||||
VERIFY: log lines show adoption only after `coherence >= max(cfg.coh_floor, last_good_coherence * cfg.ref_adopt_rel)`.
|
||||
- R3: Real run is queued on branch `dv` with a why/resolve pueue label.
|
||||
Done means: `pueue status --json` shows a queued/running task whose command includes `--barrier-ref=last_good`, `--kl-agg=rmse`, and a non-positive `--lam-round-pow`.
|
||||
VERIFY: status table includes the task id and label.
|
||||
|
||||
## Tasks
|
||||
- [x] T1 (R1/R2): Implement config + loop reference state.
|
||||
- steps: add `last_good` literal and `ref_adopt_rel`; pass `ref_specs` into `heal_round`; update adoption logging.
|
||||
- verify: `just fast-dev-run --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam-round-pow=-0.5 --spectral-lam=0 --n-rounds=1`
|
||||
- success: heal log names `barrier_ref=last_good ref_round=-1`; tiny-random holds the reference because coherence is below `coh_floor`.
|
||||
- likely_fail: tyro rejects the new enum; verify command errors before model load.
|
||||
- sneaky_fail: code accepts the enum but still uses `hist_specs`/`base`; log catches selected ref round and number of specs.
|
||||
- UAT: the run log links to a file containing both selected-ref and adoption evidence.
|
||||
- [x] T2 (R3): Add a recipe and queue the real run.
|
||||
- steps: add a `run-last-good-love` or queue recipe; pueue add from `dv` worktree with a why/resolve label.
|
||||
- verify: `pueue status --json | jq ...`
|
||||
- success: status row includes the task id, branch workdir, and command.
|
||||
- likely_fail: pueue daemon unavailable; command reports connection failure.
|
||||
- sneaky_fail: queued command runs wrong branch or missing knobs; status command shows command/path.
|
||||
- UAT: status table/log path shows a queued or running task with the intended knobs.
|
||||
|
||||
## Context
|
||||
`hist_specs` stores one `AdapterSpec` per folded round. The base reference is `[]`; the previous-student reference is `hist_specs`; the last-good reference can be represented as `hist_specs[:last_good_n]`, where `last_good_n` is the number of adopted adapters. `last_good_n=0` means base.
|
||||
|
||||
The coherence metric is `p_ans_any` from tinymfv. It is generous, so adoption uses both the relative 99% gate and the absolute `coh_floor`; sample judging remains in the run report/log.
|
||||
|
||||
## Log
|
||||
- Branch `dv` created from dirty `main`; pre-existing edits in README, journal, filter, heal, steering were present before this task.
|
||||
- Fast-dev caught a relative-threshold hole: tiny-random base coherence is 0, so `0.99 * ref` is 0 and would adopt a broken checkpoint. Adoption now uses `max(coh_floor, ref_adopt_rel * ref_coherence)`.
|
||||
- External review attempt via `external-review-v2` timed out after ~2.5 minutes with no review text; proceeding on compile + fast-dev evidence.
|
||||
- UAT: fast-dev log `/tmp/steer_heal_last_good_fast2.log` contains `barrier_ref=last_good ref_round=-1 ref_specs=0` and `last_good HOLD at r-1`.
|
||||
- UAT: pueue task 181 queued from the `dv` worktree with command `--barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam-round-pow=-0.5`.
|
||||
|
||||
## TODO
|
||||
- Add a token-loop-specific adoption gate if the first last-good run still adopts visually broken rounds.
|
||||
|
||||
## Errors
|
||||
| Task | Error | Resolution |
|
||||
|------|-------|------------|
|
||||
@@ -0,0 +1,54 @@
|
||||
# Love Filter Tighten Requeue
|
||||
|
||||
## Goal
|
||||
Tighten the love-demo completion filter so the next queued run does not train on low-diversity affective junk that base PPL accepts, then requeue the same last-good KL recipe at lowest priority.
|
||||
|
||||
## Scope
|
||||
In: `src/steer_heal/filter.py`, task-181 saved events, compile/fast-dev verification, pueue enqueue.
|
||||
Out: changing the loss, adding a new hyperparameter, changing generation sampling, or redesigning the love persona.
|
||||
|
||||
## Requirements
|
||||
- R1: Reject more junk using the existing `rep_tau` gate. Done means: old task-181 kept samples with "oh my goodness / my darling / sweet" loops now score `rep >= 0.3`. VERIFY: rescoring `out/20260624T144031_gemma-3-4b-it_kl_rev_s42/events.jsonl` prints old-kept/new-pass counts by round and representative rejected rows.
|
||||
- R2: Keep the filter simple and fail-fast. Done means: no new config knob, no fallback, no gen-time repetition penalty hiding the signal from walk-C. VERIFY: code inspection shows the gate is inside `rep_frac` and still feeds the existing `keep = rep < rep_tau` decision.
|
||||
- R3: Requeue the love run at lowest priority. Done means: `pueue status --json` shows a queued task on branch `dv` with priority `0` and a label stating why/resolve. VERIFY: compact status table includes the new task.
|
||||
|
||||
## Tasks
|
||||
- [x] T1 (R1): Measure the shape of task-181 junk.
|
||||
- verify: script over task-181 `events.jsonl`.
|
||||
- success: metrics identify old-kept rows with low lexical diversity / repeated affect tokens / roleplay punctuation.
|
||||
- likely_fail: metrics only catch the exact previous row.
|
||||
- sneaky_fail: the new gate rejects every ordinary love declaration too.
|
||||
- UAT: saved verification log with old/new counts and sample rows.
|
||||
- [x] T2 (R1,R2): Patch `rep_frac` with a stricter quality gate.
|
||||
- verify: `uv run python -m compileall src/steer_heal` and rescoring script.
|
||||
- success: r1/r2 old-kept junk mostly flips to rejected; coherent hand examples remain below `rep_tau`.
|
||||
- likely_fail: threshold is inert because `ppl_tau` was the real issue.
|
||||
- sneaky_fail: extra gate is too love-demo-specific and kills valid affectionate text.
|
||||
- UAT: `/tmp/steer_heal_love_filter_tighten_verify2.log`.
|
||||
- [x] T3 (R2): Run the fast dev path.
|
||||
- verify: `just fast-dev-run ... | tee /tmp/steer_heal_love_filter_tighten_fast.log | tail -80`.
|
||||
- success: tiny run completes, proving the real pipeline still executes.
|
||||
- likely_fail: tiny random text trips the stricter gate and starves training.
|
||||
- sneaky_fail: compile passes but the adaptive gen/filter path is broken.
|
||||
- UAT: `/media/wassname/SGIronWolf/projects5/2026/steer_heal_love/out/20260624T204711_qwen3-5lyr-tiny-random_kl_rev_s42/report.html`.
|
||||
- [x] T4 (R3): Commit, push, and enqueue at priority 0.
|
||||
- verify: `git log -1 --oneline`, `git status --short`, `pueue status --json`.
|
||||
- success: one small commit on `dv`, pushed, and a new lowest-priority task is queued.
|
||||
- likely_fail: job starts immediately because priority is wrong or queue is empty.
|
||||
- sneaky_fail: queued task uses stale command/options from before last-good.
|
||||
- UAT: pueue task `188` is queued with priority `0`.
|
||||
|
||||
## Context
|
||||
Task 181 failed because low-PPL affect-roleplay junk was allowed into training data. Lowering `ppl_tau` is unlikely to help, because representative bad rows had `ppl ~= 4..13`. A text-shape gate is the cheap discriminant.
|
||||
|
||||
## Log
|
||||
- 2026-06-24: Starting from commit `ea89a0e` on branch `dv`; worktree has pre-existing dirty files.
|
||||
- 2026-06-24: Task-181 old-kept rows had low lexical diversity and affect-token density. Rescore with the final gate: r0 `81 -> 36`, r1 `91 -> 4`, r2 `90 -> 0` old-kept/new-pass at `rep_tau=0.3`; hand examples scored `0.036..0.050` and passed. Evidence: `/tmp/steer_heal_love_filter_tighten_verify2.log`.
|
||||
- 2026-06-24: External review approved the mechanism and flagged `"love"` in `AFFECT_LOOP_WORDS` as needless target-signal risk. Removed it and reverified with unchanged counts. Review: `docs/reviews/20260624_love_filter_tighten_code.md`.
|
||||
- 2026-06-24: Final fast-dev run passed on the tiny-random path. Evidence: `/tmp/steer_heal_love_filter_tighten_fast2.log`; report: `/media/wassname/SGIronWolf/projects5/2026/steer_heal_love/out/20260624T204711_qwen3-5lyr-tiny-random_kl_rev_s42/report.html`.
|
||||
- 2026-06-24: Clean-branch audit found `run.py` already called `filter_completions(..., brief=True)` and `generate_steered(..., rnd=...)`, while the matching support was still local-only. Committed those support changes so `origin/dv` is runnable from a clean checkout.
|
||||
- 2026-06-24: Queued pueue task `188` at priority `0`: `env STEER_ATTN_IMPL=eager uv run python -m steer_heal.run --demo=love --use-qlora --train-bs=3 --grad-accum=2 --reg=kl_rev --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam=0.3 --lam-round-pow=-0.5 --spectral-lam=0.005 --n-rounds=8 --seed=42`.
|
||||
|
||||
## Errors
|
||||
| Task | Error | Resolution |
|
||||
|------|-------|------------|
|
||||
@@ -0,0 +1,69 @@
|
||||
# Love Loop Collapse Audit
|
||||
|
||||
## Goal
|
||||
Explain why pueue task 181 degenerated into "oh my goodness" affect loops, and add a fail-fast gate so the run does not keep spending GPU once walk-C cannot find usable steered data.
|
||||
|
||||
## Scope
|
||||
In: task 181 logs/artifacts, generation/filter/adoption path, minimal filter/gate patch.
|
||||
Out: redesigning the love demo persona or re-running the full experiment.
|
||||
|
||||
## Requirements
|
||||
- R1: Preserve the audit trail. Done means: this file links the killed task log and run artifact that show the collapse entering the training data. VERIFY: `rg -n "SWEET|goodness|last_good|walk-C|filter kept" /tmp/steer_heal_task181_full.log`.
|
||||
- R2: Catch lexical affect loops in the existing repetition filter. Done means: the r2 kept sample that previously scored `rep=0.096` now scores above `rep_tau=0.3`. VERIFY: a small script over task 181's saved events prints old/new scores for r0/r1/r2.
|
||||
- R3: Fail fast when walk-C cannot hit the requested survival target. Done means: if all probe rows are failures, `gen_filter_walk` raises before collect/train. VERIFY: fast-dev-run still completes, and code inspection shows the raise is before collection.
|
||||
|
||||
## Tasks
|
||||
- [x] T1 (R1): Kill task 181.
|
||||
- verify: `pueue status --json | jq -r '.tasks["181"].status'`
|
||||
- success: task is `Killed`.
|
||||
- UAT: pueue status shows task 181 killed, not running.
|
||||
- [x] T2 (R1): Audit task 181 collapse path.
|
||||
- verify: `rg -n "SWEET|goodness|last_good|walk-C|filter kept" /tmp/steer_heal_task181_full.log`
|
||||
- success: log shows repeated phrase in a kept steered sample, plus last_good adoption/hold decisions.
|
||||
- likely_fail: only eval says the phrase; actual training data is clean.
|
||||
- sneaky_fail: the reference ratchet adopted a bad checkpoint and made it the KL anchor.
|
||||
- UAT: this file records the exact lines and run artifact path.
|
||||
- [x] T3 (R2): Patch `rep_frac` to catch low-diversity compressed lexical loops.
|
||||
- verify: old task-181 r2 kept sample scores `rep >= 0.3`.
|
||||
- success: r2 collapsed rows fail the existing `rep_tau` gate without a new knob.
|
||||
- likely_fail: threshold catches all r0 useful samples.
|
||||
- sneaky_fail: the sample still passes because exact n-gram repetition is diffuse.
|
||||
- UAT: before/after table from task 181 events.
|
||||
- [x] T4 (R3): Patch walk-C to raise when no probe meets `gen_pass_target`.
|
||||
- verify: `rg -n "no probe reached" src/steer_heal/run.py`.
|
||||
- success: all-fail probe table cannot silently continue to collection.
|
||||
- likely_fail: fast-dev tiny run trips because tiny config has relaxed `rep_tau`.
|
||||
- sneaky_fail: code raises after collection, still wasting the long batch.
|
||||
- UAT: fast-dev-run completes and code location is before collect phase.
|
||||
|
||||
## Context
|
||||
Task 181 command:
|
||||
|
||||
```sh
|
||||
env STEER_ATTN_IMPL=eager uv run python -m steer_heal.run --demo=love --use-qlora --train-bs=3 --grad-accum=2 --reg=kl_rev --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam=0.3 --lam-round-pow=-0.5 --spectral-lam=0.005 --n-rounds=8 --seed=42
|
||||
```
|
||||
|
||||
Run artifact:
|
||||
|
||||
`/media/wassname/SGIronWolf/projects5/2026/steer_heal_love/out/20260624T144031_gemma-3-4b-it_kl_rev_s42/events.jsonl`
|
||||
|
||||
Full killed-task log:
|
||||
|
||||
`/tmp/steer_heal_task181_full.log`
|
||||
|
||||
## Log
|
||||
- 2026-06-24: Killed task 181. Pueue status reports `Killed`.
|
||||
- 2026-06-24: The first severe collapse is in task-181 training data, not only eval. `/tmp/steer_heal_task181_full.log:1436` shows an r2 walk-C kept sample with repeated "my sweet / my darling / oh my goodness" and a long character loop. The saved event for that row has `ppl=3.986`, `rep=0.096`, `keep=true`, so old `rep_frac` missed diffuse phrase loops.
|
||||
- 2026-06-24: `last_good` did not ratchet to the degraded rounds. Log lines show r0 adopted at coherence 0.989, then r1 held at 0.957 and r2 held at 0.971 against threshold 0.979. The missing gate is data quality / walk-C failure, not reference adoption.
|
||||
- 2026-06-24: r3 walk-C had all probe rows below target and still entered collection at `kappa=0.200`. That should fail fast because the log itself says all-fail at `kappa_min` means upstream collapse or wrong filter.
|
||||
- 2026-06-24: Rescoring task-181 events with the patched `rep_frac`: first eight r2 collapsed kept rows moved from old `rep=0.073..0.131` to `new_rep=1.000`, so they now fail `rep_tau=0.3`. Aggregate old-kept/new-pass counts: r0 `81 -> 59`, r1 `91 -> 26`, r2 `90 -> 2`.
|
||||
- 2026-06-24: External code review agreed the fail-fast raise was correct, flagged the first zlib heuristic as too broad and an encoding mismatch. Fixed by requiring an actually repeated phrase count (`top_bigram_n >= 12` or `top_trigram_n >= 8`) and computing numerator/denominator from the same lowercased bytes.
|
||||
- 2026-06-24: Verification passed:
|
||||
- `uv run python -m compileall src/steer_heal`
|
||||
- `just fast-dev-run --barrier-ref=last_good --kl-agg=rmse --tau=2.0 --lam-round-pow=-0.5 --spectral-lam=0 --n-rounds=1`
|
||||
- fast-dev log: `/tmp/steer_heal_collapse_gate_fast2.log`
|
||||
- fast-dev report: `/media/wassname/SGIronWolf/projects5/2026/steer_heal_love/out/20260624T202514_qwen3-5lyr-tiny-random_kl_rev_s42/report.html`
|
||||
|
||||
## Errors
|
||||
| Task | Error | Resolution |
|
||||
|------|-------|------------|
|
||||
@@ -27,6 +27,24 @@ run *ARGS:
|
||||
|
||||
# Queue sweeps (comment out completed; `just results` to check).
|
||||
queue:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
just queue-last-good-love
|
||||
|
||||
# H: last_good anchor avoids prev-anchor drift without base-anchor history erasure; rmse catches token-loop KL spikes; lam decay relaxes later rounds without disabling the hinge.
|
||||
queue-last-good-love:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
pueue add -w "$PWD" -o 1 \
|
||||
-l "why: test last_good KL anchor vs prev drift/base erasure on love loop; resolve: keep if coherence stays within 99% ref while care moves past README plateau" -- \
|
||||
env STEER_ATTN_IMPL=eager \
|
||||
{{ BASE }} --demo=love --use-qlora --train-bs=3 --grad-accum=2 \
|
||||
--reg=kl_rev --barrier-ref=last_good --kl-agg=rmse --tau=2.0 \
|
||||
--lam=0.3 --lam-round-pow=-0.5 --spectral-lam=0.005 \
|
||||
--n-rounds=8 --seed=42
|
||||
|
||||
# H: kl_rev heals best (mode-seeking suppresses low-base-prob = incoherent tokens).
|
||||
queue-sweep-reg:
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
just sweep-reg
|
||||
|
||||
@@ -97,11 +97,12 @@ class RunConfig:
|
||||
# so #101's barrier never fired); incoherence is outlier-driven, so rmse/p95/max are sensitive to it
|
||||
# (same loop: 1.5/3.8/8.1 vs coherent ~0.03). rmse = smooth dense gradient (train default), p95/max sparser.
|
||||
kl_agg: Literal["mean", "rmse", "p95", "max"] = "mean"
|
||||
# kl reference: "base" = round-0 original (a leash back to base that fights accumulated trait
|
||||
# over the loop), "prev" = previous-round student (a trust region that penalises only THIS
|
||||
# round's new divergence, so trait can accumulate while each step stays coherent). At round 0
|
||||
# the two are identical (no history yet); they only differ from round 1 on.
|
||||
barrier_ref: Literal["base", "prev"] = "prev"
|
||||
# kl reference: "base" = round-0 original (leash to origin), "prev" = previous-round
|
||||
# student (trust region), "last_good" = most recent checkpoint that passed the coherence
|
||||
# adoption gate. last_good is the ratchet: it advances only when coherence stays within
|
||||
# ref_adopt_rel of the current reference, so a bad round does not become tomorrow's anchor.
|
||||
barrier_ref: Literal["base", "prev", "last_good"] = "prev"
|
||||
ref_adopt_rel: float = 0.99
|
||||
lam: float = 0.3 # kl-barrier weight (reg=kl_*); ignored for nll. 0.3 = coherence peak of the #98/#99 ladder (unimodal in lam, peaks 0.1-0.3, 1.0 over-tight); 0.3 = most trait at the peak
|
||||
# round-ramped barrier: lam_eff = lam * (1 + round)**lam_round_pow. 0 = constant (every round same lam).
|
||||
# >0 grows the barrier with round to oppose the COMPOUNDING coherence drift under barrier_ref=prev: each
|
||||
|
||||
+57
-12
@@ -7,6 +7,7 @@ and a first-person narration regex (we want enact, not narrate).
|
||||
|
||||
import math
|
||||
import re
|
||||
import zlib
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
@@ -35,13 +36,25 @@ REFUSAL = (
|
||||
"i'm an ai", "i am an ai", "i don't have personal opinions",
|
||||
)
|
||||
|
||||
AFFECT_LOOP_WORDS = {
|
||||
"oh", "my", "goodness", "god", "heavens", "sweet", "sweetie", "darling",
|
||||
"dearest", "precious", "heart", "soul", "yes", "okay", "just",
|
||||
"sitting", "here",
|
||||
}
|
||||
|
||||
|
||||
def rep_frac(text: str) -> float:
|
||||
"""Max most-repeated n-gram fraction over n in {2,3,4}; ~1.0 = degenerate looping/too short.
|
||||
Word n-grams catch word loops; char n-grams catch character-repetition like TTTTTTT... or
|
||||
!!!!!!... that collapse into a single 'word' and are invisible to word-level checks.
|
||||
Small n catches SHORT loops ("instead their instead their" = a bigram) that the 4-gram alone
|
||||
missed (#34: that text scored 0.27 on 4-grams, under rep_tau=0.3, and poisoned training)."""
|
||||
missed (#34: that text scored 0.27 on 4-grams, under rep_tau=0.3, and poisoned training).
|
||||
|
||||
Diffuse affect loops ("my sweet / my darling / oh my goodness") can evade the single-top-gram
|
||||
fraction because no one exact n-gram dominates. Treat long, low-lexical-diversity, compressible
|
||||
completions, and long affective roleplay mush, as repetition too; this keeps the existing
|
||||
rep_tau gate load-bearing (#181 audit).
|
||||
"""
|
||||
words = text.split()
|
||||
best = 0.0
|
||||
for n in (2, 3, 4):
|
||||
@@ -56,6 +69,28 @@ def rep_frac(text: str) -> float:
|
||||
if not grams:
|
||||
continue
|
||||
best = max(best, Counter(grams).most_common(1)[0][1] / len(grams))
|
||||
|
||||
text_lc = text.lower()
|
||||
lex_words = re.findall(r"[a-z']+", text_lc)
|
||||
if len(lex_words) >= 64:
|
||||
unique_frac = len(set(lex_words)) / len(lex_words)
|
||||
text_lc_bytes = text_lc.encode()
|
||||
compressed_frac = len(zlib.compress(text_lc_bytes)) / len(text_lc_bytes)
|
||||
bigrams = [tuple(lex_words[i : i + 2]) for i in range(len(lex_words) - 1)]
|
||||
trigrams = [tuple(lex_words[i : i + 3]) for i in range(len(lex_words) - 2)]
|
||||
top_bigram_n = Counter(bigrams).most_common(1)[0][1]
|
||||
top_trigram_n = Counter(trigrams).most_common(1)[0][1]
|
||||
if unique_frac < 0.20 and compressed_frac < 0.34 and (top_bigram_n >= 12 or top_trigram_n >= 8):
|
||||
return 1.0
|
||||
affect_frac = sum(w in AFFECT_LOOP_WORDS for w in lex_words) / len(lex_words)
|
||||
punct_frac = sum(ch in "*!?()" for ch in text) / max(len(text), 1)
|
||||
caps_frac = sum(ch.isupper() for ch in text) / max(sum(ch.isalpha() for ch in text), 1)
|
||||
if len(lex_words) >= 128 and affect_frac >= 0.35 and unique_frac < 0.45 and compressed_frac < 0.52:
|
||||
return 1.0
|
||||
if len(lex_words) >= 128 and punct_frac >= 0.035 and affect_frac >= 0.25 and unique_frac < 0.50:
|
||||
return 1.0
|
||||
if len(lex_words) >= 128 and caps_frac >= 0.15 and affect_frac >= 0.25 and unique_frac < 0.55:
|
||||
return 1.0
|
||||
return best
|
||||
|
||||
|
||||
@@ -81,8 +116,9 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
|
||||
return math.exp(nll.item())
|
||||
|
||||
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig, brief: bool = False):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep.
|
||||
brief=True (walk-C probes): one-line count, no raw-sample dump (see _log_filter_report)."""
|
||||
scored = []
|
||||
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
|
||||
rf = rep_frac(c["completion"])
|
||||
@@ -92,12 +128,26 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) and (not ref)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "refuses": ref, "keep": keep})
|
||||
kept = [s for s in scored if s["keep"]]
|
||||
_log_filter_report(scored, cfg)
|
||||
_log_filter_report(scored, cfg, brief=brief)
|
||||
return kept[: cfg.n_keep], scored
|
||||
|
||||
|
||||
def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?"""
|
||||
def _log_filter_report(scored: list[dict], cfg: RunConfig, brief: bool = False) -> None:
|
||||
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?
|
||||
brief=True (walk-C probes): one-line count ONLY. The per-probe survival drives the
|
||||
bisection and is tabulated in the walk summary, so the full dump (~6 completions) x
|
||||
every probe is noise; gen_filter_walk prints ONE clean sample after the dose settles."""
|
||||
# per-criterion drop counts (overlapping): which filter is doing the work?
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_ref = sum(s["refuses"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
if brief:
|
||||
logger.info(f"filter kept {n_kept}/{len(scored)} (dropped ppl>={cfg.ppl_tau:g}:{n_ppl} "
|
||||
f"rep>={cfg.rep_tau}:{n_rep} narrate:{n_nar} refusal:{n_ref})")
|
||||
return
|
||||
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
@@ -145,12 +195,7 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
for s in just_dropped:
|
||||
logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
# per-criterion drop counts (overlapping): which filter is doing the work?
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_ref = sum(s["refuses"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
# per-criterion drop counts (overlapping, computed at top): which filter is doing the work?
|
||||
logger.info(
|
||||
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
|
||||
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
|
||||
|
||||
+42
-22
@@ -1,7 +1,6 @@
|
||||
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
|
||||
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence barrier.
|
||||
|
||||
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
|
||||
previous student, so it resists cumulative drift. reg picks the divergence:
|
||||
The barrier reference is chosen by cfg.barrier_ref. reg picks the divergence:
|
||||
nll SFT only (control)
|
||||
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
|
||||
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
|
||||
@@ -120,7 +119,15 @@ def _val_nll(model, tok, val_kept, hist_specs, lora, cfg) -> float:
|
||||
return sum(losses) / len(losses) if losses else float("nan")
|
||||
|
||||
|
||||
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
|
||||
def heal_round(
|
||||
model,
|
||||
tok,
|
||||
kept: list[dict],
|
||||
hist_specs: list[AdapterSpec],
|
||||
cfg: RunConfig,
|
||||
ref_specs: list[AdapterSpec] | None = None,
|
||||
ref_round: int | str | None = None,
|
||||
):
|
||||
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
|
||||
assert len(kept) >= cfg.min_train, (
|
||||
f"only {len(kept)} kept completions; need >= {cfg.min_train} to train. The steering/filter "
|
||||
@@ -146,13 +153,26 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
# lam_round_pow=0 -> lam_eff==lam (constant, no behaviour change). >0 grows the barrier with round.
|
||||
rnd = len(hist_specs)
|
||||
lam_eff = cfg.lam * (1 + rnd) ** cfg.lam_round_pow
|
||||
if cfg.barrier_ref == "base":
|
||||
barrier_ref_specs = []
|
||||
ref_desc = "base"
|
||||
elif cfg.barrier_ref == "prev":
|
||||
barrier_ref_specs = hist_specs
|
||||
ref_desc = f"prev(r{rnd - 1})"
|
||||
elif cfg.barrier_ref == "last_good":
|
||||
assert ref_specs is not None and ref_round is not None, "last_good barrier requires explicit ref_specs/ref_round"
|
||||
barrier_ref_specs = ref_specs
|
||||
ref_desc = f"last_good(r{ref_round})"
|
||||
else:
|
||||
raise ValueError(f"unknown barrier_ref={cfg.barrier_ref!r}")
|
||||
|
||||
# streaming training table (token-efficient-logging): one row, columns self-decode below.
|
||||
logger.info(f"heal[{cfg.reg}] {len(train_kept)} train (+{len(val_kept)} val) x {cfg.epochs} ep = "
|
||||
f"{n_batches} batches (bs={cfg.train_bs}) -> {n_opt_steps} opt steps (grad_accum={cfg.grad_accum}); "
|
||||
f"lora r={cfg.lora_r} a={cfg.lora_alpha} on layers {cfg.layer_range}; "
|
||||
f"lr={cfg.lr} cosine warmup={cfg.warmup_ratio} betas={cfg.adam_betas}; "
|
||||
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow})")
|
||||
f"lam_eff={lam_eff:.3f} (lam {cfg.lam} x (1+round={rnd})^{cfg.lam_round_pow}); "
|
||||
f"barrier_ref={cfg.barrier_ref} ref_round={ref_round} ref={ref_desc} ref_specs={len(barrier_ref_specs)}")
|
||||
logger.info("SHOULD (val): train_nll falls each epoch (SFT fits the kept data); val_nll falls then "
|
||||
"flattens. If val_nll RISES while train falls -> overfit (fewer epochs / lower r). If "
|
||||
"NEITHER falls -> data is near-base (nothing to distil) or the optimiser is broken.")
|
||||
@@ -188,30 +208,30 @@ def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg:
|
||||
ids = BatchEncoding({k: v[valid] for k, v in ids.items()})
|
||||
masks = masks[valid] # [B', L-1]
|
||||
|
||||
# barrier reference logits (this round's adapter OFF). barrier_ref="base" bakes no
|
||||
# history -> ref = round-0 original (leash to base, fights accumulated trait); "prev"
|
||||
# bakes the history -> ref = previous-round student (trust region, penalises only this
|
||||
# round's new divergence so trait accumulates while each step stays coherent).
|
||||
# Gather completion positions BEFORE log_softmax. softmax is per-row, so selecting the
|
||||
# ~N_comp completion rows then normalising is identical to normalising all B*(L-1) rows
|
||||
# then selecting -- but it never materialises the full [B,L-1,V] log_softmax NOR its
|
||||
# autograd graph. On gemma's 262k vocab at bs>1 the full tensor is what OOM'd the KL step.
|
||||
flat_mask = masks.reshape(-1) # [B'*(L-1)] bool, completion positions
|
||||
tgt_c = ids.input_ids[:, 1:].reshape(-1)[flat_mask] # [N_comp]
|
||||
|
||||
# barrier reference (this round's adapter OFF). base=[], prev=hist_specs,
|
||||
# last_good=hist_specs[:last_good_n] from run.py.
|
||||
if cfg.reg in ("kl_fwd", "kl_rev"):
|
||||
ref_specs = hist_specs if cfg.barrier_ref == "prev" else []
|
||||
with torch.no_grad(), baked(model, ref_specs), lora(model, c=0.0):
|
||||
logp0 = model(**ids).logits[:, :-1].log_softmax(-1) # [B', L-1, V]
|
||||
with torch.no_grad(), baked(model, barrier_ref_specs), lora(model, c=0.0):
|
||||
logits0 = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
V = logits0.shape[-1]
|
||||
logp0_c = logits0.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
|
||||
|
||||
# student logits: history baked + this round's adapter live
|
||||
# student: history baked + this round's adapter live. Same mask-first trick.
|
||||
with baked(model, hist_specs), lora(model, c=1.0):
|
||||
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
logp = logits.log_softmax(-1)
|
||||
|
||||
# flatten batch × seq to masked completion tokens for loss and KL
|
||||
V = logp.shape[-1]
|
||||
logp_c = logp.reshape(-1, V)[masks.reshape(-1)] # [N_comp, V]
|
||||
tgt_c = ids.input_ids[:, 1:].reshape(-1)[masks.reshape(-1)] # [N_comp]
|
||||
logits = model(**ids).logits[:, :-1] # [B', L-1, V]
|
||||
V = logits.shape[-1]
|
||||
logp_c = logits.reshape(-1, V)[flat_mask].log_softmax(-1) # [N_comp, V]
|
||||
sft = F.nll_loss(logp_c, tgt_c)
|
||||
if cfg.reg == "kl_fwd":
|
||||
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
|
||||
div = _agg_kl(_kl_per_pos(logp0_c, logp_c), cfg.kl_agg)
|
||||
elif cfg.reg == "kl_rev":
|
||||
logp0_c = logp0.reshape(-1, V)[masks.reshape(-1)]
|
||||
div = _agg_kl(_kl_per_pos(logp_c, logp0_c), cfg.kl_agg)
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll
|
||||
|
||||
+41
-2
@@ -165,6 +165,11 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -
|
||||
"root cause is upstream (adapter collapsed / filter wrong).\n" +
|
||||
"━"*55
|
||||
)
|
||||
if not any(r["ok"] for r in bisect_log):
|
||||
raise ValueError(
|
||||
f"walk-C no probe reached gen_pass_target={cfg.gen_pass_target:.2f} at r{rnd}; "
|
||||
f"kappa_min={cfg.gen_kappa_min:.3f} still produced collapsed or filtered data"
|
||||
)
|
||||
|
||||
# ── Phase 2: collect training data at settled kappa until n_keep is banked ──
|
||||
logger.info(f"\n{'─'*55}\nwalk-C collect phase: kappa={kappa:.3f}, need {cfg.n_keep} total.\n{'─'*55}")
|
||||
@@ -198,6 +203,9 @@ def gen_filter_walk(model, tok, v, cfg: RunConfig, hist_specs: list, rnd: int) -
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
last_good_n = 0 # number of adapters in the ratcheted coherent reference
|
||||
last_good_round = -1
|
||||
last_good_coherence = None
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
gen_rounds = [] # per-round adapter gens (same prompts) -> outputs.html table
|
||||
@@ -207,6 +215,7 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
# trait), not just coherence. One extra eval per run.
|
||||
logger.info(f"\n\n\n=== EVAL base [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
base_m = evaluate_model(model, tok, cfg, log_sample=True) # one FULL eval gen (token-efficient-logging)
|
||||
last_good_coherence = base_m["coherence"]
|
||||
log_event(run_dir, stage="base", round=-1, **base_m) # persist so offline plot_run.py is self-contained
|
||||
stages = [{"round": "-", "stage": "base", "m": base_m}] # base -> steered -> healed, for table + trajectory plot
|
||||
# BASE demo column (round -1): the no-adapter, no-steering model on the SAME demo prompts, so the
|
||||
@@ -252,7 +261,16 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
logger.info(f"\n\n\n=== r{rnd} HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
ref_specs = hist_specs[:last_good_n] if cfg.barrier_ref == "last_good" else None
|
||||
ref_round = last_good_round if cfg.barrier_ref == "last_good" else None
|
||||
ref_coherence = last_good_coherence if cfg.barrier_ref == "last_good" else None
|
||||
if cfg.barrier_ref == "last_good":
|
||||
logger.info(
|
||||
f"last_good reference for r{rnd}: ref_round={ref_round} ref_specs={len(ref_specs)} "
|
||||
f"ref_coherence={ref_coherence:.3f}; adoption gate = new_coh >= max(coh_floor={cfg.coh_floor:.3f}, "
|
||||
f"{cfg.ref_adopt_rel:.3f} * ref_coh = {cfg.ref_adopt_rel * ref_coherence:.3f})"
|
||||
)
|
||||
lora, spec, heal_nll = heal_round(model, tok, kept, hist_specs, cfg, ref_specs=ref_specs, ref_round=ref_round)
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
@@ -291,12 +309,33 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
logger.info(f"\n\n\n=== ADAPTER DEMO r{rnd} coh(p_ans_any)={m['coherence']:.3f} adapter_ppl={adapter_ppl:.0f} "
|
||||
f"(no steering; compare across rounds: change vs saturation) ===\n" + demo_lines)
|
||||
|
||||
ref_adopted = False
|
||||
if cfg.barrier_ref == "last_good":
|
||||
adopt_threshold = max(cfg.coh_floor, cfg.ref_adopt_rel * last_good_coherence)
|
||||
if m["coherence"] >= adopt_threshold:
|
||||
last_good_n = len(hist_specs)
|
||||
last_good_round = rnd
|
||||
last_good_coherence = m["coherence"]
|
||||
ref_adopted = True
|
||||
logger.info(
|
||||
f"last_good ADOPT r{rnd}: coherence={m['coherence']:.3f} >= "
|
||||
f"threshold={adopt_threshold:.3f}; next ref_specs={last_good_n}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"last_good HOLD at r{last_good_round}: r{rnd} coherence={m['coherence']:.3f} < "
|
||||
f"threshold={adopt_threshold:.3f}; next round still leashes to r{last_good_round}"
|
||||
)
|
||||
|
||||
vf = _flatten_v(v)
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
|
||||
"adapter_ppl": adapter_ppl, "n_comps": n_comps, "n_kept": len(kept),
|
||||
"kappa": kappa, "heal_nll": heal_nll}
|
||||
"kappa": kappa, "heal_nll": heal_nll,
|
||||
"barrier_ref_round": ref_round, "barrier_ref_coherence": ref_coherence,
|
||||
"last_good_round": last_good_round if cfg.barrier_ref == "last_good" else None,
|
||||
"last_good_adopted": ref_adopted}
|
||||
rounds.append(rec)
|
||||
stages.append({"round": rnd, "stage": "steered", "m": m_steer})
|
||||
stages.append({"round": rnd, "stage": "healed", "m": m})
|
||||
|
||||
@@ -28,8 +28,11 @@ def _extract_prompts(cfg: RunConfig) -> list[str]:
|
||||
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
|
||||
diverse suffixes isolate the persona's general residual-stream shift."""
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
suffixes = json.loads(Path(cfg.extract_data).read_text())
|
||||
rng = random.Random(cfg.seed)
|
||||
rng.shuffle(suffixes)
|
||||
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
|
||||
|
||||
|
||||
@@ -44,7 +47,21 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
||||
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
|
||||
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
|
||||
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
|
||||
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
|
||||
# Show completions for the first pair AND a seeded pick (avoids always landing on
|
||||
# the same weird first suffix). Seed primes which pair so it varies across runs.
|
||||
demo_indices = {0, (cfg.seed * 7) % len(pos)}
|
||||
for idx in sorted(demo_indices):
|
||||
pos_comp = _gen_one(model, tok, pos[idx], cfg, greedy=True)[:256]
|
||||
neg_comp = _gen_one(model, tok, neg[idx], cfg, greedy=True)[:256]
|
||||
logger.info(
|
||||
f"\n=== EXTRACT demo trace pair[{idx}] ===\n"
|
||||
f"POS prompt: {pos[idx][:200]}...\n"
|
||||
f"POS comp (64): {pos_comp[:64]}\n"
|
||||
f"NEG prompt: {neg[idx][:200]}...\n"
|
||||
f"NEG comp (64): {neg_comp[:64]}\n"
|
||||
f"--- full POS comp ---\n{pos_comp}\n"
|
||||
f"--- full NEG comp ---\n{neg_comp}"
|
||||
)
|
||||
|
||||
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
||||
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
|
||||
@@ -82,7 +99,7 @@ def _gen_one(model, tok, text, cfg, greedy: bool = False):
|
||||
|
||||
|
||||
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
||||
max_gens: int | None = None) -> list[dict]:
|
||||
max_gens: int | None = None, rnd: int | None = None) -> list[dict]:
|
||||
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
||||
|
||||
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
||||
@@ -93,7 +110,8 @@ def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
||||
"""
|
||||
out = []
|
||||
n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas)
|
||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||
rtag = f"r{rnd} " if rnd is not None else ""
|
||||
logger.info(f"\n\n\n=== {rtag}GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
|
||||
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
||||
pool = pool_for(cfg.demo)
|
||||
|
||||
Reference in New Issue
Block a user