diff --git a/experiment.ipynb b/experiment.ipynb index b193028..15f395a 100644 --- a/experiment.ipynb +++ b/experiment.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "76397472", + "id": "b8a288c6", "metadata": {}, "source": [ "# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n", @@ -20,7 +20,7 @@ { "cell_type": "code", "execution_count": null, - "id": "703c8e38", + "id": "e8c081e6", "metadata": {}, "outputs": [], "source": [ @@ -44,28 +44,35 @@ { "cell_type": "code", "execution_count": null, - "id": "e6f56583", + "id": "7b5b34e4", "metadata": {}, "outputs": [], "source": [ - "def get_collective_s_space_weight(model):\n", + "def get_s_space_projector(model, top_k=256):\n", " \"\"\"\n", " Gathers all weight matrices that write to the residual stream\n", " (o_proj from attention and down_proj from MLP) across all layers,\n", " and concatenates them to form a collective \"write\" transformation.\n", + " Then computes the SVD to extract the top_k modes.\n", + " Returns: U (hidden_size, top_k), S (top_k,)\n", " \"\"\"\n", " Ws = []\n", " for layer in model.model.layers:\n", " # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n", " # We want a combined matrix of shape [hidden_size, sum(in_features)]\n", - " # so that SVD gives U of shape [hidden_size, top_k].\n", - " # o_proj.weight is [hidden_size, num_heads * head_dim]\n", - " # down_proj.weight is [hidden_size, intermediate_size]\n", " Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n", " Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n", - " return torch.cat(Ws, dim=1).to(model.device)\n", + " W = torch.cat(Ws, dim=1).to(model.device)\n", + " \n", + " # SVD on the collective weight matrix\n", + " U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n", + " \n", + " # Crop to top_k modes\n", + " U = U[:, :top_k]\n", + " S = S[:top_k]\n", + " return U, S\n", "\n", - "def project_to_s_space(hidden_states, W, top_k=256):\n", + "def project_to_s_space(hidden_states, U, S):\n", " \"\"\"\n", " Projects the residual stream into the 'super' S-space of all residual writers.\n", " \n", @@ -75,16 +82,7 @@ " hidden states from the residual stream, and the U from all residual writers, \n", " we can project the residual stream into S-space, which can be thought of as \n", " something like the coordinate space of learned modes of behaviors.\n", - " \n", - " W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n", " \"\"\"\n", - " # SVD on the collective weight matrix\n", - " U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n", - " \n", - " # Crop to top_k modes (there will be a lot of overlap/redundancy)\n", - " U = U[:, :top_k]\n", - " S = S[:top_k]\n", - " \n", " # Project: x_S = (x @ U)\n", " x_S = hidden_states.to(torch.float32) @ U\n", " \n", @@ -125,11 +123,11 @@ { "cell_type": "code", "execution_count": null, - "id": "7cbb2a21", + "id": "ab5130a9", "metadata": {}, "outputs": [], "source": [ - "def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n", + "def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_U=None, s_space_S=None):\n", " messages = [{\"role\": \"user\", \"content\": prompt_text}]\n", " \n", " inputs = tokenizer.apply_chat_template(\n", @@ -190,8 +188,8 @@ " start_idx = prompt_ids.shape[1]\n", " cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n", " \n", - " if s_space_weight is not None:\n", - " trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n", + " if s_space_U is not None and s_space_S is not None:\n", + " trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S)\n", " else:\n", " trajectory = cot_hiddens\n", " \n", @@ -207,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "812de2b3", + "id": "96c84bd3", "metadata": { "lines_to_next_cell": 2 }, @@ -241,11 +239,11 @@ "results = {}\n", "\n", "# Project using the collective residual stream writers transformation\n", - "s_space_W = get_collective_s_space_weight(model) \n", + "s_space_U, s_space_S = get_s_space_projector(model)\n", "\n", "for p_key, p_prefix in PERSONAS.items():\n", " print(f\"\\n--- Running: {p_key} ---\")\n", - " res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n", + " res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S)\n", " results[p_key] = res\n", " print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n", " print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n", diff --git a/experiment.py b/experiment.py index efa362b..e09af80 100644 --- a/experiment.py +++ b/experiment.py @@ -39,24 +39,31 @@ NUM_EXAMPLES = 3 # %% -def get_collective_s_space_weight(model): +def get_s_space_projector(model, top_k=256): """ Gathers all weight matrices that write to the residual stream (o_proj from attention and down_proj from MLP) across all layers, and concatenates them to form a collective "write" transformation. + Then computes the SVD to extract the top_k modes. + Returns: U (hidden_size, top_k), S (top_k,) """ Ws = [] for layer in model.model.layers: # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features] # We want a combined matrix of shape [hidden_size, sum(in_features)] - # so that SVD gives U of shape [hidden_size, top_k]. - # o_proj.weight is [hidden_size, num_heads * head_dim] - # down_proj.weight is [hidden_size, intermediate_size] Ws.append(layer.self_attn.o_proj.weight.detach().cpu()) Ws.append(layer.mlp.down_proj.weight.detach().cpu()) - return torch.cat(Ws, dim=1).to(model.device) + W = torch.cat(Ws, dim=1).to(model.device) + + # SVD on the collective weight matrix + U, S, _ = torch.linalg.svd(W.float(), full_matrices=False) + + # Crop to top_k modes + U = U[:, :top_k] + S = S[:top_k] + return U, S -def project_to_s_space(hidden_states, W, top_k=256): +def project_to_s_space(hidden_states, U, S): """ Projects the residual stream into the 'super' S-space of all residual writers. @@ -66,16 +73,7 @@ def project_to_s_space(hidden_states, W, top_k=256): hidden states from the residual stream, and the U from all residual writers, we can project the residual stream into S-space, which can be thought of as something like the coordinate space of learned modes of behaviors. - - W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size) """ - # SVD on the collective weight matrix - U, S, _ = torch.linalg.svd(W.float(), full_matrices=False) - - # Crop to top_k modes (there will be a lot of overlap/redundancy) - U = U[:, :top_k] - S = S[:top_k] - # Project: x_S = (x @ U) x_S = hidden_states.to(torch.float32) @ U @@ -114,7 +112,7 @@ def compute_curvature(hidden_states): # %% -def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_weight=None): +def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_U=None, s_space_S=None): messages = [{"role": "user", "content": prompt_text}] inputs = tokenizer.apply_chat_template( @@ -175,8 +173,8 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_spac start_idx = prompt_ids.shape[1] cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]] - if s_space_weight is not None: - trajectory = project_to_s_space(cot_hiddens, s_space_weight) + if s_space_U is not None and s_space_S is not None: + trajectory = project_to_s_space(cot_hiddens, s_space_U, s_space_S) else: trajectory = cot_hiddens @@ -218,11 +216,11 @@ plt.figure(figsize=(10, 6)) results = {} # Project using the collective residual stream writers transformation -s_space_W = get_collective_s_space_weight(model) +s_space_U, s_space_S = get_s_space_projector(model) for p_key, p_prefix in PERSONAS.items(): print(f"\n--- Running: {p_key} ---") - res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W) + res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_U=s_space_U, s_space_S=s_space_S) results[p_key] = res print(f"Logratio (Yes/No): {res['logratio']:.3f}") print(f"Prompt:\n```md\n{res['prompt']}```")