diff --git a/experiment.ipynb b/experiment.ipynb index 97f4c43..93958d4 100644 --- a/experiment.ipynb +++ b/experiment.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "62eec772", + "id": "e0ccb9f3", "metadata": {}, "source": [ "# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n", @@ -20,7 +20,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47c7efe2", + "id": "f61786b3", "metadata": {}, "outputs": [], "source": [ @@ -44,17 +44,17 @@ { "cell_type": "code", "execution_count": null, - "id": "8c590e6c", + "id": "b130793b", "metadata": {}, "outputs": [], "source": [ - "def get_s_space_projector(model, top_k=1024):\n", + "def get_s_space_svd(model):\n", " \"\"\"\n", " Gathers all weight matrices that write to the residual stream\n", " (o_proj from attention and down_proj from MLP) across all layers,\n", " and concatenates them to form a collective \"write\" transformation.\n", - " Then computes the SVD to extract the top_k modes.\n", - " Returns: U (hidden_size, top_k), S (top_k,)\n", + " Then computes and returns the full SVD.\n", + " Returns: U, S, Vh\n", " \"\"\"\n", " Ws = []\n", " for layer in model.model.layers:\n", @@ -65,12 +65,9 @@ " W = torch.cat(Ws, dim=1).to(model.device)\n", " \n", " # SVD on the collective weight matrix\n", - " U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n", + " U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)\n", " \n", - " # Crop to top_k modes\n", - " U = U[:, :top_k]\n", - " S = S[:top_k]\n", - " return U, S\n", + " return U, S, Vh\n", "\n", "def project_to_s_space(hidden_states, U, S):\n", " \"\"\"\n", @@ -123,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ae905b2", + "id": "812eed4d", "metadata": {}, "outputs": [], "source": [ @@ -205,7 +202,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30e7fb4e", + "id": "57d967a9", "metadata": { "lines_to_next_cell": 2 }, @@ -239,7 +236,10 @@ "results = {}\n", "\n", "# Project using the collective residual stream writers transformation\n", - "s_space_U, s_space_S = get_s_space_projector(model)\n", + "U, S, Vh = get_s_space_svd(model)\n", + "top_k = 1024\n", + "s_space_U = U[:, :top_k]\n", + "s_space_S = S[:top_k]\n", "\n", "for p_key, p_prefix in PERSONAS.items():\n", " print(f\"\\n--- Running: {p_key} ---\")\n", diff --git a/experiment.py b/experiment.py index badfd70..f359988 100644 --- a/experiment.py +++ b/experiment.py @@ -39,13 +39,13 @@ NUM_EXAMPLES = 3 # %% -def get_s_space_projector(model, top_k=1024): +def get_s_space_svd(model): """ Gathers all weight matrices that write to the residual stream (o_proj from attention and down_proj from MLP) across all layers, and concatenates them to form a collective "write" transformation. - Then computes the SVD to extract the top_k modes. - Returns: U (hidden_size, top_k), S (top_k,) + Then computes and returns the full SVD. + Returns: U, S, Vh """ Ws = [] for layer in model.model.layers: @@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024): W = torch.cat(Ws, dim=1).to(model.device) # SVD on the collective weight matrix - U, S, _ = torch.linalg.svd(W.float(), full_matrices=False) + U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False) - # Crop to top_k modes - U = U[:, :top_k] - S = S[:top_k] - return U, S + return U, S, Vh def project_to_s_space(hidden_states, U, S): """ @@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6)) results = {} # Project using the collective residual stream writers transformation -s_space_U, s_space_S = get_s_space_projector(model) +U, S, Vh = get_s_space_svd(model) +top_k = 1024 +s_space_U = U[:, :top_k] +s_space_S = S[:top_k] for p_key, p_prefix in PERSONAS.items(): print(f"\n--- Running: {p_key} ---")