From a1a8648865eab4b325cdb4e6e538f07ce3e3e8fe Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Fri, 10 Apr 2026 09:38:22 +0800 Subject: [PATCH] Implement super S-space projection across all residual writers --- experiment.ipynb | 79 ++++++++++++++++++++++++++++++++++++++++++------ experiment.py | 69 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 135 insertions(+), 13 deletions(-) diff --git a/experiment.ipynb b/experiment.ipynb index a670630..b193028 100644 --- a/experiment.ipynb +++ b/experiment.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "bafde1f1", + "id": "76397472", "metadata": {}, "source": [ "# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n", @@ -20,7 +20,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a30577fd", + "id": "703c8e38", "metadata": {}, "outputs": [], "source": [ @@ -44,10 +44,63 @@ { "cell_type": "code", "execution_count": null, - "id": "a3b31f38", + "id": "e6f56583", "metadata": {}, "outputs": [], "source": [ + "def get_collective_s_space_weight(model):\n", + " \"\"\"\n", + " Gathers all weight matrices that write to the residual stream\n", + " (o_proj from attention and down_proj from MLP) across all layers,\n", + " and concatenates them to form a collective \"write\" transformation.\n", + " \"\"\"\n", + " Ws = []\n", + " for layer in model.model.layers:\n", + " # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n", + " # We want a combined matrix of shape [hidden_size, sum(in_features)]\n", + " # so that SVD gives U of shape [hidden_size, top_k].\n", + " # o_proj.weight is [hidden_size, num_heads * head_dim]\n", + " # down_proj.weight is [hidden_size, intermediate_size]\n", + " Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n", + " Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n", + " return torch.cat(Ws, dim=1).to(model.device)\n", + "\n", + "def project_to_s_space(hidden_states, W, top_k=256):\n", + " \"\"\"\n", + " Projects the residual stream into the 'super' S-space of all residual writers.\n", + " \n", + " Explanation: The residual stream doesn't change much, but gets suppressed in the \n", + " last 3-10% of layers. Since the residual stream interacts with all modules, \n", + " we get the 'super' S-space of all residual stream writers. By getting the \n", + " hidden states from the residual stream, and the U from all residual writers, \n", + " we can project the residual stream into S-space, which can be thought of as \n", + " something like the coordinate space of learned modes of behaviors.\n", + " \n", + " W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n", + " \"\"\"\n", + " # SVD on the collective weight matrix\n", + " U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n", + " \n", + " # Crop to top_k modes (there will be a lot of overlap/redundancy)\n", + " U = U[:, :top_k]\n", + " S = S[:top_k]\n", + " \n", + " # Project: x_S = (x @ U)\n", + " x_S = hidden_states.to(torch.float32) @ U\n", + " \n", + " # Align signs: flip U (and x_S) so the maximum projection is positive\n", + " # This standardizes the direction of the modes\n", + " signs = torch.sign(x_S.max(dim=0).values + x_S.min(dim=0).values) \n", + " # If the max absolute value was negative, signs will be -1, else 1\n", + " signs[signs == 0] = 1.0 # prevent 0 multiplication\n", + " \n", + " x_S = x_S * signs\n", + " \n", + " # Scale by singular values\n", + " x_S = x_S * S\n", + " \n", + " return x_S\n", + "\n", "def compute_curvature(hidden_states):\n", " '''\n", " Computes Frenet-Serret extrinsic curvature (kappa).\n", @@ -72,11 +125,11 @@ { "cell_type": "code", "execution_count": null, - "id": "675b1b52", + "id": "7cbb2a21", "metadata": {}, "outputs": [], "source": [ - "def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\"):\n", + "def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n", " messages = [{\"role\": \"user\", \"content\": prompt_text}]\n", " \n", " inputs = tokenizer.apply_chat_template(\n", @@ -137,9 +190,14 @@ " start_idx = prompt_ids.shape[1]\n", " cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n", " \n", + " if s_space_weight is not None:\n", + " trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n", + " else:\n", + " trajectory = cot_hiddens\n", + " \n", " return {\n", " \"logratio\": (p_yes - p_no).item(),\n", - " \"kappa_trajectory\": compute_curvature(cot_hiddens).cpu().numpy(),\n", + " \"kappa_trajectory\": compute_curvature(trajectory).cpu().numpy(),\n", " \"prompt\": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),\n", " \"generated_text\": tokenizer.decode(generated_ids, skip_special_tokens=False)\n", " }\n", @@ -149,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e2e698d", + "id": "812de2b3", "metadata": { "lines_to_next_cell": 2 }, @@ -182,9 +240,12 @@ "plt.figure(figsize=(10, 6))\n", "results = {}\n", "\n", + "# Project using the collective residual stream writers transformation\n", + "s_space_W = get_collective_s_space_weight(model) \n", + "\n", "for p_key, p_prefix in PERSONAS.items():\n", " print(f\"\\n--- Running: {p_key} ---\")\n", - " res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)\n", + " res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n", " results[p_key] = res\n", " print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n", " print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n", @@ -192,7 +253,7 @@ " \n", " plt.plot(res['kappa_trajectory'], label=f\"{p_key} (logratio: {res['logratio']:.2f})\")\n", "\n", - "plt.title(r\"Extrinsic Curvature ($\\kappa$) of Hidden States during CoT\")\n", + "plt.title(r\"Extrinsic Curvature ($\\kappa$) of S-Space Trajectories during CoT\")\n", "plt.xlabel(\"Token Position in CoT\")\n", "plt.ylabel(r\"$\\kappa(t)$\")\n", "plt.legend()\n", diff --git a/experiment.py b/experiment.py index 73be640..efa362b 100644 --- a/experiment.py +++ b/experiment.py @@ -39,6 +39,59 @@ NUM_EXAMPLES = 3 # %% +def get_collective_s_space_weight(model): + """ + Gathers all weight matrices that write to the residual stream + (o_proj from attention and down_proj from MLP) across all layers, + and concatenates them to form a collective "write" transformation. + """ + Ws = [] + for layer in model.model.layers: + # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features] + # We want a combined matrix of shape [hidden_size, sum(in_features)] + # so that SVD gives U of shape [hidden_size, top_k]. + # o_proj.weight is [hidden_size, num_heads * head_dim] + # down_proj.weight is [hidden_size, intermediate_size] + Ws.append(layer.self_attn.o_proj.weight.detach().cpu()) + Ws.append(layer.mlp.down_proj.weight.detach().cpu()) + return torch.cat(Ws, dim=1).to(model.device) + +def project_to_s_space(hidden_states, W, top_k=256): + """ + Projects the residual stream into the 'super' S-space of all residual writers. + + Explanation: The residual stream doesn't change much, but gets suppressed in the + last 3-10% of layers. Since the residual stream interacts with all modules, + we get the 'super' S-space of all residual stream writers. By getting the + hidden states from the residual stream, and the U from all residual writers, + we can project the residual stream into S-space, which can be thought of as + something like the coordinate space of learned modes of behaviors. + + W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size) + """ + # SVD on the collective weight matrix + U, S, _ = torch.linalg.svd(W.float(), full_matrices=False) + + # Crop to top_k modes (there will be a lot of overlap/redundancy) + U = U[:, :top_k] + S = S[:top_k] + + # Project: x_S = (x @ U) + x_S = hidden_states.to(torch.float32) @ U + + # Align signs: flip U (and x_S) so the maximum projection is positive + # This standardizes the direction of the modes + signs = torch.sign(x_S.max(dim=0).values + x_S.min(dim=0).values) + # If the max absolute value was negative, signs will be -1, else 1 + signs[signs == 0] = 1.0 # prevent 0 multiplication + + x_S = x_S * signs + + # Scale by singular values + x_S = x_S * S + + return x_S + def compute_curvature(hidden_states): ''' Computes Frenet-Serret extrinsic curvature (kappa). @@ -61,7 +114,7 @@ def compute_curvature(hidden_states): # %% -def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"): +def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda", s_space_weight=None): messages = [{"role": "user", "content": prompt_text}] inputs = tokenizer.apply_chat_template( @@ -122,9 +175,14 @@ def guided_eval(model, tokenizer, prompt_text, n_think=32, device="cuda"): start_idx = prompt_ids.shape[1] cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]] + if s_space_weight is not None: + trajectory = project_to_s_space(cot_hiddens, s_space_weight) + else: + trajectory = cot_hiddens + return { "logratio": (p_yes - p_no).item(), - "kappa_trajectory": compute_curvature(cot_hiddens).cpu().numpy(), + "kappa_trajectory": compute_curvature(trajectory).cpu().numpy(), "prompt": tokenizer.decode(prompt_ids[0], skip_special_tokens=False), "generated_text": tokenizer.decode(generated_ids, skip_special_tokens=False) } @@ -159,9 +217,12 @@ print(f"\n--- Dilemma ---\n{item['dilemma_situation']}\nAction: {item['action']} plt.figure(figsize=(10, 6)) results = {} +# Project using the collective residual stream writers transformation +s_space_W = get_collective_s_space_weight(model) + for p_key, p_prefix in PERSONAS.items(): print(f"\n--- Running: {p_key} ---") - res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE) + res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W) results[p_key] = res print(f"Logratio (Yes/No): {res['logratio']:.3f}") print(f"Prompt:\n```md\n{res['prompt']}```") @@ -169,7 +230,7 @@ for p_key, p_prefix in PERSONAS.items(): plt.plot(res['kappa_trajectory'], label=f"{p_key} (logratio: {res['logratio']:.2f})") -plt.title(r"Extrinsic Curvature ($\kappa$) of Hidden States during CoT") +plt.title(r"Extrinsic Curvature ($\kappa$) of S-Space Trajectories during CoT") plt.xlabel("Token Position in CoT") plt.ylabel(r"$\kappa(t)$") plt.legend()