diff --git a/nbs/example.ipynb b/nbs/example.ipynb index e67af44..189ef56 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +45,11 @@ " device_map=\"auto\",\n", " attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n", ")\n", - "tokenizer = AutoTokenizer.from_pretrained(model_name)\n" + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "if tokenizer.pad_token_id is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "tokenizer.paddding_side = \"left\"\n", + "tokenizer.truncation_side = \"left\"" ] }, { @@ -116,15 +120,8 @@ ], "source": [ "from torch.utils.data import DataLoader\n", - "def collate_fn(examples):\n", - " # Pad the batch to max length within this batch\n", - " return tokenizer.pad(\n", - " examples,\n", - " padding=True,\n", - " return_tensors=\"pt\",\n", - " max_length=max_length, \n", - " truncation=True,\n", - " )\n", + "from transformers.data import DataCollatorForLanguageModeling\n", + "collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n", "ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n", "print(ds)\n" ] @@ -294,382 +291,6 @@ "source": [ "ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get supressed activations" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "metadata": {}, - "outputs": [], - "source": [ - "from jaxtyping import Float, Int\n", - "from torch import Tensor\n", - "from einops import rearrange\n", - "\n", - "\n", - "def get_supressed_activations(\n", - " hs: Float[Tensor, \"l b t h\"], w_out, w_inv\n", - ") -> Float[Tensor, \"l b t h\"]:\n", - " \"\"\"\n", - " Novel experiment: Here we define a transform to isolate supressed activations, where we hypothesis that style/concepts/scratchpads and other internal only representations must be stored.\n", - "\n", - " See the following references for more information:\n", - "\n", - " - https://arxiv.org/pdf/2401.12181\n", - " - > Suppression neurons that are similar, except decrease the probability of a group of related tokens\n", - "\n", - " - https://arxiv.org/html/2406.19384\n", - " - > Previous work suggests that networks contain ensembles of “prediction\" neurons, which act as probability promoters [66, 24, 32] and work in tandem with suppression neurons (Section 5.4).\n", - "\n", - " - https://arxiv.org/pdf/2401.12181\n", - " > We find a striking pattern which is remarkably consistent across the different seeds: after about the halfway point in the model, prediction neurons become increasingly prevalent until the very end of the network where there is a sudden shift towards a much larger number of suppression neurons.\n", - " \"\"\"\n", - " with torch.no_grad():\n", - " # here we pass the hs through the last layer, take a diff, and then project it back to find which activation changes lead to supressed\n", - " hs2 = rearrange(hs[:, :, -1:], \"l b t h -> (l b t) h\")\n", - " hs_out2 = torch.nn.functional.linear(hs2, w_out)\n", - " hs_out = rearrange(\n", - " hs_out2, \"(l b t) h -> l b t h\", l=hs.shape[0], b=hs.shape[1], t=1\n", - " )\n", - " diffs = hs_out[:, :, :].diff(dim=0)\n", - " diffs2 = rearrange(diffs, \"l b t h -> (l b t) h\")\n", - " # W_inv = get_cache_inv(w_out)\n", - "\n", - " diffs_inv2 = torch.nn.functional.linear(diffs2.to(dtype=w_inv.dtype), w_inv)\n", - " diffs_inv = rearrange(\n", - " diffs_inv2, \"(l b t) h -> l b t h\", l=hs.shape[0] - 1, b=hs.shape[1], t=1\n", - " ).to(w_out.dtype)\n", - " # TODO just return this?\n", - " eps = 1.0e-1\n", - " supressed_mask = (diffs_inv < -eps).to(hs.dtype)\n", - " # supressed_mask = repeat(supressed_mask, 'l b 1 h -> l b t h', t=hs.shape[2])\n", - " supressed_act = hs[1:] * supressed_mask\n", - " return supressed_act" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from activation_store.collect import default_postprocess_result\n", - "\n", - "Wo = model.get_output_embeddings().weight.detach().clone().cpu()\n", - "Wo_inv = torch.pinverse(Wo.clone().float())\n", - "\n", - "@torch.no_grad()\n", - "def sup_postproc(input, trace, output, model):\n", - "\n", - " \n", - " o = default_postprocess_result(input, trace, output, model)\n", - " \n", - " hs = o.pop('hidden_states')\n", - " hs = rearrange(hs, \"b l t h -> l b t h\")\n", - " hs_s = get_supressed_activations(hs, Wo.to(hs.dtype), Wo_inv.to(hs.dtype))\n", - " hs_s = rearrange(hs_s, \"l b t h -> b l t h\")\n", - " o['hidden_states_supressed'] = hs_s.half()\n", - " \n", - " return o\n" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-02-16 09:52:12.917\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m78\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet\u001b[0m\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c0c38f37f9934a0dbe7086b695624548", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "collecting activations: 0%| | 0/5 [00:00