From 2f82c4bdec777dfa3139f98dcc7fca39ccf04e7d Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sun, 16 Feb 2025 09:41:13 +0800 Subject: [PATCH] wip --- activation_store/collect.py | 12 ++-- nbs/example.ipynb | 114 +++++++++++++++++------------------- 2 files changed, 60 insertions(+), 66 deletions(-) diff --git a/activation_store/collect.py b/activation_store/collect.py index 6512cde..926d8a5 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -17,7 +17,7 @@ from torch import Tensor default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() -def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput) -> Dict[str, Tensor]: +def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]: """add activations to output, and rearrange hidden states""" # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want @@ -27,7 +27,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h') - return dict(**acts, **output) + return dict(attention_mask=input["attention_mask"], **acts, **output) @torch.no_grad @@ -38,9 +38,10 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po with torch.amp.autocast(device_type=device.type): with TraceDict(model, layers) as trace: out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) - o = postprocess_result(batch, trace, out) + o = postprocess_result(batch, trace, out, model) # copy to avoid memory leaks + o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()} o = recursive_copy(o) out = trace = batch = None clear_mem() @@ -69,7 +70,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na Usage: f = activation_store(loader, model, layers=['transformer.h']) - Dataset.from_parquet(f) + Dataset.from_parquet(f).with_format("torch") """ hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet" @@ -84,12 +85,11 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na for bo in iterator: bs = len(next(iter(bo.values()))) - assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension" + assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size a first dimension" # or maybe better compression to `writer.write(example, key)` for each writer.write_batch(bo) writer.finalize() writer.close() - # ds = Dataset.from_file(str(f)).with_format("torch") return f diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 2afe528..4f9f4a7 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -41,7 +41,7 @@ "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", - " torch_dtype=\"auto\",\n", + " torch_dtype=torch.bfloat16,\n", " device_map=\"auto\",\n", " attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n", ")\n", @@ -57,19 +57,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", - " features: ['input_ids', 'attention_mask'],\n", + " features: ['attention_mask', 'input_ids'],\n", " num_rows: 20\n", "})" ] }, - "execution_count": 4, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -103,14 +103,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] } ], @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -168,33 +168,33 @@ " 'model.layers.23.mlp.down_proj']" ] }, - "execution_count": 7, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# choose layers to cache\n", - "layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]\n", + "layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]\n", "layers" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2025-02-15 21:58:37.654\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__9b3f4b0da96e9ad5.parquet\u001b[0m\n" + "\u001b[32m2025-02-16 09:36:37.315\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m77\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fe95a697e5c0432e85d15707b07fd001", + "model_id": "8341bbff75634f0fb235e107abc2083d", "version_major": 2, "version_minor": 0 }, @@ -213,15 +213,14 @@ ] }, { - "ename": "NameError", - "evalue": "name 'ds_a' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m f \u001b[38;5;241m=\u001b[39m activation_store(ds, model, layers\u001b[38;5;241m=\u001b[39mlayers, writer_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mds_a\u001b[49m\n", - "\u001b[0;31mNameError\u001b[0m: name 'ds_a' is not defined" - ] + "data": { + "text/plain": [ + "PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -231,23 +230,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1e1429e1d3224a2b8a5398f7a414911d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Generating train split: 0 examples [00:00, ? examples/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, { "data": { "text/plain": [ @@ -257,47 +242,56 @@ "})" ] }, - "execution_count": 10, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import Dataset\n", - "Dataset.from_parquet(str(f))" + "ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n", + "ds_a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 25, 453, 896])" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "ds_a[0:2]['logits'].shape" + "ds_a[0:2]['hidden_states'].shape # [batch, layers, tokens, hidden_states]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 61, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 453, 896])" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "ds_a[0:2]['model.layers.0.mlp.down_proj'].shape" + "ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {