act grousp

This commit is contained in:
wassname
2025-03-14 16:43:00 +08:00
parent 9d63a74fc6
commit fa7bd34b36
2 changed files with 169 additions and 84 deletions
+39 -13
View File
@@ -2,6 +2,7 @@ from transformers import AutoModelForCausalLM
import torch
from datasets import Dataset
from tqdm.auto import tqdm
import itertools
from torch.utils.data import DataLoader
from loguru import logger
from pathlib import Path
@@ -11,13 +12,14 @@ from datasets.arrow_writer import ArrowWriter, ParquetWriter
from datasets.fingerprint import Hasher
from transformers.modeling_outputs import ModelOutput
from activation_store.helpers.torch import clear_mem
from typing import Dict, Generator
from typing import Dict, Generator, List, Union, Optional
from torch import Tensor
import os
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
@torch.no_grad()
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, last_token=True, dtype=torch.float16) -> Dict[str, Tensor]:
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16) -> Dict[str, Tensor]:
"""Make your own. This adds activations to output, and rearranges hidden states.
Note the parquet write support float16, so we use that. It does not support float8, bfloat16, etc.
@@ -27,12 +29,15 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
"""
token_index = slice(-1, None) if last_token else slice(None)
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
# usually [b, t, h] but it depends on the model
try:
acts = {'acts': torch.stack([v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] for k, v in trace.items()], dim=1)}
except Exception as e:
logger.error(f"failed to stack activations: {e}")
# Baukit records the literal layer output, which varies by model. Sometimes you get a tuple, or not.Usually [b, t, h] for MLP, but not for attention layers. You may need to customize this.
if act_groups is not None:
acts = {}
for k, group in act_groups.items():
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
assert len(aas) > 0, f"no activations found for {group}"
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
acts[k] = aas
else:
acts = {f'act-{k}':
v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index]
for k, v in trace.items()}
@@ -49,7 +54,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
if 'label' in input:
o['label'] = input['label']
# convert any 0d tensors like loss to 1d, by repeating along batch dimension
# all output tensors must have a batch dim
for k, v in o.items():
if v.dim() == 0:
bs = input['input_ids'].shape[0]
@@ -58,7 +63,26 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
@torch.no_grad
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
"""
Collect activations from a model
Args:
- loader: DataLoader
- model: AutoModelForCausalLM
- layers: can be
- selected from `model.named_modules()`
- groups of layers to collect, these will be stacked so they must have compatible sizes
- postprocess_result: Callable - see `default_postprocess_result` for signature
Returns:
- Generator of [Dict[str, Tensor]], where each tensor has shape [batch,...]
"""
act_groups = None
if isinstance(layers, dict):
act_groups = layers
layers = list(itertools.chain(*layers.values()))
model.eval()
for batch in tqdm(loader, 'collecting activations'):
device = next(model.parameters()).device
@@ -68,7 +92,7 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
o = postprocess_result(batch, trace, out, model)
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
# copy to avoid memory leaks
o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()}
@@ -83,7 +107,7 @@ def dataset_hash(**kwargs):
return suffix
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
"""
Collect activations from a model and store them in a dataset
@@ -91,7 +115,9 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
- loader: DataLoader
- model: AutoModelForCausalLM
- dataset_name: str
- layers: List[str] - selected from `model.named_modules()`
- layers:
- List[str] selected from `model.named_modules()`
- or Dict[str, List[str]]] - groups of layers to collect, these will be stacked so they must have compatible sizes
- dataset_dir: Path
- postprocess_result: Callable - see `default_postprocess_result` for signature
+130 -71
View File
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -12,7 +12,20 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -33,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -61,19 +74,19 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['attention_mask', 'input_ids'],\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 48,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -107,14 +120,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.data.dataloader.DataLoader object at 0x7089f82ccb30>\n"
"<torch.utils.data.dataloader.DataLoader object at 0x7c988e1ef290>\n"
]
}
],
@@ -135,65 +148,117 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['model.layers.0.mlp.down_proj',\n",
" 'model.layers.1.mlp.down_proj',\n",
" 'model.layers.2.mlp.down_proj',\n",
" 'model.layers.3.mlp.down_proj',\n",
" 'model.layers.4.mlp.down_proj',\n",
" 'model.layers.5.mlp.down_proj',\n",
" 'model.layers.6.mlp.down_proj',\n",
" 'model.layers.7.mlp.down_proj',\n",
" 'model.layers.8.mlp.down_proj',\n",
" 'model.layers.9.mlp.down_proj',\n",
" 'model.layers.10.mlp.down_proj',\n",
" 'model.layers.11.mlp.down_proj',\n",
" 'model.layers.12.mlp.down_proj',\n",
" 'model.layers.13.mlp.down_proj',\n",
" 'model.layers.14.mlp.down_proj',\n",
" 'model.layers.15.mlp.down_proj',\n",
" 'model.layers.16.mlp.down_proj',\n",
" 'model.layers.17.mlp.down_proj',\n",
" 'model.layers.18.mlp.down_proj',\n",
" 'model.layers.19.mlp.down_proj',\n",
" 'model.layers.20.mlp.down_proj',\n",
" 'model.layers.21.mlp.down_proj',\n",
" 'model.layers.22.mlp.down_proj',\n",
" 'model.layers.23.mlp.down_proj']"
"{'mlp.down_proj': ['model.layers.0.mlp.down_proj',\n",
" 'model.layers.1.mlp.down_proj',\n",
" 'model.layers.2.mlp.down_proj',\n",
" 'model.layers.3.mlp.down_proj',\n",
" 'model.layers.4.mlp.down_proj',\n",
" 'model.layers.5.mlp.down_proj',\n",
" 'model.layers.6.mlp.down_proj',\n",
" 'model.layers.7.mlp.down_proj',\n",
" 'model.layers.8.mlp.down_proj',\n",
" 'model.layers.9.mlp.down_proj',\n",
" 'model.layers.10.mlp.down_proj',\n",
" 'model.layers.11.mlp.down_proj',\n",
" 'model.layers.12.mlp.down_proj',\n",
" 'model.layers.13.mlp.down_proj',\n",
" 'model.layers.14.mlp.down_proj',\n",
" 'model.layers.15.mlp.down_proj',\n",
" 'model.layers.16.mlp.down_proj',\n",
" 'model.layers.17.mlp.down_proj',\n",
" 'model.layers.18.mlp.down_proj',\n",
" 'model.layers.19.mlp.down_proj',\n",
" 'model.layers.20.mlp.down_proj',\n",
" 'model.layers.21.mlp.down_proj',\n",
" 'model.layers.22.mlp.down_proj',\n",
" 'model.layers.23.mlp.down_proj'],\n",
" 'self_attn': ['model.layers.0.self_attn',\n",
" 'model.layers.1.self_attn',\n",
" 'model.layers.2.self_attn',\n",
" 'model.layers.3.self_attn',\n",
" 'model.layers.4.self_attn',\n",
" 'model.layers.5.self_attn',\n",
" 'model.layers.6.self_attn',\n",
" 'model.layers.7.self_attn',\n",
" 'model.layers.8.self_attn',\n",
" 'model.layers.9.self_attn',\n",
" 'model.layers.10.self_attn',\n",
" 'model.layers.11.self_attn',\n",
" 'model.layers.12.self_attn',\n",
" 'model.layers.13.self_attn',\n",
" 'model.layers.14.self_attn',\n",
" 'model.layers.15.self_attn',\n",
" 'model.layers.16.self_attn',\n",
" 'model.layers.17.self_attn',\n",
" 'model.layers.18.self_attn',\n",
" 'model.layers.19.self_attn',\n",
" 'model.layers.20.self_attn',\n",
" 'model.layers.21.self_attn',\n",
" 'model.layers.22.self_attn',\n",
" 'model.layers.23.self_attn'],\n",
" 'mlp.up_proj': ['model.layers.0.mlp.up_proj',\n",
" 'model.layers.1.mlp.up_proj',\n",
" 'model.layers.2.mlp.up_proj',\n",
" 'model.layers.3.mlp.up_proj',\n",
" 'model.layers.4.mlp.up_proj',\n",
" 'model.layers.5.mlp.up_proj',\n",
" 'model.layers.6.mlp.up_proj',\n",
" 'model.layers.7.mlp.up_proj',\n",
" 'model.layers.8.mlp.up_proj',\n",
" 'model.layers.9.mlp.up_proj',\n",
" 'model.layers.10.mlp.up_proj',\n",
" 'model.layers.11.mlp.up_proj',\n",
" 'model.layers.12.mlp.up_proj',\n",
" 'model.layers.13.mlp.up_proj',\n",
" 'model.layers.14.mlp.up_proj',\n",
" 'model.layers.15.mlp.up_proj',\n",
" 'model.layers.16.mlp.up_proj',\n",
" 'model.layers.17.mlp.up_proj',\n",
" 'model.layers.18.mlp.up_proj',\n",
" 'model.layers.19.mlp.up_proj',\n",
" 'model.layers.20.mlp.up_proj',\n",
" 'model.layers.21.mlp.up_proj',\n",
" 'model.layers.22.mlp.up_proj',\n",
" 'model.layers.23.mlp.up_proj']}"
]
},
"execution_count": 50,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# choose layers to cache\n",
"layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]\n",
"layers"
"layer_groups = {\n",
" 'mlp.down_proj': [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')],\n",
" 'self_attn': [k for k,v in model.named_modules() if k.endswith('.self_attn')],\n",
" 'mlp.up_proj': [k for k,v in model.named_modules() if k.endswith('mlp.up_proj')],\n",
"}\n",
"layer_groups"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-02-16 09:36:37.315\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m77\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet\u001b[0m\n"
"\u001b[32m2025-03-14 16:42:30.982\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m134\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8341bbff75634f0fb235e107abc2083d",
"model_id": "e0a492af45854b2c83f6d10d87c6d42a",
"version_major": 2,
"version_minor": 0
},
@@ -204,44 +269,51 @@
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"data": {
"text/plain": [
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')"
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet')"
]
},
"execution_count": 51,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = activation_store(ds, model, layers=layers)\n",
"f = activation_store(ds, model, layers=layer_groups)\n",
"f"
]
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1bcb56397af43ac82f4cf4761acdd87",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],\n",
" features: ['mlp.down_proj', 'self_attn', 'mlp.up_proj', 'loss', 'logits', 'hidden_states'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 57,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -254,16 +326,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 25, 453, 896])"
"torch.Size([2, 25, 1, 896])"
]
},
"execution_count": 62,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -274,23 +346,10 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 453, 896])"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape"
]
"outputs": [],
"source": []
}
],
"metadata": {