diff --git a/experiment.py b/experiment.py index de241ae..39ef61c 100644 --- a/experiment.py +++ b/experiment.py @@ -28,6 +28,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from tqdm.auto import tqdm import matplotlib.pyplot as plt import numpy as np +from einops import rearrange, reduce, repeat # --- CONFIGURATION --- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" @@ -72,7 +73,7 @@ def project_to_s_space(hidden_states, U, S): something like the coordinate space of learned modes of behaviors. """ # Project: x_S = (x @ U) - x_S = hidden_states.to(torch.float32) @ U + x_S = hidden_states.to(torch.float32) @ U # * sqrt(S) # optional scaling by singular values, but biases towards top pretrained modes # Align signs: flip U (and x_S) so the maximum projection is positive # This standardizes the direction of the modes @@ -82,41 +83,57 @@ def project_to_s_space(hidden_states, U, S): x_S = x_S * signs - # Scale by singular values - x_S = x_S * S + # No S-scaling: scaling by S makes top-10 dimensions dominate the norm, + # washing out persona differences that live in lower-S directions. + # If we want energy weighting, use sqrt(S) -- but flat is better for + # detecting persona-induced curvature changes. + # x_S = x_S * S # DON'T: kills persona signal in norms return x_S def compute_curvature(hidden_states): ''' - Computes Frenet-Serret extrinsic curvature (kappa). - kappa(t) = ||gamma''(t)|| / ||gamma'(t)||^3 + Frenet-Serret curvature for arbitrary (non-arc-length) parameterization. + + gamma: [T, D] trajectory in D-dimensional space, parameterized by token index t. + dim=0 is the trajectory (we differentiate along this), dim=1 is coordinates. + + For arc-length: kappa = ||gamma''|| / ||gamma'||^3 + For arbitrary t: kappa = ||gamma' x gamma''|| / ||gamma'||^3 + = sqrt(||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2) / ||gamma'||^3 + + The cross-product form subtracts tangential acceleration (speed changes), + leaving only normal acceleration (direction changes). Token index is NOT + arc-length -- speed varies a lot, and tangential acceleration is large and + persona-invariant. Without the correction, it dominates the numerator. ''' - - # TODO assert has grad + eps=1e-12 + gamma = hidden_states.to(torch.float32) # [T, D] + d_gamma = torch.gradient(gamma, dim=0)[0] # [T, D] + dd_gamma = torch.gradient(d_gamma, dim=0)[0] # [T, D] - # Cast to float32 to prevent float16 overflow when cubing - gamma = hidden_states.to(torch.float32) - d_gamma = torch.gradient(gamma, dim=0)[0] - dd_gamma = torch.gradient(d_gamma, dim=0)[0] + norm_d_sq = (d_gamma ** 2).sum(dim=1) # [T] + norm_dd_sq = (dd_gamma ** 2).sum(dim=1) # [T] + dot_d_dd = (d_gamma * dd_gamma).sum(dim=1) # [T] - norm_d_gamma = torch.norm(d_gamma, dim=1) - norm_dd_gamma = torch.norm(dd_gamma, dim=1) + # ||gamma' x gamma''||^2 = ||gamma'||^2 ||gamma''||^2 - (gamma' . gamma'')^2 + cross_sq = (norm_d_sq * norm_dd_sq - dot_d_dd ** 2).clamp(min=eps) + norm_d_cubed = norm_d_sq * norm_d_sq.sqrt() # ||gamma'||^3 - kappa = norm_dd_gamma / (norm_d_gamma ** 3 + 1e-12) + kappa = cross_sq.sqrt() / (norm_d_cubed + eps) return kappa # %% def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None): - messages = [{"role": "user", "content": prompt_text}] + messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt_text}] inputs = tokenizer.apply_chat_template( messages, - add_generation_prompt=True, return_tensors="pt", return_dict=True, + add_generation_prompt=True, enable_thinking=True ).to(device) @@ -129,10 +146,13 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac attention_mask=attention_mask, max_new_tokens=n_think, do_sample=False, - pad_token_id=tokenizer.eos_token_id + pad_token_id=tokenizer.eos_token_id, + use_cache=True, # TODO use cache in the model( call to save compute + output_hidden_states=True, + return_dict_in_generate=True, ) start_idx = prompt_ids.shape[1] - generated_ids = out[0, start_idx:] + generated_ids = out.sequences[0, start_idx:] suffix_ids = tokenizer.encode("\\nI should answer now.\\nMy choice: **", add_special_tokens=False, return_tensors="pt").to(device) full_ids = torch.cat([prompt_ids, generated_ids.unsqueeze(0), suffix_ids], dim=1) @@ -143,15 +163,15 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac ], dim=1) - # TODO kv cache with torch.no_grad(): - outputs = model( + out_score = model( full_ids, attention_mask=full_attention_mask, - output_hidden_states=True + output_hidden_states=True, + ) - logits = outputs.logits[0, -1, :] + logits = out_score.logits[0, -1, :] log_probs = F.log_softmax(logits, dim=-1) # Simple parsing of Yes vs No variants @@ -168,13 +188,23 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac # Note the residual stream doesn't change much, but it's suppressed in the last few layers (see https://github.com/wassname/eliciting_suppressed_knowledge & https://arxiv.org/abs/2402.10588) so it's normal to choose the 80% or 60% layer for steering and analysis. We hope most of the thinking has been done, but it hasn't yet been suppressed in preperation for output. - target_layer = int(0.8 * (len(outputs.hidden_states) - 1)) - print(f"Extracting hidden states from layer {target_layer}/{len(outputs.hidden_states) - 1} for curvature analysis. Shape of hidden states: {outputs.hidden_states[target_layer].shape}") + + n_layers = len(out.hidden_states[0]) + target_layer = int(0.8 * n_layers) + + + # out.hidden_states comes out as + # tuple: (inputs, token1, token2) + # of which each is tuple: layer, + # containing [b t h] + + hs = torch.concat([x[target_layer] for x in out.hidden_states], dim=1) # [batch_size, seq_len, hidden_dim] + print(f"Extracting hidden states from layer {target_layer}/{n_layers} for curvature analysis") + # hs = rearrange(out.hidden_states[0][target_layer], 'b t h -> b t h') + + print(f"Shape of hidden states: {hs.shape} [b t h]") - middle_layer_hiddens = outputs.hidden_states[target_layer][0] - cot_hiddens = middle_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]] - - trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S) + trajectory = project_to_s_space(hs[0], s_space_U, s_space_S) # [B=1, seq_len, s_dim] return { "logratio": (p_yes - p_no).item(), diff --git a/pyproject.toml b/pyproject.toml index eaf8c07..c509608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.13" dependencies = [ "accelerate>=1.13.0", "datasets>=4.8.4", + "einops>=0.8.2", "jupyter>=1.1.1", "jupytext>=1.19.1", "matplotlib>=3.10.8", diff --git a/test_grad.py b/test_grad.py new file mode 100644 index 0000000..78644e4 --- /dev/null +++ b/test_grad.py @@ -0,0 +1,5 @@ +import torch +x = torch.tensor([1.0, 2.0, 4.0, 7.0]) +with torch.no_grad(): + dx = torch.gradient(x)[0] + print(dx)