mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Refactor to return U, S, Vh from SVD and perform cropping outside the function
This commit is contained in:
+14
-14
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "62eec772",
|
"id": "e0ccb9f3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "47c7efe2",
|
"id": "f61786b3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -44,17 +44,17 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "8c590e6c",
|
"id": "b130793b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def get_s_space_projector(model, top_k=1024):\n",
|
"def get_s_space_svd(model):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" Gathers all weight matrices that write to the residual stream\n",
|
" Gathers all weight matrices that write to the residual stream\n",
|
||||||
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
||||||
" and concatenates them to form a collective \"write\" transformation.\n",
|
" and concatenates them to form a collective \"write\" transformation.\n",
|
||||||
" Then computes the SVD to extract the top_k modes.\n",
|
" Then computes and returns the full SVD.\n",
|
||||||
" Returns: U (hidden_size, top_k), S (top_k,)\n",
|
" Returns: U, S, Vh\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" Ws = []\n",
|
" Ws = []\n",
|
||||||
" for layer in model.model.layers:\n",
|
" for layer in model.model.layers:\n",
|
||||||
@@ -65,12 +65,9 @@
|
|||||||
" W = torch.cat(Ws, dim=1).to(model.device)\n",
|
" W = torch.cat(Ws, dim=1).to(model.device)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # SVD on the collective weight matrix\n",
|
" # SVD on the collective weight matrix\n",
|
||||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
" U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Crop to top_k modes\n",
|
" return U, S, Vh\n",
|
||||||
" U = U[:, :top_k]\n",
|
|
||||||
" S = S[:top_k]\n",
|
|
||||||
" return U, S\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"def project_to_s_space(hidden_states, U, S):\n",
|
"def project_to_s_space(hidden_states, U, S):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
@@ -123,7 +120,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "6ae905b2",
|
"id": "812eed4d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -205,7 +202,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "30e7fb4e",
|
"id": "57d967a9",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"lines_to_next_cell": 2
|
"lines_to_next_cell": 2
|
||||||
},
|
},
|
||||||
@@ -239,7 +236,10 @@
|
|||||||
"results = {}\n",
|
"results = {}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Project using the collective residual stream writers transformation\n",
|
"# Project using the collective residual stream writers transformation\n",
|
||||||
"s_space_U, s_space_S = get_s_space_projector(model)\n",
|
"U, S, Vh = get_s_space_svd(model)\n",
|
||||||
|
"top_k = 1024\n",
|
||||||
|
"s_space_U = U[:, :top_k]\n",
|
||||||
|
"s_space_S = S[:top_k]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for p_key, p_prefix in PERSONAS.items():\n",
|
"for p_key, p_prefix in PERSONAS.items():\n",
|
||||||
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
||||||
|
|||||||
+9
-9
@@ -39,13 +39,13 @@ NUM_EXAMPLES = 3
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def get_s_space_projector(model, top_k=1024):
|
def get_s_space_svd(model):
|
||||||
"""
|
"""
|
||||||
Gathers all weight matrices that write to the residual stream
|
Gathers all weight matrices that write to the residual stream
|
||||||
(o_proj from attention and down_proj from MLP) across all layers,
|
(o_proj from attention and down_proj from MLP) across all layers,
|
||||||
and concatenates them to form a collective "write" transformation.
|
and concatenates them to form a collective "write" transformation.
|
||||||
Then computes the SVD to extract the top_k modes.
|
Then computes and returns the full SVD.
|
||||||
Returns: U (hidden_size, top_k), S (top_k,)
|
Returns: U, S, Vh
|
||||||
"""
|
"""
|
||||||
Ws = []
|
Ws = []
|
||||||
for layer in model.model.layers:
|
for layer in model.model.layers:
|
||||||
@@ -56,12 +56,9 @@ def get_s_space_projector(model, top_k=1024):
|
|||||||
W = torch.cat(Ws, dim=1).to(model.device)
|
W = torch.cat(Ws, dim=1).to(model.device)
|
||||||
|
|
||||||
# SVD on the collective weight matrix
|
# SVD on the collective weight matrix
|
||||||
U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)
|
U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
|
||||||
|
|
||||||
# Crop to top_k modes
|
return U, S, Vh
|
||||||
U = U[:, :top_k]
|
|
||||||
S = S[:top_k]
|
|
||||||
return U, S
|
|
||||||
|
|
||||||
def project_to_s_space(hidden_states, U, S):
|
def project_to_s_space(hidden_states, U, S):
|
||||||
"""
|
"""
|
||||||
@@ -216,7 +213,10 @@ plt.figure(figsize=(10, 6))
|
|||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Project using the collective residual stream writers transformation
|
# Project using the collective residual stream writers transformation
|
||||||
s_space_U, s_space_S = get_s_space_projector(model)
|
U, S, Vh = get_s_space_svd(model)
|
||||||
|
top_k = 1024
|
||||||
|
s_space_U = U[:, :top_k]
|
||||||
|
s_space_S = S[:top_k]
|
||||||
|
|
||||||
for p_key, p_prefix in PERSONAS.items():
|
for p_key, p_prefix in PERSONAS.items():
|
||||||
print(f"\n--- Running: {p_key} ---")
|
print(f"\n--- Running: {p_key} ---")
|
||||||
|
|||||||
Reference in New Issue
Block a user