another example

This commit is contained in:
wassname
2025-02-16 09:56:23 +08:00
parent 2f82c4bdec
commit f0311c022f
2 changed files with 381 additions and 4 deletions
+3 -2
View File
@@ -18,14 +18,15 @@ from torch import Tensor
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]:
"""add activations to output, and rearrange hidden states"""
"""Make your own. This adds activations to output, and rearranges hidden states"""
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
acts = {f'act-{k}':
v.output[0] if isinstance(v.output, tuple) else v.output
for k, v in trace.items()}
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')
# batch must be first, also the writer supports float16 so lets use that
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h').half()
return dict(attention_mask=input["attention_mask"], **acts, **output)
+378 -2
View File
@@ -181,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -224,7 +224,7 @@
}
],
"source": [
"f = activation_store(ds, model, layers=layers, writer_batch_size=10)\n",
"f = activation_store(ds, model, layers=layers)\n",
"f"
]
},
@@ -292,6 +292,382 @@
"source": [
"ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get supressed activations"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"from jaxtyping import Float, Int\n",
"from torch import Tensor\n",
"from einops import rearrange\n",
"\n",
"\n",
"def get_supressed_activations(\n",
" hs: Float[Tensor, \"l b t h\"], w_out, w_inv\n",
") -> Float[Tensor, \"l b t h\"]:\n",
" \"\"\"\n",
" Novel experiment: Here we define a transform to isolate supressed activations, where we hypothesis that style/concepts/scratchpads and other internal only representations must be stored.\n",
"\n",
" See the following references for more information:\n",
"\n",
" - https://arxiv.org/pdf/2401.12181\n",
" - > Suppression neurons that are similar, except decrease the probability of a group of related tokens\n",
"\n",
" - https://arxiv.org/html/2406.19384\n",
" - > Previous work suggests that networks contain ensembles of “prediction\" neurons, which act as probability promoters [66, 24, 32] and work in tandem with suppression neurons (Section 5.4).\n",
"\n",
" - https://arxiv.org/pdf/2401.12181\n",
" > We find a striking pattern which is remarkably consistent across the different seeds: after about the halfway point in the model, prediction neurons become increasingly prevalent until the very end of the network where there is a sudden shift towards a much larger number of suppression neurons.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" # here we pass the hs through the last layer, take a diff, and then project it back to find which activation changes lead to supressed\n",
" hs2 = rearrange(hs[:, :, -1:], \"l b t h -> (l b t) h\")\n",
" hs_out2 = torch.nn.functional.linear(hs2, w_out)\n",
" hs_out = rearrange(\n",
" hs_out2, \"(l b t) h -> l b t h\", l=hs.shape[0], b=hs.shape[1], t=1\n",
" )\n",
" diffs = hs_out[:, :, :].diff(dim=0)\n",
" diffs2 = rearrange(diffs, \"l b t h -> (l b t) h\")\n",
" # W_inv = get_cache_inv(w_out)\n",
"\n",
" diffs_inv2 = torch.nn.functional.linear(diffs2.to(dtype=w_inv.dtype), w_inv)\n",
" diffs_inv = rearrange(\n",
" diffs_inv2, \"(l b t) h -> l b t h\", l=hs.shape[0] - 1, b=hs.shape[1], t=1\n",
" ).to(w_out.dtype)\n",
" # TODO just return this?\n",
" eps = 1.0e-1\n",
" supressed_mask = (diffs_inv < -eps).to(hs.dtype)\n",
" # supressed_mask = repeat(supressed_mask, 'l b 1 h -> l b t h', t=hs.shape[2])\n",
" supressed_act = hs[1:] * supressed_mask\n",
" return supressed_act"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from activation_store.collect import default_postprocess_result\n",
"\n",
"Wo = model.get_output_embeddings().weight.detach().clone().cpu()\n",
"Wo_inv = torch.pinverse(Wo.clone().float())\n",
"\n",
"@torch.no_grad()\n",
"def sup_postproc(input, trace, output, model):\n",
"\n",
" \n",
" o = default_postprocess_result(input, trace, output, model)\n",
" \n",
" hs = o.pop('hidden_states')\n",
" hs = rearrange(hs, \"b l t h -> l b t h\")\n",
" hs_s = get_supressed_activations(hs, Wo.to(hs.dtype), Wo_inv.to(hs.dtype))\n",
" hs_s = rearrange(hs_s, \"l b t h -> b l t h\")\n",
" o['hidden_states_supressed'] = hs_s.half()\n",
" \n",
" return o\n"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-02-16 09:52:12.917\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m78\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0c38f37f9934a0dbe7086b695624548",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"collecting activations: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet')"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f2 = activation_store(ds, model, postprocess_result=sup_postproc)\n",
"f2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2380a4d27fdb42e9a61adb67bd221cd8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['attention_mask', 'logits', 'hidden_states_supressed'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2 = Dataset.from_parquet(str(f2)).with_format(\"torch\")\n",
"ds_a2"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 24, 453, 896])"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2.info\n",
"ds_a2[0:2]['hidden_states_supressed'].shape"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[-0.0000e+00, -0.0000e+00, -2.9629e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -1.9275e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 9.8419e-03, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, -0.0000e+00, 3.3855e-03, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -9.2627e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, -1.6125e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.0000e+00, -2.7559e+00, -5.6543e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 1.0596e+00],\n",
" [-0.0000e+00, 6.6846e-01, -3.3643e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 4.6387e-01],\n",
" [ 0.0000e+00, 7.1484e-01, 2.5854e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 1.4053e+00],\n",
" ...,\n",
" [ 0.0000e+00, 1.0898e+00, 4.5239e-01, ..., 0.0000e+00,\n",
" -0.0000e+00, -2.1033e-01],\n",
" [-0.0000e+00, 2.1504e+00, -1.2415e-01, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.3428e+00],\n",
" [-0.0000e+00, 2.8262e+00, 5.9277e-01, ..., 0.0000e+00,\n",
" -0.0000e+00, -5.5811e-01]],\n",
"\n",
" [[ 0.0000e+00, -2.4727e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -7.0947e-01, 0.0000e+00],\n",
" [-0.0000e+00, 1.3535e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 1.1560e-01, -0.0000e+00],\n",
" [ 0.0000e+00, 1.1572e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 1.0020e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 1.0254e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -5.9326e-02, -0.0000e+00],\n",
" [-0.0000e+00, 2.2461e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -2.3105e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 1.5420e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -5.2979e-01, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -2.5605e+00, 6.3984e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -1.1279e+00, -1.1766e+01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" 1.3135e+00, -3.2148e+00],\n",
" ...,\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 7.1289e-01, -8.5234e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -1.0047e+01, 1.1797e+01],\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -1.9268e+00, -9.5078e+00]]],\n",
"\n",
"\n",
" [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 1.1493e-01,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 7.6843e-02,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.0968e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.1094e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5928e-01,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5220e-01,\n",
" -0.0000e+00, -0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -0.0000e+00],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[-0.0000e+00, -8.0391e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 3.9531e+00],\n",
" [-0.0000e+00, -4.0186e-01, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.0168e-01],\n",
" [-0.0000e+00, 1.7993e-01, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 3.4561e-03],\n",
" ...,\n",
" [ 0.0000e+00, -2.8122e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -6.9031e-02],\n",
" [ 0.0000e+00, -6.8909e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -6.5613e-02],\n",
" [ 0.0000e+00, -5.5481e-02, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -9.4604e-02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 4.4263e-01, -0.0000e+00, -5.6543e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-2.8281e+00, 0.0000e+00, -3.3643e-01, ..., -0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [ 6.2225e-02, 0.0000e+00, 2.5854e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [-8.7598e-01, 0.0000e+00, 5.7422e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-1.0537e+00, 0.0000e+00, 3.3667e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [-1.1729e+00, 0.0000e+00, 2.0654e-01, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, 1.6553e+00],\n",
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -3.4204e-01],\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 2.3914e-01],\n",
" ...,\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -8.4375e-01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.8193e-01],\n",
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.7900e-01]],\n",
"\n",
" [[ 0.0000e+00, -6.5742e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" -0.0000e+00, 6.3984e+00],\n",
" [-0.0000e+00, -3.4448e-01, -0.0000e+00, ..., -0.0000e+00,\n",
" -0.0000e+00, -1.1766e+01],\n",
" [-0.0000e+00, 6.8408e-01, 0.0000e+00, ..., -0.0000e+00,\n",
" 0.0000e+00, -3.2148e+00],\n",
" ...,\n",
" [ 0.0000e+00, -1.3831e-01, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -1.2625e+01],\n",
" [ 0.0000e+00, -1.0439e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -9.9453e+00],\n",
" [ 0.0000e+00, -1.5166e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, -1.1367e+01]]]])"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a2[0:2]['hidden_states_supressed']"
]
}
],
"metadata": {