From f186243fe1546e9db9f33088fa89a5f255afb5c7 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sat, 15 Feb 2025 21:15:15 +0800 Subject: [PATCH] working --- README.md | 10 +- .../py.typed => activation_store/__init__.py | 0 activation_store/collect.py | 74 +++++ activation_store/helpers/torch.py | 78 +++++ nbs/example.ipynb | 310 ++++++++++++++++++ pyproject.toml | 5 +- src/cache_transformer_acts/__init__.py | 2 - uv.lock | 98 +++--- 8 files changed, 529 insertions(+), 48 deletions(-) rename src/cache_transformer_acts/py.typed => activation_store/__init__.py (100%) create mode 100644 activation_store/collect.py create mode 100644 activation_store/helpers/torch.py create mode 100644 nbs/example.ipynb delete mode 100644 src/cache_transformer_acts/__init__.py diff --git a/README.md b/README.md index 191108a..b2f5755 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# cache_transformer_activations +# activation_store -Utility library to collect transformer activations **on disk**. +Utility library to persistently store transformer activations on disk. -These activations can be quite large (layers x batch x sequence x hidden_size), so it's nice to store it on disk and avoid and out of memory error. +These activations can be quite large (layers x batch x sequence x hidden_size), so storing them on disk helps avoid out of memory errors. -Install using `pip install git+https://github.com/wassname/cache_transformer_activations.git`. +Install using `pip install git+https://github.com/wassname/activation_store.git`. ## Development ``` -git clone https//github.com/wassname/cache_transformer_activations.git +git clone https//github.com/wassname/activation_store.git uv sync ``` diff --git a/src/cache_transformer_acts/py.typed b/activation_store/__init__.py similarity index 100% rename from src/cache_transformer_acts/py.typed rename to activation_store/__init__.py diff --git a/activation_store/collect.py b/activation_store/collect.py new file mode 100644 index 0000000..94d9457 --- /dev/null +++ b/activation_store/collect.py @@ -0,0 +1,74 @@ +from transformers import AutoModelForCausalLM +import torch +from datasets import Dataset +from tqdm.auto import tqdm +from torch.utils.data import DataLoader +from loguru import logger +from pathlib import Path +from baukit.nethook import TraceDict, recursive_copy +from einops import rearrange +from datasets.arrow_writer import ArrowWriter +from datasets.fingerprint import Hasher +from transformers.modeling_outputs import ModelOutput + +from activation_store.helpers.torch import clear_mem +from typing import Dict, Generator +from torch import Tensor + +default_output_folder = (Path(__file__).parent.parent.parent / "outputs").resolve() + +def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput) -> Dict[str, Tensor]: + """add ret, activations to output""" + + # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want + acts = {f'act-{k}': + v.output[0] if isinstance(v.output, tuple) else v.output + for k, v in ret.items()} + + output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h') + + return dict(**acts, **output) + + +@torch.no_grad +def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: + model.eval() + for batch in tqdm(loader, 'collecting hidden states'): + device = next(model.parameters()).device + b_in = { + k: v.to(device) + for k, v in batch.items() + } + with TraceDict(model, layers) as ret: + out = model(**b_in, use_cache=False, output_hidden_states=True, return_dict=True) + o = postprocess_result(batch, ret, out) + o = recursive_copy(o) + out = ret = b_in = None + clear_mem() + yield o + + +def dataset_hash(**kwargs): + suffix = Hasher.hash(kwargs) + return suffix + + +def collect_act_to_disk(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1): + hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) + f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}" + f.parent.mkdir(exist_ok=True, parents=True) + logger.info(f"creating dataset {f}") + + iterator = generate_batches(loader, model, layers=layers) + with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer: + for bo in iterator: + + bs = len(next(iter(bo.values()))) + assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension" + + writer.write_batch(bo) + writer.write_examples_on_file() + writer.finalize() + + ds = Dataset.from_file(str(f)).with_format("torch") + return ds, f diff --git a/activation_store/helpers/torch.py b/activation_store/helpers/torch.py new file mode 100644 index 0000000..cb16720 --- /dev/null +++ b/activation_store/helpers/torch.py @@ -0,0 +1,78 @@ +import torch +import gc +import copy +import numpy as np +from jaxtyping import Float, Int +from torch import Tensor + +# def switch(p: Float[Tensor, ""], s: Float[Tensor, ""]): +# """if the true label is 0, we will flip our binary prediction around. so 25% becomes 75%. It's the rating of how correct our answer was from 0 to 1""" +# s = s.float() +# return (1 - s) * (1-p) + s * p + + +def clear_mem(): + gc.collect() + # get_accelerator().empty_cache() + # accelerator.free_memory() + torch.cuda.empty_cache() + gc.collect() + +def detachcpu(x): + """ + Trys to convert torch if possible a single item + """ + if isinstance(x, torch.Tensor): + x = x.cpu() + return x + else: + return x + +# def recursive_copy(x, clone=None, detach=None, retain_grad=None): +# """ +# from baukit + +# Copies a reference to a tensor, or an object that contains tensors, +# optionally detaching and cloning the tensor(s). If retain_grad is +# true, the original tensors are marked to have grads retained. +# """ +# if not clone and not detach and not retain_grad: +# return x +# if isinstance(x, torch.Tensor): +# if retain_grad: +# if not x.requires_grad: +# x.requires_grad = True +# x.retain_grad() +# elif detach: +# x = x.detach() +# if clone: +# x = x.clone() +# return x +# # Only dicts, lists, and tuples (and subclasses) can be copied. +# if isinstance(x, dict): +# return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()}) +# elif isinstance(x, (list, tuple)): +# return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x]) +# else: +# return copy.deepcopy(x) + +# def batch_to_device(b, device=None): +# """Move a batch to the device""" +# if isinstance(b, torch.Tensor): +# return b.to(device) +# elif isinstance(b, dict): +# return {k:batch_to_device(v, device=device) for k,v in b.items()} +# elif isinstance(b, (list, tuple)): +# return type(b)([batch_to_device(v, device=device) for v in b]) +# else: +# return b + +# def shape_of_anything(v): +# if isinstance(v, (Tensor, np.ndarray)): +# return v.shape +# elif isinstance(v, dict): +# return {k:shape_of_anything(v) for k,v in v.items()} +# elif isinstance(v, list): +# return len(v) +# else: +# return 1 diff --git a/nbs/example.ipynb b/nbs/example.ipynb new file mode 100644 index 0000000..195bdfa --- /dev/null +++ b/nbs/example.ipynb @@ -0,0 +1,310 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "from activation_store.collect import collect_act_to_disk\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " torch_dtype=\"auto\",\n", + " device_map=\"auto\",\n", + " attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['prompt', 'chosen', 'rejected'],\n", + " num_rows: 20\n", + "})" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "N = 20\n", + "max_length = 256\n", + "\n", + "imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n", + "imdb" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['input_ids', 'attention_mask'],\n", + " num_rows: 20\n", + "})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def proc(row):\n", + " messages = [\n", + " {\"role\":\"user\", \"content\": row['prompt'] },\n", + " {\"role\":\"assistant\", \"content\": row['chosen'] }\n", + " ]\n", + " return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)\n", + "\n", + "ds2 = imdb.map(proc).with_format(\"torch\")\n", + "new_cols = set(ds2.column_names) - set(imdb.column_names)\n", + "ds2 = ds2.select_columns(new_cols)\n", + "ds2" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "def collate_fn(examples):\n", + " # Pad the batch to max length within this batch\n", + " return tokenizer.pad(\n", + " examples,\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + " )\n", + "ds = DataLoader(ds2, batch_size=2, num_workers=0, collate_fn=collate_fn)\n", + "print(ds)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# # sanity check with one manual forward\n", + "# b = next(iter(ds))\n", + "# outputs = model(**b)\n", + "# outputs.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['model.layers.0.mlp.down_proj',\n", + " 'model.layers.1.mlp.down_proj',\n", + " 'model.layers.2.mlp.down_proj',\n", + " 'model.layers.3.mlp.down_proj',\n", + " 'model.layers.4.mlp.down_proj',\n", + " 'model.layers.5.mlp.down_proj',\n", + " 'model.layers.6.mlp.down_proj',\n", + " 'model.layers.7.mlp.down_proj',\n", + " 'model.layers.8.mlp.down_proj',\n", + " 'model.layers.9.mlp.down_proj',\n", + " 'model.layers.10.mlp.down_proj',\n", + " 'model.layers.11.mlp.down_proj',\n", + " 'model.layers.12.mlp.down_proj',\n", + " 'model.layers.13.mlp.down_proj',\n", + " 'model.layers.14.mlp.down_proj',\n", + " 'model.layers.15.mlp.down_proj',\n", + " 'model.layers.16.mlp.down_proj',\n", + " 'model.layers.17.mlp.down_proj',\n", + " 'model.layers.18.mlp.down_proj',\n", + " 'model.layers.19.mlp.down_proj',\n", + " 'model.layers.20.mlp.down_proj',\n", + " 'model.layers.21.mlp.down_proj',\n", + " 'model.layers.22.mlp.down_proj',\n", + " 'model.layers.23.mlp.down_proj']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# choose layers to cache\n", + "layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]\n", + "layers" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-02-15 21:14:24.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mcollect_act_to_disk\u001b[0m:\u001b[36m60\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/outputs/.ds/ds__7ae34f9e83796c91\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f6ed6625c38544378d2d46969a8470c4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "collecting hidden states: 0%| | 0/10 [00:00 1\u001b[0m \u001b[43mds_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel.layers.0.mlp.down_proj\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "\u001b[0;31mKeyError\u001b[0m: 'model.layers.0.mlp.down_proj'" + ] + } + ], + "source": [ + "ds_a[0:2]['model.layers.0.mlp.down_proj'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 829f829..11c6fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,14 @@ [project] -name = "cache-transformer-acts" +name = "activation_store" version = "0.1.0" -description = "Add your description here" +description = "Cache transformer activations to disk" readme = "README.md" authors = [ { name = "wassname", email = "1103714+wassname@users.noreply.github.com" } ] requires-python = ">=3.12" dependencies = [ + "accelerate>=1.3.0", "baukit", "datasets>=3.3.0", "einops>=0.8.1", diff --git a/src/cache_transformer_acts/__init__.py b/src/cache_transformer_acts/__init__.py deleted file mode 100644 index 0d35970..0000000 --- a/src/cache_transformer_acts/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -def hello() -> str: - return "Hello from cache-transformer-acts!" diff --git a/uv.lock b/uv.lock index 487750e..0eb5188 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,65 @@ version = 1 requires-python = ">=3.12" +[[package]] +name = "accelerate" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/15/0fab0260ab4069e5224e637d2e400538bb27b0dfc36f17daf68db9770d78/accelerate-1.3.0.tar.gz", hash = "sha256:518631c0adb80bd3d42fb29e7e2dc2256bcd7c786b0ba9119bbaa08611b36d9c", size = 342758 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/de/64508cb91af013aaba214752309c0967568a4219d50a4ea30e822af3c976/accelerate-1.3.0-py3-none-any.whl", hash = "sha256:5788d9e6a7a9f80fed665cf09681c4dddd9dc056bea656db4140ffc285ce423e", size = 336647 }, +] + +[[package]] +name = "activation-store" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "accelerate" }, + { name = "baukit" }, + { name = "datasets" }, + { name = "einops" }, + { name = "jaxtyping" }, + { name = "loguru" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, +] + +[package.dev-dependencies] +dev = [ + { name = "ipykernel" }, + { name = "ipywidgets" }, +] + +[package.metadata] +requires-dist = [ + { name = "accelerate", specifier = ">=1.3.0" }, + { name = "baukit", git = "https://github.com/davidbau/baukit.git" }, + { name = "datasets", specifier = ">=3.3.0" }, + { name = "einops", specifier = ">=0.8.1" }, + { name = "jaxtyping", specifier = ">=0.2.38" }, + { name = "loguru", specifier = ">=0.7.3" }, + { name = "torch", specifier = ">=2.6.0" }, + { name = "tqdm", specifier = ">=4.67.1" }, + { name = "transformers", specifier = ">=4.48.3" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "ipykernel", specifier = ">=6.29.5" }, + { name = "ipywidgets", specifier = ">=8.1.5" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.4.6" @@ -108,45 +167,6 @@ dependencies = [ { name = "torchvision" }, ] -[[package]] -name = "cache-transformer-acts" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "baukit" }, - { name = "datasets" }, - { name = "einops" }, - { name = "jaxtyping" }, - { name = "loguru" }, - { name = "torch" }, - { name = "tqdm" }, - { name = "transformers" }, -] - -[package.dev-dependencies] -dev = [ - { name = "ipykernel" }, - { name = "ipywidgets" }, -] - -[package.metadata] -requires-dist = [ - { name = "baukit", git = "https://github.com/davidbau/baukit.git" }, - { name = "datasets", specifier = ">=3.3.0" }, - { name = "einops", specifier = ">=0.8.1" }, - { name = "jaxtyping", specifier = ">=0.2.38" }, - { name = "loguru", specifier = ">=0.7.3" }, - { name = "torch", specifier = ">=2.6.0" }, - { name = "tqdm", specifier = ">=4.67.1" }, - { name = "transformers", specifier = ">=4.48.3" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "ipykernel", specifier = ">=6.29.5" }, - { name = "ipywidgets", specifier = ">=8.1.5" }, -] - [[package]] name = "certifi" version = "2025.1.31"