mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
fix layers
This commit is contained in:
+57
-27
@@ -28,6 +28,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from einops import rearrange, reduce, repeat
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# --- CONFIGURATION ---
|
||||||
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||||
@@ -72,7 +73,7 @@ def project_to_s_space(hidden_states, U, S):
|
|||||||
something like the coordinate space of learned modes of behaviors.
|
something like the coordinate space of learned modes of behaviors.
|
||||||
"""
|
"""
|
||||||
# Project: x_S = (x @ U)
|
# Project: x_S = (x @ U)
|
||||||
x_S = hidden_states.to(torch.float32) @ U
|
x_S = hidden_states.to(torch.float32) @ U # * sqrt(S) # optional scaling by singular values, but biases towards top pretrained modes
|
||||||
|
|
||||||
# Align signs: flip U (and x_S) so the maximum projection is positive
|
# Align signs: flip U (and x_S) so the maximum projection is positive
|
||||||
# This standardizes the direction of the modes
|
# This standardizes the direction of the modes
|
||||||
@@ -82,41 +83,57 @@ def project_to_s_space(hidden_states, U, S):
|
|||||||
|
|
||||||
x_S = x_S * signs
|
x_S = x_S * signs
|
||||||
|
|
||||||
# Scale by singular values
|
# No S-scaling: scaling by S makes top-10 dimensions dominate the norm,
|
||||||
x_S = x_S * S
|
# washing out persona differences that live in lower-S directions.
|
||||||
|
# If we want energy weighting, use sqrt(S) -- but flat is better for
|
||||||
|
# detecting persona-induced curvature changes.
|
||||||
|
# x_S = x_S * S # DON'T: kills persona signal in norms
|
||||||
|
|
||||||
return x_S
|
return x_S
|
||||||
|
|
||||||
def compute_curvature(hidden_states):
|
def compute_curvature(hidden_states):
|
||||||
'''
|
'''
|
||||||
Computes Frenet-Serret extrinsic curvature (kappa).
|
Frenet-Serret curvature for arbitrary (non-arc-length) parameterization.
|
||||||
kappa(t) = ||gamma''(t)|| / ||gamma'(t)||^3
|
|
||||||
|
gamma: [T, D] trajectory in D-dimensional space, parameterized by token index t.
|
||||||
|
dim=0 is the trajectory (we differentiate along this), dim=1 is coordinates.
|
||||||
|
|
||||||
|
For arc-length: kappa = ||gamma''|| / ||gamma'||^3
|
||||||
|
For arbitrary t: kappa = ||gamma' x gamma''|| / ||gamma'||^3
|
||||||
|
= sqrt(||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2) / ||gamma'||^3
|
||||||
|
|
||||||
|
The cross-product form subtracts tangential acceleration (speed changes),
|
||||||
|
leaving only normal acceleration (direction changes). Token index is NOT
|
||||||
|
arc-length -- speed varies a lot, and tangential acceleration is large and
|
||||||
|
persona-invariant. Without the correction, it dominates the numerator.
|
||||||
'''
|
'''
|
||||||
|
eps=1e-12
|
||||||
|
gamma = hidden_states.to(torch.float32) # [T, D]
|
||||||
|
d_gamma = torch.gradient(gamma, dim=0)[0] # [T, D]
|
||||||
|
dd_gamma = torch.gradient(d_gamma, dim=0)[0] # [T, D]
|
||||||
|
|
||||||
# TODO assert has grad
|
norm_d_sq = (d_gamma ** 2).sum(dim=1) # [T]
|
||||||
|
norm_dd_sq = (dd_gamma ** 2).sum(dim=1) # [T]
|
||||||
|
dot_d_dd = (d_gamma * dd_gamma).sum(dim=1) # [T]
|
||||||
|
|
||||||
# Cast to float32 to prevent float16 overflow when cubing
|
# ||gamma' x gamma''||^2 = ||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2
|
||||||
gamma = hidden_states.to(torch.float32)
|
cross_sq = (norm_d_sq * norm_dd_sq - dot_d_dd ** 2).clamp(min=eps)
|
||||||
d_gamma = torch.gradient(gamma, dim=0)[0]
|
norm_d_cubed = norm_d_sq * norm_d_sq.sqrt() # ||gamma'||^3
|
||||||
dd_gamma = torch.gradient(d_gamma, dim=0)[0]
|
|
||||||
|
|
||||||
norm_d_gamma = torch.norm(d_gamma, dim=1)
|
kappa = cross_sq.sqrt() / (norm_d_cubed + eps)
|
||||||
norm_dd_gamma = torch.norm(dd_gamma, dim=1)
|
|
||||||
|
|
||||||
kappa = norm_dd_gamma / (norm_d_gamma ** 3 + 1e-12)
|
|
||||||
return kappa
|
return kappa
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
|
def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None):
|
||||||
messages = [{"role": "user", "content": prompt_text}]
|
messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt_text}]
|
||||||
|
|
||||||
inputs = tokenizer.apply_chat_template(
|
inputs = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
add_generation_prompt=True,
|
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
enable_thinking=True
|
enable_thinking=True
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
@@ -129,10 +146,13 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=n_think,
|
max_new_tokens=n_think,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
pad_token_id=tokenizer.eos_token_id
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
use_cache=True, # TODO use cache in the model( call to save compute
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
start_idx = prompt_ids.shape[1]
|
start_idx = prompt_ids.shape[1]
|
||||||
generated_ids = out[0, start_idx:]
|
generated_ids = out.sequences[0, start_idx:]
|
||||||
|
|
||||||
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
|
suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device)
|
||||||
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
|
full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1)
|
||||||
@@ -143,15 +163,15 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
|
|||||||
], dim=1)
|
], dim=1)
|
||||||
|
|
||||||
|
|
||||||
# TODO kv cache
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(
|
out_score = model(
|
||||||
full_ids,
|
full_ids,
|
||||||
attention_mask=full_attention_mask,
|
attention_mask=full_attention_mask,
|
||||||
output_hidden_states=True
|
output_hidden_states=True,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs.logits[0, -1, :]
|
logits = out_score.logits[0, -1, :]
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
log_probs = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
# Simple parsing of Yes vs No variants
|
# Simple parsing of Yes vs No variants
|
||||||
@@ -168,13 +188,23 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac
|
|||||||
|
|
||||||
|
|
||||||
# Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output.
|
# Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output.
|
||||||
target_layer = int(0.8 * (len(outputs.hidden_states) - 1))
|
|
||||||
print(f"Extracting hidden states from layer {target_layer}/{len(outputs.hidden_states) - 1} for curvature analysis. Shape of hidden states: {outputs.hidden_states[target_layer].shape}")
|
|
||||||
|
|
||||||
middle_layer_hiddens = outputs.hidden_states[target_layer][0]
|
n_layers = len(out.hidden_states[0])
|
||||||
cot_hiddens = middle_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]
|
target_layer = int(0.8 * n_layers)
|
||||||
|
|
||||||
trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)
|
|
||||||
|
# out.hidden_states comes out as
|
||||||
|
# tuple: (inputs, token1, token2)
|
||||||
|
# of which each is tuple: layer,
|
||||||
|
# containing [b t h]
|
||||||
|
|
||||||
|
hs = torch.concat([x[target_layer] for x in out.hidden_states], dim=1) # [batch_size, seq_len, hidden_dim]
|
||||||
|
print(f"Extracting hidden states from layer {target_layer}/{n_layers} for curvature analysis")
|
||||||
|
# hs = rearrange(out.hidden_states[0][target_layer], 'b t h -> b t h')
|
||||||
|
|
||||||
|
print(f"Shape of hidden states: {hs.shape} [b t h]")
|
||||||
|
|
||||||
|
trajectory = project_to_s_space(hs[0], s_space_U, s_space_S) # [B=1, seq_len, s_dim]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"logratio": (p_yes - p_no).item(),
|
"logratio": (p_yes - p_no).item(),
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate>=1.13.0",
|
"accelerate>=1.13.0",
|
||||||
"datasets>=4.8.4",
|
"datasets>=4.8.4",
|
||||||
|
"einops>=0.8.2",
|
||||||
"jupyter>=1.1.1",
|
"jupyter>=1.1.1",
|
||||||
"jupytext>=1.19.1",
|
"jupytext>=1.19.1",
|
||||||
"matplotlib>=3.10.8",
|
"matplotlib>=3.10.8",
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
x = torch.tensor([1.0, 2.0, 4.0, 7.0])
|
||||||
|
with torch.no_grad():
|
||||||
|
dx = torch.gradient(x)[0]
|
||||||
|
print(dx)
|
||||||
Reference in New Issue
Block a user