This commit is contained in:
wassname
2025-02-15 21:15:15 +08:00
parent 691d290c47
commit f186243fe1
8 changed files with 529 additions and 48 deletions
+5 -5
View File
@@ -1,14 +1,14 @@
# cache_transformer_activations
# activation_store
Utility library to collect transformer activations **on disk**.
Utility library to persistently store transformer activations on disk.
These activations can be quite large (layers x batch x sequence x hidden_size), so it's nice to store it on disk and avoid and out of memory error.
These activations can be quite large (layers x batch x sequence x hidden_size), so storing them on disk helps avoid out of memory errors.
Install using `pip install git+https://github.com/wassname/cache_transformer_activations.git`.
Install using `pip install git+https://github.com/wassname/activation_store.git`.
## Development
```
git clone https//github.com/wassname/cache_transformer_activations.git
git clone https//github.com/wassname/activation_store.git
uv sync
```
+74
View File
@@ -0,0 +1,74 @@
from transformers import AutoModelForCausalLM
import torch
from datasets import Dataset
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from loguru import logger
from pathlib import Path
from baukit.nethook import TraceDict, recursive_copy
from einops import rearrange
from datasets.arrow_writer import ArrowWriter
from datasets.fingerprint import Hasher
from transformers.modeling_outputs import ModelOutput
from activation_store.helpers.torch import clear_mem
from typing import Dict, Generator
from torch import Tensor
default_output_folder = (Path(__file__).parent.parent.parent / "outputs").resolve()
def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput) -> Dict[str, Tensor]:
"""add ret, activations to output"""
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
acts = {f'act-{k}':
v.output[0] if isinstance(v.output, tuple) else v.output
for k, v in ret.items()}
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')
return dict(**acts, **output)
@torch.no_grad
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
model.eval()
for batch in tqdm(loader, 'collecting hidden states'):
device = next(model.parameters()).device
b_in = {
k: v.to(device)
for k, v in batch.items()
}
with TraceDict(model, layers) as ret:
out = model(**b_in, use_cache=False, output_hidden_states=True, return_dict=True)
o = postprocess_result(batch, ret, out)
o = recursive_copy(o)
out = ret = b_in = None
clear_mem()
yield o
def dataset_hash(**kwargs):
suffix = Hasher.hash(kwargs)
return suffix
def collect_act_to_disk(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1):
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}"
f.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {f}")
iterator = generate_batches(loader, model, layers=layers)
with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer:
for bo in iterator:
bs = len(next(iter(bo.values())))
assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
writer.write_batch(bo)
writer.write_examples_on_file()
writer.finalize()
ds = Dataset.from_file(str(f)).with_format("torch")
return ds, f
+78
View File
@@ -0,0 +1,78 @@
import torch
import gc
import copy
import numpy as np
from jaxtyping import Float, Int
from torch import Tensor
# def switch(p: Float[Tensor, ""], s: Float[Tensor, ""]):
# """if the true label is 0, we will flip our binary prediction around. so 25% becomes 75%. It's the rating of how correct our answer was from 0 to 1"""
# s = s.float()
# return (1 - s) * (1-p) + s * p
def clear_mem():
gc.collect()
# get_accelerator().empty_cache()
# accelerator.free_memory()
torch.cuda.empty_cache()
gc.collect()
def detachcpu(x):
"""
Trys to convert torch if possible a single item
"""
if isinstance(x, torch.Tensor):
x = x.cpu()
return x
else:
return x
# def recursive_copy(x, clone=None, detach=None, retain_grad=None):
# """
# from baukit
# Copies a reference to a tensor, or an object that contains tensors,
# optionally detaching and cloning the tensor(s). If retain_grad is
# true, the original tensors are marked to have grads retained.
# """
# if not clone and not detach and not retain_grad:
# return x
# if isinstance(x, torch.Tensor):
# if retain_grad:
# if not x.requires_grad:
# x.requires_grad = True
# x.retain_grad()
# elif detach:
# x = x.detach()
# if clone:
# x = x.clone()
# return x
# # Only dicts, lists, and tuples (and subclasses) can be copied.
# if isinstance(x, dict):
# return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
# elif isinstance(x, (list, tuple)):
# return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
# else:
# return copy.deepcopy(x)
# def batch_to_device(b, device=None):
# """Move a batch to the device"""
# if isinstance(b, torch.Tensor):
# return b.to(device)
# elif isinstance(b, dict):
# return {k:batch_to_device(v, device=device) for k,v in b.items()}
# elif isinstance(b, (list, tuple)):
# return type(b)([batch_to_device(v, device=device) for v in b])
# else:
# return b
# def shape_of_anything(v):
# if isinstance(v, (Tensor, np.ndarray)):
# return v.shape
# elif isinstance(v, dict):
# return {k:shape_of_anything(v) for k,v in v.items()}
# elif isinstance(v, list):
# return len(v)
# else:
# return 1
+310
View File
@@ -0,0 +1,310 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from activation_store.collect import collect_act_to_disk\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=\"auto\",\n",
" device_map=\"auto\",\n",
" attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['prompt', 'chosen', 'rejected'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N = 20\n",
"max_length = 256\n",
"\n",
"imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n",
"imdb"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"def proc(row):\n",
" messages = [\n",
" {\"role\":\"user\", \"content\": row['prompt'] },\n",
" {\"role\":\"assistant\", \"content\": row['chosen'] }\n",
" ]\n",
" return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)\n",
"\n",
"ds2 = imdb.map(proc).with_format(\"torch\")\n",
"new_cols = set(ds2.column_names) - set(imdb.column_names)\n",
"ds2 = ds2.select_columns(new_cols)\n",
"ds2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.data.dataloader.DataLoader object at 0x7f6ddd90fcb0>\n"
]
}
],
"source": [
"from torch.utils.data import DataLoader\n",
"def collate_fn(examples):\n",
" # Pad the batch to max length within this batch\n",
" return tokenizer.pad(\n",
" examples,\n",
" padding=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
"ds = DataLoader(ds2, batch_size=2, num_workers=0, collate_fn=collate_fn)\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# # sanity check with one manual forward\n",
"# b = next(iter(ds))\n",
"# outputs = model(**b)\n",
"# outputs.keys()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['model.layers.0.mlp.down_proj',\n",
" 'model.layers.1.mlp.down_proj',\n",
" 'model.layers.2.mlp.down_proj',\n",
" 'model.layers.3.mlp.down_proj',\n",
" 'model.layers.4.mlp.down_proj',\n",
" 'model.layers.5.mlp.down_proj',\n",
" 'model.layers.6.mlp.down_proj',\n",
" 'model.layers.7.mlp.down_proj',\n",
" 'model.layers.8.mlp.down_proj',\n",
" 'model.layers.9.mlp.down_proj',\n",
" 'model.layers.10.mlp.down_proj',\n",
" 'model.layers.11.mlp.down_proj',\n",
" 'model.layers.12.mlp.down_proj',\n",
" 'model.layers.13.mlp.down_proj',\n",
" 'model.layers.14.mlp.down_proj',\n",
" 'model.layers.15.mlp.down_proj',\n",
" 'model.layers.16.mlp.down_proj',\n",
" 'model.layers.17.mlp.down_proj',\n",
" 'model.layers.18.mlp.down_proj',\n",
" 'model.layers.19.mlp.down_proj',\n",
" 'model.layers.20.mlp.down_proj',\n",
" 'model.layers.21.mlp.down_proj',\n",
" 'model.layers.22.mlp.down_proj',\n",
" 'model.layers.23.mlp.down_proj']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# choose layers to cache\n",
"layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]\n",
"layers"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-02-15 21:14:24.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mcollect_act_to_disk\u001b[0m:\u001b[36m60\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/outputs/.ds/ds__7ae34f9e83796c91\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6ed6625c38544378d2d46969a8470c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"collecting hidden states: 0%| | 0/10 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a, f = collect_act_to_disk(ds, model, layers=layers)\n",
"ds_a"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 453, 151936])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a[0:2]['logits'].shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'model.layers.0.mlp.down_proj'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mds_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel.layers.0.mlp.down_proj\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
"\u001b[0;31mKeyError\u001b[0m: 'model.layers.0.mlp.down_proj'"
]
}
],
"source": [
"ds_a[0:2]['model.layers.0.mlp.down_proj'].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
+3 -2
View File
@@ -1,13 +1,14 @@
[project]
name = "cache-transformer-acts"
name = "activation_store"
version = "0.1.0"
description = "Add your description here"
description = "Cache transformer activations to disk"
readme = "README.md"
authors = [
{ name = "wassname", email = "1103714+wassname@users.noreply.github.com" }
]
requires-python = ">=3.12"
dependencies = [
"accelerate>=1.3.0",
"baukit",
"datasets>=3.3.0",
"einops>=0.8.1",
-2
View File
@@ -1,2 +0,0 @@
def hello() -> str:
return "Hello from cache-transformer-acts!"
Generated
+59 -39
View File
@@ -1,6 +1,65 @@
version = 1
requires-python = ">=3.12"
[[package]]
name = "accelerate"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "psutil" },
{ name = "pyyaml" },
{ name = "safetensors" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/85/15/0fab0260ab4069e5224e637d2e400538bb27b0dfc36f17daf68db9770d78/accelerate-1.3.0.tar.gz", hash = "sha256:518631c0adb80bd3d42fb29e7e2dc2256bcd7c786b0ba9119bbaa08611b36d9c", size = 342758 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/73/de/64508cb91af013aaba214752309c0967568a4219d50a4ea30e822af3c976/accelerate-1.3.0-py3-none-any.whl", hash = "sha256:5788d9e6a7a9f80fed665cf09681c4dddd9dc056bea656db4140ffc285ce423e", size = 336647 },
]
[[package]]
name = "activation-store"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "accelerate" },
{ name = "baukit" },
{ name = "datasets" },
{ name = "einops" },
{ name = "jaxtyping" },
{ name = "loguru" },
{ name = "torch" },
{ name = "tqdm" },
{ name = "transformers" },
]
[package.dev-dependencies]
dev = [
{ name = "ipykernel" },
{ name = "ipywidgets" },
]
[package.metadata]
requires-dist = [
{ name = "accelerate", specifier = ">=1.3.0" },
{ name = "baukit", git = "https://github.com/davidbau/baukit.git" },
{ name = "datasets", specifier = ">=3.3.0" },
{ name = "einops", specifier = ">=0.8.1" },
{ name = "jaxtyping", specifier = ">=0.2.38" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "torch", specifier = ">=2.6.0" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "transformers", specifier = ">=4.48.3" },
]
[package.metadata.requires-dev]
dev = [
{ name = "ipykernel", specifier = ">=6.29.5" },
{ name = "ipywidgets", specifier = ">=8.1.5" },
]
[[package]]
name = "aiohappyeyeballs"
version = "2.4.6"
@@ -108,45 +167,6 @@ dependencies = [
{ name = "torchvision" },
]
[[package]]
name = "cache-transformer-acts"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "baukit" },
{ name = "datasets" },
{ name = "einops" },
{ name = "jaxtyping" },
{ name = "loguru" },
{ name = "torch" },
{ name = "tqdm" },
{ name = "transformers" },
]
[package.dev-dependencies]
dev = [
{ name = "ipykernel" },
{ name = "ipywidgets" },
]
[package.metadata]
requires-dist = [
{ name = "baukit", git = "https://github.com/davidbau/baukit.git" },
{ name = "datasets", specifier = ">=3.3.0" },
{ name = "einops", specifier = ">=0.8.1" },
{ name = "jaxtyping", specifier = ">=0.2.38" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "torch", specifier = ">=2.6.0" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "transformers", specifier = ">=4.48.3" },
]
[package.metadata.requires-dev]
dev = [
{ name = "ipykernel", specifier = ">=6.29.5" },
{ name = "ipywidgets", specifier = ">=8.1.5" },
]
[[package]]
name = "certifi"
version = "2025.1.31"