fix layers

This commit is contained in:
wassname
2026-04-10 11:12:45 +08:00
parent 9c317af8eb
commit 60cd056320
3 changed files with 64 additions and 28 deletions
+58 -28
View File
@@ -28,6 +28,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from einops import rearrange, reduce, repeat
# --- CONFIGURATION --- # --- CONFIGURATION ---
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
@@ -72,7 +73,7 @@ def project_to_s_space(hidden_states, U, S):
something like the coordinate space of learned modes of behaviors. something like the coordinate space of learned modes of behaviors.
""" """
# Project: x_S = (x @ U) # Project: x_S = (x @ U)
x_S = hidden_states.to(torch.float32) @ U x_S = hidden_states.to(torch.float32) @ U # * sqrt(S) # optional scaling by singular values, but biases towards top pretrained modes
# Align signs: flip U (and x_S) so the maximum projection is positive # Align signs: flip U (and x_S) so the maximum projection is positive
# This standardizes the direction of the modes # This standardizes the direction of the modes
@@ -82,41 +83,57 @@ def project_to_s_space(hidden_states, U, S):
x_S = x_S * signs x_S = x_S * signs
# Scale by singular values # No S-scaling: scaling by S makes top-10 dimensions dominate the norm,
x_S = x_S * S # washing out persona differences that live in lower-S directions.
# If we want energy weighting, use sqrt(S) -- but flat is better for
# detecting persona-induced curvature changes.
# x_S = x_S * S # DON'T: kills persona signal in norms
return x_S return x_S
def compute_curvature(hidden_states): def compute_curvature(hidden_states):
''' '''
Computes Frenet-Serret extrinsic curvature (kappa). Frenet-Serret curvature for arbitrary (non-arc-length) parameterization.
kappa(t) = ||gamma''(t)|| / ||gamma'(t)||^3
gamma: [T, D] trajectory in D-dimensional space, parameterized by token index t.
dim=0 is the trajectory (we differentiate along this), dim=1 is coordinates.
For arc-length: kappa = ||gamma''|| / ||gamma'||^3
For arbitrary t: kappa = ||gamma' x gamma''|| / ||gamma'||^3
= sqrt(||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2) / ||gamma'||^3
The cross-product form subtracts tangential acceleration (speed changes),
leaving only normal acceleration (direction changes). Token index is NOT
arc-length -- speed varies a lot, and tangential acceleration is large and
persona-invariant. Without the correction, it dominates the numerator.
''' '''
eps=1e-12
# TODO assert has grad gamma = hidden_states.to(torch.float32) # [T, D]
d_gamma = torch.gradient(gamma, dim=0)[0] # [T, D]
dd_gamma = torch.gradient(d_gamma, dim=0)[0] # [T, D]
# Cast to float32 to prevent float16 overflow when cubing norm_d_sq = (d_gamma ** 2).sum(dim=1) # [T]
gamma = hidden_states.to(torch.float32) norm_dd_sq = (dd_gamma ** 2).sum(dim=1) # [T]
d_gamma = torch.gradient(gamma, dim=0)[0] dot_d_dd = (d_gamma * dd_gamma).sum(dim=1) # [T]
dd_gamma = torch.gradient(d_gamma, dim=0)[0]
norm_d_gamma = torch.norm(d_gamma, dim=1) # ||gamma' x gamma''||^2 = ||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2
norm_dd_gamma = torch.norm(dd_gamma, dim=1) cross_sq = (norm_d_sq * norm_dd_sq - dot_d_dd ** 2).clamp(min=eps)
norm_d_cubed = norm_d_sq * norm_d_sq.sqrt() # ||gamma'||^3
kappa = norm_dd_gamma / (norm_d_gamma ** 3 + 1e-12) kappa = cross_sq.sqrt() / (norm_d_cubed + eps)
return kappa return kappa
# %% # %%
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None): def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
messages = [{"role": "user", "content": prompt_text}] messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt_text}]
inputs = tokenizer.apply_chat_template( inputs = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True,
return_tensors="pt", return_tensors="pt",
return_dict=True, return_dict=True,
add_generation_prompt=True,
enable_thinking=True enable_thinking=True
).to(device) ).to(device)
@@ -129,10 +146,13 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
attention_mask=attention_mask, attention_mask=attention_mask,
max_new_tokens=n_think, max_new_tokens=n_think,
do_sample=False, do_sample=False,
pad_token_id=tokenizer.eos_token_id pad_token_id=tokenizer.eos_token_id,
use_cache=True, # TODO use cache in the model( call to save compute
output_hidden_states=True,
return_dict_in_generate=True,
) )
start_idx = prompt_ids.shape[1] start_idx = prompt_ids.shape[1]
generated_ids = out[0, start_idx:] generated_ids = out.sequences[0, start_idx:]
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device) suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1) full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
@@ -143,15 +163,15 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
], dim=1) ], dim=1)
# TODO kv cache
with torch.no_grad(): with torch.no_grad():
outputs = model( out_score = model(
full_ids, full_ids,
attention_mask=full_attention_mask, attention_mask=full_attention_mask,
output_hidden_states=True output_hidden_states=True,
) )
logits = outputs.logits[0, -1, :] logits = out_score.logits[0, -1, :]
log_probs = F.log_softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1)
# Simple parsing of Yes vs No variants # Simple parsing of Yes vs No variants
@@ -168,13 +188,23 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
# Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output. # Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output.
target_layer = int(0.8 * (len(outputs.hidden_states) - 1))
print(f"Extracting hidden states from layer {target_layer}/{len(outputs.hidden_states) - 1} for curvature analysis. Shape of hidden states: {outputs.hidden_states[target_layer].shape}") n_layers = len(out.hidden_states[0])
target_layer = int(0.8 * n_layers)
# out.hidden_states comes out as
# tuple: (inputs, token1, token2)
# of which each is tuple: layer,
# containing [b t h]
hs = torch.concat([x[target_layer] for x in out.hidden_states], dim=1) # [batch_size, seq_len, hidden_dim]
print(f"Extracting hidden states from layer {target_layer}/{n_layers} for curvature analysis")
# hs = rearrange(out.hidden_states[0][target_layer], 'b t h -> b t h')
print(f"Shape of hidden states: {hs.shape} [b t h]")
middle_layer_hiddens = outputs.hidden_states[target_layer][0] trajectory = project_to_s_space(hs[0], s_space_U, s_space_S) # [B=1, seq_len, s_dim]
cot_hiddens = middle_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)
return { return {
"logratio": (p_yes - p_no).item(), "logratio": (p_yes - p_no).item(),
+1
View File
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [ dependencies = [
"accelerate>=1.13.0", "accelerate>=1.13.0",
"datasets>=4.8.4", "datasets>=4.8.4",
"einops>=0.8.2",
"jupyter>=1.1.1", "jupyter>=1.1.1",
"jupytext>=1.19.1", "jupytext>=1.19.1",
"matplotlib>=3.10.8", "matplotlib>=3.10.8",
+5
View File
@@ -0,0 +1,5 @@
import torch
x = torch.tensor([1.0, 2.0, 4.0, 7.0])
with torch.no_grad():
dx = torch.gradient(x)[0]
print(dx)