diff --git a/.gitignore b/.gitignore index 70a94ea..edbae4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +/outputs + # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,jupyternotebooks,linux # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,jupyternotebooks,linux diff --git a/activation_store/collect.py b/activation_store/collect.py index a85cec3..6512cde 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -7,7 +7,7 @@ from loguru import logger from pathlib import Path from baukit.nethook import TraceDict, recursive_copy from einops import rearrange -from datasets.arrow_writer import ArrowWriter +from datasets.arrow_writer import ArrowWriter, ParquetWriter from datasets.fingerprint import Hasher from transformers.modeling_outputs import ModelOutput @@ -15,15 +15,15 @@ from activation_store.helpers.torch import clear_mem from typing import Dict, Generator from torch import Tensor -default_output_folder = (Path(__file__).parent.parent.parent / "outputs").resolve() +default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() -def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput) -> Dict[str, Tensor]: - """add ret, activations to output""" +def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput) -> Dict[str, Tensor]: + """add activations to output, and rearrange hidden states""" # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want acts = {f'act-{k}': v.output[0] if isinstance(v.output, tuple) else v.output - for k, v in ret.items()} + for k, v in trace.items()} output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h') @@ -33,17 +33,16 @@ def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput) @torch.no_grad def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: model.eval() - for batch in tqdm(loader, 'collecting hidden states'): + for batch in tqdm(loader, 'collecting activations'): device = next(model.parameters()).device - b_in = { - k: v.to(device) - for k, v in batch.items() - } - with TraceDict(model, layers) as ret: - out = model(**b_in, use_cache=False, output_hidden_states=True, return_dict=True) - o = postprocess_result(batch, ret, out) + with torch.amp.autocast(device_type=device.type): + with TraceDict(model, layers) as trace: + out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) + o = postprocess_result(batch, trace, out) + + # copy to avoid memory leaks o = recursive_copy(o) - out = ret = b_in = None + out = trace = batch = None clear_mem() yield o @@ -64,22 +63,33 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na - layers: List[str] - selected from `model.named_modules()` - dataset_dir: Path - postprocess_result: Callable - see `default_postprocess_result` for signature + + Returns: + - file + + Usage: + f = activation_store(loader, model, layers=['transformer.h']) + Dataset.from_parquet(f) """ hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) - f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}" + f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet" f.parent.mkdir(exist_ok=True, parents=True) logger.info(f"creating dataset {f}") - iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result) - with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer: + iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result) + + with ParquetWriter(path=f, writer_batch_size=writer_batch_size, + embed_local_files=True + ) as writer: for bo in iterator: bs = len(next(iter(bo.values()))) assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension" + # or maybe better compression to `writer.write(example, key)` for each writer.write_batch(bo) - writer.write_examples_on_file() writer.finalize() + writer.close() - ds = Dataset.from_file(str(f)).with_format("torch") - return ds, f + # ds = Dataset.from_file(str(f)).with_format("torch") + return f diff --git a/activation_store/helpers/torch.py b/activation_store/helpers/torch.py index cb16720..e649b37 100644 --- a/activation_store/helpers/torch.py +++ b/activation_store/helpers/torch.py @@ -1,78 +1,8 @@ import torch import gc -import copy -import numpy as np -from jaxtyping import Float, Int -from torch import Tensor - -# def switch(p: Float[Tensor, ""], s: Float[Tensor, ""]): -# """if the true label is 0, we will flip our binary prediction around. so 25% becomes 75%. It's the rating of how correct our answer was from 0 to 1""" -# s = s.float() -# return (1 - s) * (1-p) + s * p def clear_mem(): gc.collect() - # get_accelerator().empty_cache() - # accelerator.free_memory() torch.cuda.empty_cache() gc.collect() - -def detachcpu(x): - """ - Trys to convert torch if possible a single item - """ - if isinstance(x, torch.Tensor): - x = x.cpu() - return x - else: - return x - -# def recursive_copy(x, clone=None, detach=None, retain_grad=None): -# """ -# from baukit - -# Copies a reference to a tensor, or an object that contains tensors, -# optionally detaching and cloning the tensor(s). If retain_grad is -# true, the original tensors are marked to have grads retained. -# """ -# if not clone and not detach and not retain_grad: -# return x -# if isinstance(x, torch.Tensor): -# if retain_grad: -# if not x.requires_grad: -# x.requires_grad = True -# x.retain_grad() -# elif detach: -# x = x.detach() -# if clone: -# x = x.clone() -# return x -# # Only dicts, lists, and tuples (and subclasses) can be copied. -# if isinstance(x, dict): -# return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()}) -# elif isinstance(x, (list, tuple)): -# return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x]) -# else: -# return copy.deepcopy(x) - -# def batch_to_device(b, device=None): -# """Move a batch to the device""" -# if isinstance(b, torch.Tensor): -# return b.to(device) -# elif isinstance(b, dict): -# return {k:batch_to_device(v, device=device) for k,v in b.items()} -# elif isinstance(b, (list, tuple)): -# return type(b)([batch_to_device(v, device=device) for v in b]) -# else: -# return b - -# def shape_of_anything(v): -# if isinstance(v, (Tensor, np.ndarray)): -# return v.shape -# elif isinstance(v, dict): -# return {k:shape_of_anything(v) for k,v in v.items()} -# elif isinstance(v, list): -# return len(v) -# else: -# return 1 diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 6d84ed6..2afe528 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -57,14 +57,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", - " features: ['prompt', 'chosen', 'rejected'],\n", + " features: ['input_ids', 'attention_mask'],\n", " num_rows: 20\n", "})" ] @@ -103,14 +103,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n" + "\n" ] } ], @@ -123,22 +123,10 @@ " padding=True,\n", " return_tensors=\"pt\",\n", " )\n", - "ds = DataLoader(ds2, batch_size=2, num_workers=0, collate_fn=collate_fn)\n", + "ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n", "print(ds)\n" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# # sanity check with one manual forward\n", - "# b = next(iter(ds))\n", - "# outputs = model(**b)\n", - "# outputs.keys()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -148,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -180,7 +168,7 @@ " 'model.layers.23.mlp.down_proj']" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -200,18 +188,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2025-02-15 21:14:24.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mcollect_act_to_disk\u001b[0m:\u001b[36m60\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/outputs/.ds/ds__7ae34f9e83796c91\u001b[0m\n" + "\u001b[32m2025-02-15 21:58:37.654\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__9b3f4b0da96e9ad5.parquet\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f6ed6625c38544378d2d46969a8470c4", + "model_id": "fe95a697e5c0432e85d15707b07fd001", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "collecting hidden states: 0%| | 0/10 [00:00 2\u001b[0m \u001b[43mds_a\u001b[49m\n", + "\u001b[0;31mNameError\u001b[0m: name 'ds_a' is not defined" + ] + } + ], + "source": [ + "f = activation_store(ds, model, layers=layers, writer_batch_size=10)\n", + "f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1e1429e1d3224a2b8a5398f7a414911d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/plain": [ @@ -233,53 +257,37 @@ "})" ] }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ds_a, f = activation_store(ds, model, layers=layers)\n", - "ds_a" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 453, 151936])" - ] - }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "from datasets import Dataset\n", + "Dataset.from_parquet(str(f))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "ds_a[0:2]['logits'].shape" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "'model.layers.0.mlp.down_proj'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mds_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel.layers.0.mlp.down_proj\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", - "\u001b[0;31mKeyError\u001b[0m: 'model.layers.0.mlp.down_proj'" - ] - } - ], + "outputs": [], "source": [ "ds_a[0:2]['model.layers.0.mlp.down_proj'].shape" ]