From fa7bd34b365cb0b8934e8a991d9ebbbcaea79109 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Fri, 14 Mar 2025 16:43:00 +0800 Subject: [PATCH] act grousp --- activation_store/collect.py | 52 +++++++--- nbs/example.ipynb | 201 +++++++++++++++++++++++------------- 2 files changed, 169 insertions(+), 84 deletions(-) diff --git a/activation_store/collect.py b/activation_store/collect.py index 52fd899..d68d73a 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -2,6 +2,7 @@ from transformers import AutoModelForCausalLM import torch from datasets import Dataset from tqdm.auto import tqdm +import itertools from torch.utils.data import DataLoader from loguru import logger from pathlib import Path @@ -11,13 +12,14 @@ from datasets.arrow_writer import ArrowWriter, ParquetWriter from datasets.fingerprint import Hasher from transformers.modeling_outputs import ModelOutput from activation_store.helpers.torch import clear_mem -from typing import Dict, Generator +from typing import Dict, Generator, List, Union, Optional from torch import Tensor +import os default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() @torch.no_grad() -def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, last_token=True, dtype=torch.float16) -> Dict[str, Tensor]: +def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16) -> Dict[str, Tensor]: """Make your own. This adds activations to output, and rearranges hidden states. Note the parquet write support float16, so we use that. It does not support float8, bfloat16, etc. @@ -27,12 +29,15 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu """ token_index = slice(-1, None) if last_token else slice(None) - # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want - # usually [b, t, h] but it depends on the model - try: - acts = {'acts': torch.stack([v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] for k, v in trace.items()], dim=1)} - except Exception as e: - logger.error(f"failed to stack activations: {e}") + # Baukit records the literal layer output, which varies by model. Sometimes you get a tuple, or not.Usually [b, t, h] for MLP, but not for attention layers. You may need to customize this. + if act_groups is not None: + acts = {} + for k, group in act_groups.items(): + aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group] + assert len(aas) > 0, f"no activations found for {group}" + aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1) + acts[k] = aas + else: acts = {f'act-{k}': v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] for k, v in trace.items()} @@ -49,7 +54,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu if 'label' in input: o['label'] = input['label'] - # convert any 0d tensors like loss to 1d, by repeating along batch dimension + # all output tensors must have a batch dim for k, v in o.items(): if v.dim() == 0: bs = input['input_ids'].shape[0] @@ -58,7 +63,26 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu @torch.no_grad -def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: +def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: + """ + Collect activations from a model + + Args: + - loader: DataLoader + - model: AutoModelForCausalLM + - layers: can be + - selected from `model.named_modules()` + - groups of layers to collect, these will be stacked so they must have compatible sizes + - postprocess_result: Callable - see `default_postprocess_result` for signature + + Returns: + - Generator of [Dict[str, Tensor]], where each tensor has shape [batch,...] + """ + act_groups = None + if isinstance(layers, dict): + act_groups = layers + layers = list(itertools.chain(*layers.values())) + model.eval() for batch in tqdm(loader, 'collecting activations'): device = next(model.parameters()).device @@ -68,7 +92,7 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()} with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace: out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) - o = postprocess_result(batch, trace, out, model) + o = postprocess_result(batch, trace, out, model, act_groups=act_groups) # copy to avoid memory leaks o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()} @@ -83,7 +107,7 @@ def dataset_hash(**kwargs): return suffix -def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset: +def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset: """ Collect activations from a model and store them in a dataset @@ -91,7 +115,9 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na - loader: DataLoader - model: AutoModelForCausalLM - dataset_name: str - - layers: List[str] - selected from `model.named_modules()` + - layers: + - List[str] selected from `model.named_modules()` + - or Dict[str, List[str]]] - groups of layers to collect, these will be stacked so they must have compatible sizes - dataset_dir: Path - postprocess_result: Callable - see `default_postprocess_result` for signature diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 189ef56..834341f 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 45, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,20 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", + "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", + "# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -61,19 +74,19 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", - " features: ['attention_mask', 'input_ids'],\n", + " features: ['input_ids', 'attention_mask'],\n", " num_rows: 20\n", "})" ] }, - "execution_count": 48, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -107,14 +120,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] } ], @@ -135,65 +148,117 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['model.layers.0.mlp.down_proj',\n", - " 'model.layers.1.mlp.down_proj',\n", - " 'model.layers.2.mlp.down_proj',\n", - " 'model.layers.3.mlp.down_proj',\n", - " 'model.layers.4.mlp.down_proj',\n", - " 'model.layers.5.mlp.down_proj',\n", - " 'model.layers.6.mlp.down_proj',\n", - " 'model.layers.7.mlp.down_proj',\n", - " 'model.layers.8.mlp.down_proj',\n", - " 'model.layers.9.mlp.down_proj',\n", - " 'model.layers.10.mlp.down_proj',\n", - " 'model.layers.11.mlp.down_proj',\n", - " 'model.layers.12.mlp.down_proj',\n", - " 'model.layers.13.mlp.down_proj',\n", - " 'model.layers.14.mlp.down_proj',\n", - " 'model.layers.15.mlp.down_proj',\n", - " 'model.layers.16.mlp.down_proj',\n", - " 'model.layers.17.mlp.down_proj',\n", - " 'model.layers.18.mlp.down_proj',\n", - " 'model.layers.19.mlp.down_proj',\n", - " 'model.layers.20.mlp.down_proj',\n", - " 'model.layers.21.mlp.down_proj',\n", - " 'model.layers.22.mlp.down_proj',\n", - " 'model.layers.23.mlp.down_proj']" + "{'mlp.down_proj': ['model.layers.0.mlp.down_proj',\n", + " 'model.layers.1.mlp.down_proj',\n", + " 'model.layers.2.mlp.down_proj',\n", + " 'model.layers.3.mlp.down_proj',\n", + " 'model.layers.4.mlp.down_proj',\n", + " 'model.layers.5.mlp.down_proj',\n", + " 'model.layers.6.mlp.down_proj',\n", + " 'model.layers.7.mlp.down_proj',\n", + " 'model.layers.8.mlp.down_proj',\n", + " 'model.layers.9.mlp.down_proj',\n", + " 'model.layers.10.mlp.down_proj',\n", + " 'model.layers.11.mlp.down_proj',\n", + " 'model.layers.12.mlp.down_proj',\n", + " 'model.layers.13.mlp.down_proj',\n", + " 'model.layers.14.mlp.down_proj',\n", + " 'model.layers.15.mlp.down_proj',\n", + " 'model.layers.16.mlp.down_proj',\n", + " 'model.layers.17.mlp.down_proj',\n", + " 'model.layers.18.mlp.down_proj',\n", + " 'model.layers.19.mlp.down_proj',\n", + " 'model.layers.20.mlp.down_proj',\n", + " 'model.layers.21.mlp.down_proj',\n", + " 'model.layers.22.mlp.down_proj',\n", + " 'model.layers.23.mlp.down_proj'],\n", + " 'self_attn': ['model.layers.0.self_attn',\n", + " 'model.layers.1.self_attn',\n", + " 'model.layers.2.self_attn',\n", + " 'model.layers.3.self_attn',\n", + " 'model.layers.4.self_attn',\n", + " 'model.layers.5.self_attn',\n", + " 'model.layers.6.self_attn',\n", + " 'model.layers.7.self_attn',\n", + " 'model.layers.8.self_attn',\n", + " 'model.layers.9.self_attn',\n", + " 'model.layers.10.self_attn',\n", + " 'model.layers.11.self_attn',\n", + " 'model.layers.12.self_attn',\n", + " 'model.layers.13.self_attn',\n", + " 'model.layers.14.self_attn',\n", + " 'model.layers.15.self_attn',\n", + " 'model.layers.16.self_attn',\n", + " 'model.layers.17.self_attn',\n", + " 'model.layers.18.self_attn',\n", + " 'model.layers.19.self_attn',\n", + " 'model.layers.20.self_attn',\n", + " 'model.layers.21.self_attn',\n", + " 'model.layers.22.self_attn',\n", + " 'model.layers.23.self_attn'],\n", + " 'mlp.up_proj': ['model.layers.0.mlp.up_proj',\n", + " 'model.layers.1.mlp.up_proj',\n", + " 'model.layers.2.mlp.up_proj',\n", + " 'model.layers.3.mlp.up_proj',\n", + " 'model.layers.4.mlp.up_proj',\n", + " 'model.layers.5.mlp.up_proj',\n", + " 'model.layers.6.mlp.up_proj',\n", + " 'model.layers.7.mlp.up_proj',\n", + " 'model.layers.8.mlp.up_proj',\n", + " 'model.layers.9.mlp.up_proj',\n", + " 'model.layers.10.mlp.up_proj',\n", + " 'model.layers.11.mlp.up_proj',\n", + " 'model.layers.12.mlp.up_proj',\n", + " 'model.layers.13.mlp.up_proj',\n", + " 'model.layers.14.mlp.up_proj',\n", + " 'model.layers.15.mlp.up_proj',\n", + " 'model.layers.16.mlp.up_proj',\n", + " 'model.layers.17.mlp.up_proj',\n", + " 'model.layers.18.mlp.up_proj',\n", + " 'model.layers.19.mlp.up_proj',\n", + " 'model.layers.20.mlp.up_proj',\n", + " 'model.layers.21.mlp.up_proj',\n", + " 'model.layers.22.mlp.up_proj',\n", + " 'model.layers.23.mlp.up_proj']}" ] }, - "execution_count": 50, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# choose layers to cache\n", - "layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]\n", - "layers" + "layer_groups = {\n", + " 'mlp.down_proj': [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')],\n", + " 'self_attn': [k for k,v in model.named_modules() if k.endswith('.self_attn')],\n", + " 'mlp.up_proj': [k for k,v in model.named_modules() if k.endswith('mlp.up_proj')],\n", + "}\n", + "layer_groups" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2025-02-16 09:36:37.315\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m77\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet\u001b[0m\n" + "\u001b[32m2025-03-14 16:42:30.982\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m134\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8341bbff75634f0fb235e107abc2083d", + "model_id": "e0a492af45854b2c83f6d10d87c6d42a", "version_major": 2, "version_minor": 0 }, @@ -204,44 +269,51 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" - ] - }, { "data": { "text/plain": [ - "PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')" + "PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet')" ] }, - "execution_count": 51, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "f = activation_store(ds, model, layers=layers)\n", + "f = activation_store(ds, model, layers=layer_groups)\n", "f" ] }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 15, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f1bcb56397af43ac82f4cf4761acdd87", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/plain": [ "Dataset({\n", - " features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],\n", + " features: ['mlp.down_proj', 'self_attn', 'mlp.up_proj', 'loss', 'logits', 'hidden_states'],\n", " num_rows: 20\n", "})" ] }, - "execution_count": 57, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -254,16 +326,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 25, 453, 896])" + "torch.Size([2, 25, 1, 896])" ] }, - "execution_count": 62, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -274,23 +346,10 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 453, 896])" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape" - ] + "outputs": [], + "source": [] } ], "metadata": {