mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Refactor to return U, S, Vh from SVD and perform cropping outside the function
This commit is contained in:
+14
-14
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62eec772",
|
||||
"id": "e0ccb9f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||
@@ -20,7 +20,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47c7efe2",
|
||||
"id": "f61786b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -44,17 +44,17 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c590e6c",
|
||||
"id": "b130793b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_s_space_projector(model, top_k=1024):\n",
|
||||
"def get_s_space_svd(model):\n",
|
||||
" \"\"\"\n",
|
||||
" Gathers all weight matrices that write to the residual stream\n",
|
||||
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
||||
" and concatenates them to form a collective \"write\" transformation.\n",
|
||||
" Then computes the SVD to extract the top_k modes.\n",
|
||||
" Returns: U (hidden_size, top_k), S (top_k,)\n",
|
||||
" Then computes and returns the full SVD.\n",
|
||||
" Returns: U, S, Vh\n",
|
||||
" \"\"\"\n",
|
||||
" Ws = []\n",
|
||||
" for layer in model.model.layers:\n",
|
||||
@@ -65,12 +65,9 @@
|
||||
" W = torch.cat(Ws, dim=1).to(model.device)\n",
|
||||
" \n",
|
||||
" # SVD on the collective weight matrix\n",
|
||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" \n",
|
||||
" # Crop to top_k modes\n",
|
||||
" U = U[:, :top_k]\n",
|
||||
" S = S[:top_k]\n",
|
||||
" return U, S\n",
|
||||
" return U, S, Vh\n",
|
||||
"\n",
|
||||
"def project_to_s_space(hidden_states, U, S):\n",
|
||||
" \"\"\"\n",
|
||||
@@ -123,7 +120,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ae905b2",
|
||||
"id": "812eed4d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -205,7 +202,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30e7fb4e",
|
||||
"id": "57d967a9",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
@@ -239,7 +236,10 @@
|
||||
"results = {}\n",
|
||||
"\n",
|
||||
"# Project using the collective residual stream writers transformation\n",
|
||||
"s_space_U, s_space_S = get_s_space_projector(model)\n",
|
||||
"U, S, Vh = get_s_space_svd(model)\n",
|
||||
"top_k = 1024\n",
|
||||
"s_space_U = U[:, :top_k]\n",
|
||||
"s_space_S = S[:top_k]\n",
|
||||
"\n",
|
||||
"for p_key, p_prefix in PERSONAS.items():\n",
|
||||
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
||||
|
||||
Reference in New Issue
Block a user