This commit is contained in:
wassname
2025-03-12 13:55:47 +08:00
parent 72e20deba3
commit bcd47da026
+8 -387
View File
@@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -45,7 +45,11 @@
" device_map=\"auto\",\n",
" attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"if tokenizer.pad_token_id is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.paddding_side = \"left\"\n",
"tokenizer.truncation_side = \"left\""
]
},
{
@@ -116,15 +120,8 @@
],
"source": [
"from torch.utils.data import DataLoader\n",
"def collate_fn(examples):\n",
" # Pad the batch to max length within this batch\n",
" return tokenizer.pad(\n",
" examples,\n",
" padding=True,\n",
" return_tensors=\"pt\",\n",
" max_length=max_length, \n",
" truncation=True,\n",
" )\n",
"from transformers.data import DataCollatorForLanguageModeling\n",
"collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
"ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n",
"print(ds)\n"
]
@@ -294,382 +291,6 @@
"source": [
"ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get supressed activations"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"from jaxtyping import Float, Int\n",
"from torch import Tensor\n",
"from einops import rearrange\n",
"\n",
"\n",
"def get_supressed_activations(\n",
" hs: Float[Tensor, \"l b t h\"], w_out, w_inv\n",
") -> Float[Tensor, \"l b t h\"]:\n",
" \"\"\"\n",
" Novel experiment: Here we define a transform to isolate supressed activations, where we hypothesis that style/concepts/scratchpads and other internal only representations must be stored.\n",
"\n",
" See the following references for more information:\n",
"\n",
" - https://arxiv.org/pdf/2401.12181\n",
" - > Suppression neurons that are similar, except decrease the probability of a group of related tokens\n",
"\n",
" - https://arxiv.org/html/2406.19384\n",
" - > Previous work suggests that networks contain ensembles of “prediction\" neurons, which act as probability promoters [66, 24, 32] and work in tandem with suppression neurons (Section 5.4).\n",
"\n",
" - https://arxiv.org/pdf/2401.12181\n",
" > We find a striking pattern which is remarkably consistent across the different seeds: after about the halfway point in the model, prediction neurons become increasingly prevalent until the very end of the network where there is a sudden shift towards a much larger number of suppression neurons.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" # here we pass the hs through the last layer, take a diff, and then project it back to find which activation changes lead to supressed\n",
" hs2 = rearrange(hs[:, :, -1:], \"l b t h -> (l b t) h\")\n",
" hs_out2 = torch.nn.functional.linear(hs2, w_out)\n",
" hs_out = rearrange(\n",
" hs_out2, \"(l b t) h -> l b t h\", l=hs.shape[0], b=hs.shape[1], t=1\n",
" )\n",
" diffs = hs_out[:, :, :].diff(dim=0)\n",
" diffs2 = rearrange(diffs, \"l b t h -> (l b t) h\")\n",
" # W_inv = get_cache_inv(w_out)\n",
"\n",
" diffs_inv2 = torch.nn.functional.linear(diffs2.to(dtype=w_inv.dtype), w_inv)\n",
" diffs_inv = rearrange(\n",
" diffs_inv2, \"(l b t) h -> l b t h\", l=hs.shape[0] - 1, b=hs.shape[1], t=1\n",
" ).to(w_out.dtype)\n",
" # TODO just return this?\n",
" eps = 1.0e-1\n",
" supressed_mask = (diffs_inv < -eps).to(hs.dtype)\n",
" # supressed_mask = repeat(supressed_mask, 'l b 1 h -> l b t h', t=hs.shape[2])\n",
" supressed_act = hs[1:] * supressed_mask\n",
" return supressed_act"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from activation_store.collect import default_postprocess_result\n",
"\n",
"Wo = model.get_output_embeddings().weight.detach().clone().cpu()\n",
"Wo_inv = torch.pinverse(Wo.clone().float())\n",
"\n",
"@torch.no_grad()\n",
"def sup_postproc(input, trace, output, model):\n",
"\n",
" \n",
" o = default_postprocess_result(input, trace, output, model)\n",
" \n",
" hs = o.pop('hidden_states')\n",
" hs = rearrange(hs, \"b l t h -> l b t h\")\n",
" hs_s = get_supressed_activations(hs, Wo.to(hs.dtype), Wo_inv.to(hs.dtype))\n",
" hs_s = rearrange(hs_s, \"l b t h -> b l t h\")\n",
" o['hidden_states_supressed'] = hs_s.half()\n",
" \n",
" return o\n"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-02-16 09:52:12.917\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m78\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0c38f37f9934a0dbe7086b695624548",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"collecting activations: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet')"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f2 = activation_store(ds, model, postprocess_result=sup_postproc)\n",
"f2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2380a4d27fdb42e9a61adb67bd221cd8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['attention_mask', 'logits', 'hidden_states_supressed'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2 = Dataset.from_parquet(str(f2)).with_format(\"torch\")\n",
"ds_a2"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 24, 453, 896])"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2.info\n",
"ds_a2[0:2]['hidden_states_supressed'].shape"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[-0.0000e+00, -0.0000e+00, -2.9629e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -1.9275e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 9.8419e-03, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, -0.0000e+00, 3.3855e-03, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -9.2627e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -1.6125e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.0000e+00, -2.7559e+00, -5.6543e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 1.0596e+00],\n",
" [-0.0000e+00, 6.6846e-01, -3.3643e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 4.6387e-01],\n",
" [ 0.0000e+00, 7.1484e-01, 2.5854e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 1.4053e+00],\n",
" ...,\n",
" [ 0.0000e+00, 1.0898e+00, 4.5239e-01, ..., 0.0000e+00,\n",
" -0.0000e+00, -2.1033e-01],\n",
" [-0.0000e+00, 2.1504e+00, -1.2415e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.3428e+00],\n",
" [-0.0000e+00, 2.8262e+00, 5.9277e-01, ..., 0.0000e+00,\n",
" -0.0000e+00, -5.5811e-01]],\n",
"\n",
" [[ 0.0000e+00, -2.4727e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -7.0947e-01, 0.0000e+00],\n",
" [-0.0000e+00, 1.3535e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 1.1560e-01, -0.0000e+00],\n",
" [ 0.0000e+00, 1.1572e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 1.0020e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 1.0254e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -5.9326e-02, -0.0000e+00],\n",
" [-0.0000e+00, 2.2461e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -2.3105e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 1.5420e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -5.2979e-01, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -2.5605e+00, 6.3984e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -1.1279e+00, -1.1766e+01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" 1.3135e+00, -3.2148e+00],\n",
" ...,\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 7.1289e-01, -8.5234e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -1.0047e+01, 1.1797e+01],\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -1.9268e+00, -9.5078e+00]]],\n",
"\n",
"\n",
" [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 1.1493e-01,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 7.6843e-02,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.0968e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.1094e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5928e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5220e-01,\n",
" -0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[-0.0000e+00, -8.0391e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 3.9531e+00],\n",
" [-0.0000e+00, -4.0186e-01, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.0168e-01],\n",
" [-0.0000e+00, 1.7993e-01, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 3.4561e-03],\n",
" ...,\n",
" [ 0.0000e+00, -2.8122e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -6.9031e-02],\n",
" [ 0.0000e+00, -6.8909e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -6.5613e-02],\n",
" [ 0.0000e+00, -5.5481e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -9.4604e-02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 4.4263e-01, -0.0000e+00, -5.6543e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-2.8281e+00, 0.0000e+00, -3.3643e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [ 6.2225e-02, 0.0000e+00, 2.5854e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [-8.7598e-01, 0.0000e+00, 5.7422e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-1.0537e+00, 0.0000e+00, 3.3667e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-1.1729e+00, 0.0000e+00, 2.0654e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.6553e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -3.4204e-01],\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 2.3914e-01],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -8.4375e-01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.8193e-01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.7900e-01]],\n",
"\n",
" [[ 0.0000e+00, -6.5742e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -0.0000e+00, 6.3984e+00],\n",
" [-0.0000e+00, -3.4448e-01, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -1.1766e+01],\n",
" [-0.0000e+00, 6.8408e-01, 0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -3.2148e+00],\n",
" ...,\n",
" [ 0.0000e+00, -1.3831e-01, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -1.2625e+01],\n",
" [ 0.0000e+00, -1.0439e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.9453e+00],\n",
" [ 0.0000e+00, -1.5166e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -1.1367e+01]]]])"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2[0:2]['hidden_states_supressed']"
]
}
],
"metadata": {