Implement super S-space projection across all residual writers

This commit is contained in:
wassname
2026-04-10 09:38:22 +08:00
parent bde29fee1e
commit a1a8648865
2 changed files with 135 additions and 13 deletions
+70 -9
View File
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "bafde1f1",
"id": "76397472",
"metadata": {},
"source": [
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
@@ -20,7 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a30577fd",
"id": "703c8e38",
"metadata": {},
"outputs": [],
"source": [
@@ -44,10 +44,63 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a3b31f38",
"id": "e6f56583",
"metadata": {},
"outputs": [],
"source": [
"def get_collective_s_space_weight(model):\n",
" \"\"\"\n",
" Gathers all weight matrices that write to the residual stream\n",
" (o_proj from attention and down_proj from MLP) across all layers,\n",
" and concatenates them to form a collective \"write\" transformation.\n",
" \"\"\"\n",
" Ws = []\n",
" for layer in model.model.layers:\n",
" # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n",
" # We want a combined matrix of shape [hidden_size, sum(in_features)]\n",
" # so that SVD gives U of shape [hidden_size, top_k].\n",
" # o_proj.weight is [hidden_size, num_heads * head_dim]\n",
" # down_proj.weight is [hidden_size, intermediate_size]\n",
" Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n",
" Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n",
" return torch.cat(Ws, dim=1).to(model.device)\n",
"\n",
"def project_to_s_space(hidden_states, W, top_k=256):\n",
" \"\"\"\n",
" Projects the residual stream into the 'super' S-space of all residual writers.\n",
" \n",
" Explanation: The residual stream doesn't change much, but gets suppressed in the \n",
" last 3-10% of layers. Since the residual stream interacts with all modules, \n",
" we get the 'super' S-space of all residual stream writers. By getting the \n",
" hidden states from the residual stream, and the U from all residual writers, \n",
" we can project the residual stream into S-space, which can be thought of as \n",
" something like the coordinate space of learned modes of behaviors.\n",
" \n",
" W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n",
" \"\"\"\n",
" # SVD on the collective weight matrix\n",
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
" \n",
" # Crop to top_k modes (there will be a lot of overlap/redundancy)\n",
" U = U[:, :top_k]\n",
" S = S[:top_k]\n",
" \n",
" # Project: x_S = (x @ U)\n",
" x_S = hidden_states.to(torch.float32) @ U\n",
" \n",
" # Align signs: flip U (and x_S) so the maximum projection is positive\n",
" # This standardizes the direction of the modes\n",
" signs = torch.sign(x_S.max(dim=0).values + x_S.min(dim=0).values) \n",
" # If the max absolute value was negative, signs will be -1, else 1\n",
" signs[signs == 0] = 1.0 # prevent 0 multiplication\n",
" \n",
" x_S = x_S * signs\n",
" \n",
" # Scale by singular values\n",
" x_S = x_S * S\n",
" \n",
" return x_S\n",
"\n",
"def compute_curvature(hidden_states):\n",
" '''\n",
" Computes Frenet-Serret extrinsic curvature (kappa).\n",
@@ -72,11 +125,11 @@
{
"cell_type": "code",
"execution_count": null,
"id": "675b1b52",
"id": "7cbb2a21",
"metadata": {},
"outputs": [],
"source": [
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\"):\n",
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n",
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
" \n",
" inputs = tokenizer.apply_chat_template(\n",
@@ -137,9 +190,14 @@
" start_idx = prompt_ids.shape[1]\n",
" cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n",
" \n",
" if s_space_weight is not None:\n",
" trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n",
" else:\n",
" trajectory = cot_hiddens\n",
" \n",
" return {\n",
" \"logratio\": (p_yes - p_no).item(),\n",
" \"kappa_trajectory\": compute_curvature(cot_hiddens).cpu().numpy(),\n",
" \"kappa_trajectory\": compute_curvature(trajectory).cpu().numpy(),\n",
" \"prompt\": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),\n",
" \"generated_text\": tokenizer.decode(generated_ids, skip_special_tokens=False)\n",
" }\n",
@@ -149,7 +207,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9e2e698d",
"id": "812de2b3",
"metadata": {
"lines_to_next_cell": 2
},
@@ -182,9 +240,12 @@
"plt.figure(figsize=(10, 6))\n",
"results = {}\n",
"\n",
"# Project using the collective residual stream writers transformation\n",
"s_space_W = get_collective_s_space_weight(model) \n",
"\n",
"for p_key, p_prefix in PERSONAS.items():\n",
" print(f\"\\n--- Running: {p_key} ---\")\n",
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)\n",
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n",
" results[p_key] = res\n",
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
" print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n",
@@ -192,7 +253,7 @@
" \n",
" plt.plot(res['kappa_trajectory'], label=f\"{p_key} (logratio: {res['logratio']:.2f})\")\n",
"\n",
"plt.title(r\"Extrinsic Curvature ($\\kappa$) of Hidden States during CoT\")\n",
"plt.title(r\"Extrinsic Curvature ($\\kappa$) of S-Space Trajectories during CoT\")\n",
"plt.xlabel(\"Token Position in CoT\")\n",
"plt.ylabel(r\"$\\kappa(t)$\")\n",
"plt.legend()\n",