This commit is contained in:
wassname
2025-02-15 22:03:52 +08:00
parent 0e18875b25
commit 8a61bfeba0
4 changed files with 98 additions and 148 deletions
+2
View File
@@ -1,3 +1,5 @@
/outputs
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,jupyternotebooks,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,jupyternotebooks,linux
+30 -20
View File
@@ -7,7 +7,7 @@ from loguru import logger
from pathlib import Path
from baukit.nethook import TraceDict, recursive_copy
from einops import rearrange
from datasets.arrow_writer import ArrowWriter
from datasets.arrow_writer import ArrowWriter, ParquetWriter
from datasets.fingerprint import Hasher
from transformers.modeling_outputs import ModelOutput
@@ -15,15 +15,15 @@ from activation_store.helpers.torch import clear_mem
from typing import Dict, Generator
from torch import Tensor
default_output_folder = (Path(__file__).parent.parent.parent / "outputs").resolve()
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput) -> Dict[str, Tensor]:
"""add ret, activations to output"""
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput) -> Dict[str, Tensor]:
"""add activations to output, and rearrange hidden states"""
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
acts = {f'act-{k}':
v.output[0] if isinstance(v.output, tuple) else v.output
for k, v in ret.items()}
for k, v in trace.items()}
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')
@@ -33,17 +33,16 @@ def default_postprocess_result(input: dict, ret: TraceDict, output: ModelOutput)
@torch.no_grad
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
model.eval()
for batch in tqdm(loader, 'collecting hidden states'):
for batch in tqdm(loader, 'collecting activations'):
device = next(model.parameters()).device
b_in = {
k: v.to(device)
for k, v in batch.items()
}
with TraceDict(model, layers) as ret:
out = model(**b_in, use_cache=False, output_hidden_states=True, return_dict=True)
o = postprocess_result(batch, ret, out)
with torch.amp.autocast(device_type=device.type):
with TraceDict(model, layers) as trace:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
o = postprocess_result(batch, trace, out)
# copy to avoid memory leaks
o = recursive_copy(o)
out = ret = b_in = None
out = trace = batch = None
clear_mem()
yield o
@@ -64,22 +63,33 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
- layers: List[str] - selected from `model.named_modules()`
- dataset_dir: Path
- postprocess_result: Callable - see `default_postprocess_result` for signature
Returns:
- file
Usage:
f = activation_store(loader, model, layers=['transformer.h'])
Dataset.from_parquet(f)
"""
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}"
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet"
f.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {f}")
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer:
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
with ParquetWriter(path=f, writer_batch_size=writer_batch_size,
embed_local_files=True
) as writer:
for bo in iterator:
bs = len(next(iter(bo.values())))
assert all(len(v) == bs for v in bo.values()), f"must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
# or maybe better compression to `writer.write(example, key)` for each
writer.write_batch(bo)
writer.write_examples_on_file()
writer.finalize()
writer.close()
ds = Dataset.from_file(str(f)).with_format("torch")
return ds, f
# ds = Dataset.from_file(str(f)).with_format("torch")
return f
-70
View File
@@ -1,78 +1,8 @@
import torch
import gc
import copy
import numpy as np
from jaxtyping import Float, Int
from torch import Tensor
# def switch(p: Float[Tensor, ""], s: Float[Tensor, ""]):
# """if the true label is 0, we will flip our binary prediction around. so 25% becomes 75%. It's the rating of how correct our answer was from 0 to 1"""
# s = s.float()
# return (1 - s) * (1-p) + s * p
def clear_mem():
gc.collect()
# get_accelerator().empty_cache()
# accelerator.free_memory()
torch.cuda.empty_cache()
gc.collect()
def detachcpu(x):
"""
Trys to convert torch if possible a single item
"""
if isinstance(x, torch.Tensor):
x = x.cpu()
return x
else:
return x
# def recursive_copy(x, clone=None, detach=None, retain_grad=None):
# """
# from baukit
# Copies a reference to a tensor, or an object that contains tensors,
# optionally detaching and cloning the tensor(s). If retain_grad is
# true, the original tensors are marked to have grads retained.
# """
# if not clone and not detach and not retain_grad:
# return x
# if isinstance(x, torch.Tensor):
# if retain_grad:
# if not x.requires_grad:
# x.requires_grad = True
# x.retain_grad()
# elif detach:
# x = x.detach()
# if clone:
# x = x.clone()
# return x
# # Only dicts, lists, and tuples (and subclasses) can be copied.
# if isinstance(x, dict):
# return type(x)({k: recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for k, v in x.items()})
# elif isinstance(x, (list, tuple)):
# return type(x)([recursive_copy(v, clone=clone, detach=detach, retain_grad=retain_grad) for v in x])
# else:
# return copy.deepcopy(x)
# def batch_to_device(b, device=None):
# """Move a batch to the device"""
# if isinstance(b, torch.Tensor):
# return b.to(device)
# elif isinstance(b, dict):
# return {k:batch_to_device(v, device=device) for k,v in b.items()}
# elif isinstance(b, (list, tuple)):
# return type(b)([batch_to_device(v, device=device) for v in b])
# else:
# return b
# def shape_of_anything(v):
# if isinstance(v, (Tensor, np.ndarray)):
# return v.shape
# elif isinstance(v, dict):
# return {k:shape_of_anything(v) for k,v in v.items()}
# elif isinstance(v, list):
# return len(v)
# else:
# return 1
+66 -58
View File
@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -57,14 +57,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['prompt', 'chosen', 'rejected'],\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 20\n",
"})"
]
@@ -103,14 +103,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.data.dataloader.DataLoader object at 0x7f6ddd90fcb0>\n"
"<torch.utils.data.dataloader.DataLoader object at 0x76557465f770>\n"
]
}
],
@@ -123,22 +123,10 @@
" padding=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
"ds = DataLoader(ds2, batch_size=2, num_workers=0, collate_fn=collate_fn)\n",
"ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n",
"print(ds)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# # sanity check with one manual forward\n",
"# b = next(iter(ds))\n",
"# outputs = model(**b)\n",
"# outputs.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -148,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -180,7 +168,7 @@
" 'model.layers.23.mlp.down_proj']"
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -200,18 +188,18 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-02-15 21:14:24.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mcollect_act_to_disk\u001b[0m:\u001b[36m60\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/outputs/.ds/ds__7ae34f9e83796c91\u001b[0m\n"
"\u001b[32m2025-02-15 21:58:37.654\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__9b3f4b0da96e9ad5.parquet\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6ed6625c38544378d2d46969a8470c4",
"model_id": "fe95a697e5c0432e85d15707b07fd001",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"collecting hidden states: 0%| | 0/10 [00:00<?, ?it/s]"
"collecting activations: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
@@ -224,6 +212,42 @@
"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"ename": "NameError",
"evalue": "name 'ds_a' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m f \u001b[38;5;241m=\u001b[39m activation_store(ds, model, layers\u001b[38;5;241m=\u001b[39mlayers, writer_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mds_a\u001b[49m\n",
"\u001b[0;31mNameError\u001b[0m: name 'ds_a' is not defined"
]
}
],
"source": [
"f = activation_store(ds, model, layers=layers, writer_batch_size=10)\n",
"f"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1e1429e1d3224a2b8a5398f7a414911d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
@@ -233,53 +257,37 @@
"})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a, f = activation_store(ds, model, layers=layers)\n",
"ds_a"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 453, 151936])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import Dataset\n",
"Dataset.from_parquet(str(f))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_a[0:2]['logits'].shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'model.layers.0.mlp.down_proj'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mds_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel.layers.0.mlp.down_proj\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
"\u001b[0;31mKeyError\u001b[0m: 'model.layers.0.mlp.down_proj'"
]
}
],
"outputs": [],
"source": [
"ds_a[0:2]['model.layers.0.mlp.down_proj'].shape"
]