mirror of
https://github.com/wassname/activation_store.git
synced 2026-07-04 23:17:53 +08:00
wip
This commit is contained in:
@@ -17,7 +17,7 @@ from torch import Tensor
|
||||
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput) -> Dict[str, Tensor]:
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]:
|
||||
"""add activations to output, and rearrange hidden states"""
|
||||
|
||||
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
|
||||
@@ -27,7 +27,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')
|
||||
|
||||
return dict(**acts, **output)
|
||||
return dict(attention_mask=input["attention_mask"], **acts, **output)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
@@ -38,9 +38,10 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po
|
||||
with torch.amp.autocast(device_type=device.type):
|
||||
with TraceDict(model, layers) as trace:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
o = postprocess_result(batch, trace, out)
|
||||
o = postprocess_result(batch, trace, out, model)
|
||||
|
||||
# copy to avoid memory leaks
|
||||
o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()}
|
||||
o = recursive_copy(o)
|
||||
out = trace = batch = None
|
||||
clear_mem()
|
||||
@@ -69,7 +70,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
|
||||
Usage:
|
||||
f = activation_store(loader, model, layers=['transformer.h'])
|
||||
Dataset.from_parquet(f)
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
"""
|
||||
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet"
|
||||
@@ -84,12 +85,11 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
for bo in iterator:
|
||||
|
||||
bs = len(next(iter(bo.values())))
|
||||
assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
|
||||
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
|
||||
|
||||
# or maybe better compression to `writer.write(example, key)` for each
|
||||
writer.write_batch(bo)
|
||||
writer.finalize()
|
||||
writer.close()
|
||||
|
||||
# ds = Dataset.from_file(str(f)).with_format("torch")
|
||||
return f
|
||||
|
||||
+54
-60
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -33,7 +33,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -41,7 +41,7 @@
|
||||
"\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" model_name,\n",
|
||||
" torch_dtype=\"auto\",\n",
|
||||
" torch_dtype=torch.bfloat16,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n",
|
||||
")\n",
|
||||
@@ -57,19 +57,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Dataset({\n",
|
||||
" features: ['input_ids', 'attention_mask'],\n",
|
||||
" features: ['attention_mask', 'input_ids'],\n",
|
||||
" num_rows: 20\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -103,14 +103,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<torch.utils.data.dataloader.DataLoader object at 0x76557465f770>\n"
|
||||
"<torch.utils.data.dataloader.DataLoader object at 0x7089f82ccb30>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -136,7 +136,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -168,33 +168,33 @@
|
||||
" 'model.layers.23.mlp.down_proj']"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# choose layers to cache\n",
|
||||
"layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]\n",
|
||||
"layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]\n",
|
||||
"layers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2025-02-15 21:58:37.654\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__9b3f4b0da96e9ad5.parquet\u001b[0m\n"
|
||||
"\u001b[32m2025-02-16 09:36:37.315\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m77\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "fe95a697e5c0432e85d15707b07fd001",
|
||||
"model_id": "8341bbff75634f0fb235e107abc2083d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@@ -213,15 +213,14 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'ds_a' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m f \u001b[38;5;241m=\u001b[39m activation_store(ds, model, layers\u001b[38;5;241m=\u001b[39mlayers, writer_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mds_a\u001b[49m\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'ds_a' is not defined"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')"
|
||||
]
|
||||
},
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -231,23 +230,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1e1429e1d3224a2b8a5398f7a414911d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0 examples [00:00, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@@ -257,47 +242,56 @@
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import Dataset\n",
|
||||
"Dataset.from_parquet(str(f))"
|
||||
"ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n",
|
||||
"ds_a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([2, 25, 453, 896])"
|
||||
]
|
||||
},
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a[0:2]['logits'].shape"
|
||||
"ds_a[0:2]['hidden_states'].shape # [batch, layers, tokens, hidden_states]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([2, 453, 896])"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a[0:2]['model.layers.0.mlp.down_proj'].shape"
|
||||
"ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user