fix layers

This commit is contained in:
wassname
2026-04-10 11:12:45 +08:00
parent 9c317af8eb
commit 60cd056320
3 changed files with 64 additions and 28 deletions
+58 -28
View File
@@ -28,6 +28,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
from einops import rearrange, reduce, repeat
# --- CONFIGURATION ---
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
@@ -72,7 +73,7 @@ def project_to_s_space(hidden_states, U, S):
something like the coordinate space of learned modes of behaviors.
"""
# Project: x_S = (x @ U)
x_S = hidden_states.to(torch.float32) @ U
x_S = hidden_states.to(torch.float32) @ U # * sqrt(S) # optional scaling by singular values, but biases towards top pretrained modes
# Align signs: flip U (and x_S) so the maximum projection is positive
# This standardizes the direction of the modes
@@ -82,41 +83,57 @@ def project_to_s_space(hidden_states, U, S):
x_S = x_S * signs
# Scale by singular values
x_S = x_S * S
# No S-scaling: scaling by S makes top-10 dimensions dominate the norm,
# washing out persona differences that live in lower-S directions.
# If we want energy weighting, use sqrt(S) -- but flat is better for
# detecting persona-induced curvature changes.
# x_S = x_S * S # DON'T: kills persona signal in norms
return x_S
def compute_curvature(hidden_states):
'''
Computes Frenet-Serret extrinsic curvature (kappa).
kappa(t) = ||gamma''(t)|| / ||gamma'(t)||^3
Frenet-Serret curvature for arbitrary (non-arc-length) parameterization.
gamma: [T, D] trajectory in D-dimensional space, parameterized by token index t.
dim=0 is the trajectory (we differentiate along this), dim=1 is coordinates.
For arc-length: kappa = ||gamma''|| / ||gamma'||^3
For arbitrary t: kappa = ||gamma' x gamma''|| / ||gamma'||^3
= sqrt(||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2) / ||gamma'||^3
The cross-product form subtracts tangential acceleration (speed changes),
leaving only normal acceleration (direction changes). Token index is NOT
arc-length -- speed varies a lot, and tangential acceleration is large and
persona-invariant. Without the correction, it dominates the numerator.
'''
# TODO assert has grad
eps=1e-12
gamma = hidden_states.to(torch.float32) # [T, D]
d_gamma = torch.gradient(gamma, dim=0)[0] # [T, D]
dd_gamma = torch.gradient(d_gamma, dim=0)[0] # [T, D]
# Cast to float32 to prevent float16 overflow when cubing
gamma = hidden_states.to(torch.float32)
d_gamma = torch.gradient(gamma, dim=0)[0]
dd_gamma = torch.gradient(d_gamma, dim=0)[0]
norm_d_sq = (d_gamma ** 2).sum(dim=1) # [T]
norm_dd_sq = (dd_gamma ** 2).sum(dim=1) # [T]
dot_d_dd = (d_gamma * dd_gamma).sum(dim=1) # [T]
norm_d_gamma = torch.norm(d_gamma, dim=1)
norm_dd_gamma = torch.norm(dd_gamma, dim=1)
# ||gamma' x gamma''||^2 = ||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2
cross_sq = (norm_d_sq * norm_dd_sq - dot_d_dd ** 2).clamp(min=eps)
norm_d_cubed = norm_d_sq * norm_d_sq.sqrt() # ||gamma'||^3
kappa = norm_dd_gamma / (norm_d_gamma ** 3 + 1e-12)
kappa = cross_sq.sqrt() / (norm_d_cubed + eps)
return kappa
# %%
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
messages = [{"role": "user", "content": prompt_text}]
messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt_text}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
enable_thinking=True
).to(device)
@@ -129,10 +146,13 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
attention_mask=attention_mask,
max_new_tokens=n_think,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
pad_token_id=tokenizer.eos_token_id,
use_cache=True, # TODO use cache in the model( call to save compute
output_hidden_states=True,
return_dict_in_generate=True,
)
start_idx = prompt_ids.shape[1]
generated_ids = out[0, start_idx:]
generated_ids = out.sequences[0, start_idx:]
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
@@ -143,15 +163,15 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
], dim=1)
# TODO kv cache
with torch.no_grad():
outputs = model(
out_score = model(
full_ids,
attention_mask=full_attention_mask,
output_hidden_states=True
output_hidden_states=True,
)
logits = outputs.logits[0, -1, :]
logits = out_score.logits[0, -1, :]
log_probs = F.log_softmax(logits, dim=-1)
# Simple parsing of Yes vs No variants
@@ -168,13 +188,23 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
# Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output.
target_layer = int(0.8 * (len(outputs.hidden_states) - 1))
print(f"Extracting hidden states from layer {target_layer}/{len(outputs.hidden_states) - 1} for curvature analysis. Shape of hidden states: {outputs.hidden_states[target_layer].shape}")
n_layers = len(out.hidden_states[0])
target_layer = int(0.8 * n_layers)
# out.hidden_states comes out as
# tuple: (inputs, token1, token2)
# of which each is tuple: layer,
# containing [b t h]
hs = torch.concat([x[target_layer] for x in out.hidden_states], dim=1) # [batch_size, seq_len, hidden_dim]
print(f"Extracting hidden states from layer {target_layer}/{n_layers} for curvature analysis")
# hs = rearrange(out.hidden_states[0][target_layer], 'b t h -> b t h')
print(f"Shape of hidden states: {hs.shape} [b t h]")
middle_layer_hiddens = outputs.hidden_states[target_layer][0]
cot_hiddens = middle_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)
trajectory = project_to_s_space(hs[0], s_space_U, s_space_S) # [B=1, seq_len, s_dim]
return {
"logratio": (p_yes - p_no).item(),