mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Implement super S-space projection across all residual writers
This commit is contained in:
+70
-9
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bafde1f1",
|
||||
"id": "76397472",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Brukino's AntiPaSTO Appetizer: Guided CoT Eval & Frenet-Serret Curvature\n",
|
||||
@@ -20,7 +20,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a30577fd",
|
||||
"id": "703c8e38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -44,10 +44,63 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a3b31f38",
|
||||
"id": "e6f56583",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_collective_s_space_weight(model):\n",
|
||||
" \"\"\"\n",
|
||||
" Gathers all weight matrices that write to the residual stream\n",
|
||||
" (o_proj from attention and down_proj from MLP) across all layers,\n",
|
||||
" and concatenates them to form a collective \"write\" transformation.\n",
|
||||
" \"\"\"\n",
|
||||
" Ws = []\n",
|
||||
" for layer in model.model.layers:\n",
|
||||
" # In Qwen2, o_proj and down_proj weights are shape [hidden_size, in_features]\n",
|
||||
" # We want a combined matrix of shape [hidden_size, sum(in_features)]\n",
|
||||
" # so that SVD gives U of shape [hidden_size, top_k].\n",
|
||||
" # o_proj.weight is [hidden_size, num_heads * head_dim]\n",
|
||||
" # down_proj.weight is [hidden_size, intermediate_size]\n",
|
||||
" Ws.append(layer.self_attn.o_proj.weight.detach().cpu())\n",
|
||||
" Ws.append(layer.mlp.down_proj.weight.detach().cpu())\n",
|
||||
" return torch.cat(Ws, dim=1).to(model.device)\n",
|
||||
"\n",
|
||||
"def project_to_s_space(hidden_states, W, top_k=256):\n",
|
||||
" \"\"\"\n",
|
||||
" Projects the residual stream into the 'super' S-space of all residual writers.\n",
|
||||
" \n",
|
||||
" Explanation: The residual stream doesn't change much, but gets suppressed in the \n",
|
||||
" last 3-10% of layers. Since the residual stream interacts with all modules, \n",
|
||||
" we get the 'super' S-space of all residual stream writers. By getting the \n",
|
||||
" hidden states from the residual stream, and the U from all residual writers, \n",
|
||||
" we can project the residual stream into S-space, which can be thought of as \n",
|
||||
" something like the coordinate space of learned modes of behaviors.\n",
|
||||
" \n",
|
||||
" W: The concatenated weight matrix of all residual writers, shape (sum(in_features), hidden_size)\n",
|
||||
" \"\"\"\n",
|
||||
" # SVD on the collective weight matrix\n",
|
||||
" U, S, _ = torch.linalg.svd(W.float(), full_matrices=False)\n",
|
||||
" \n",
|
||||
" # Crop to top_k modes (there will be a lot of overlap/redundancy)\n",
|
||||
" U = U[:, :top_k]\n",
|
||||
" S = S[:top_k]\n",
|
||||
" \n",
|
||||
" # Project: x_S = (x @ U)\n",
|
||||
" x_S = hidden_states.to(torch.float32) @ U\n",
|
||||
" \n",
|
||||
" # Align signs: flip U (and x_S) so the maximum projection is positive\n",
|
||||
" # This standardizes the direction of the modes\n",
|
||||
" signs = torch.sign(x_S.max(dim=0).values + x_S.min(dim=0).values) \n",
|
||||
" # If the max absolute value was negative, signs will be -1, else 1\n",
|
||||
" signs[signs == 0] = 1.0 # prevent 0 multiplication\n",
|
||||
" \n",
|
||||
" x_S = x_S * signs\n",
|
||||
" \n",
|
||||
" # Scale by singular values\n",
|
||||
" x_S = x_S * S\n",
|
||||
" \n",
|
||||
" return x_S\n",
|
||||
"\n",
|
||||
"def compute_curvature(hidden_states):\n",
|
||||
" '''\n",
|
||||
" Computes Frenet-Serret extrinsic curvature (kappa).\n",
|
||||
@@ -72,11 +125,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "675b1b52",
|
||||
"id": "7cbb2a21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\"):\n",
|
||||
"def guided_eval(model, tokenizer, prompt_text, n_think=32, device=\"cuda\", s_space_weight=None):\n",
|
||||
" messages = [{\"role\": \"user\", \"content\": prompt_text}]\n",
|
||||
" \n",
|
||||
" inputs = tokenizer.apply_chat_template(\n",
|
||||
@@ -137,9 +190,14 @@
|
||||
" start_idx = prompt_ids.shape[1]\n",
|
||||
" cot_hiddens = final_layer_hiddens[start_idx : start_idx + generated_ids.shape[0]]\n",
|
||||
" \n",
|
||||
" if s_space_weight is not None:\n",
|
||||
" trajectory = project_to_s_space(cot_hiddens, s_space_weight)\n",
|
||||
" else:\n",
|
||||
" trajectory = cot_hiddens\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" \"logratio\": (p_yes - p_no).item(),\n",
|
||||
" \"kappa_trajectory\": compute_curvature(cot_hiddens).cpu().numpy(),\n",
|
||||
" \"kappa_trajectory\": compute_curvature(trajectory).cpu().numpy(),\n",
|
||||
" \"prompt\": tokenizer.decode(prompt_ids[0], skip_special_tokens=False),\n",
|
||||
" \"generated_text\": tokenizer.decode(generated_ids, skip_special_tokens=False)\n",
|
||||
" }\n",
|
||||
@@ -149,7 +207,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e2e698d",
|
||||
"id": "812de2b3",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
@@ -182,9 +240,12 @@
|
||||
"plt.figure(figsize=(10, 6))\n",
|
||||
"results = {}\n",
|
||||
"\n",
|
||||
"# Project using the collective residual stream writers transformation\n",
|
||||
"s_space_W = get_collective_s_space_weight(model) \n",
|
||||
"\n",
|
||||
"for p_key, p_prefix in PERSONAS.items():\n",
|
||||
" print(f\"\\n--- Running: {p_key} ---\")\n",
|
||||
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE)\n",
|
||||
" res = guided_eval(model, tokenizer, p_prefix + prompt_base, n_think=N_THINK_TOKENS, device=DEVICE, s_space_weight=s_space_W)\n",
|
||||
" results[p_key] = res\n",
|
||||
" print(f\"Logratio (Yes/No): {res['logratio']:.3f}\")\n",
|
||||
" print(f\"Prompt:\\n```md\\n{res['prompt']}```\")\n",
|
||||
@@ -192,7 +253,7 @@
|
||||
" \n",
|
||||
" plt.plot(res['kappa_trajectory'], label=f\"{p_key} (logratio: {res['logratio']:.2f})\")\n",
|
||||
"\n",
|
||||
"plt.title(r\"Extrinsic Curvature ($\\kappa$) of Hidden States during CoT\")\n",
|
||||
"plt.title(r\"Extrinsic Curvature ($\\kappa$) of S-Space Trajectories during CoT\")\n",
|
||||
"plt.xlabel(\"Token Position in CoT\")\n",
|
||||
"plt.ylabel(r\"$\\kappa(t)$\")\n",
|
||||
"plt.legend()\n",
|
||||
|
||||
Reference in New Issue
Block a user