This commit is contained in:
wassname
2025-02-15 21:21:36 +08:00
parent f186243fe1
commit 0e18875b25
3 changed files with 48 additions and 31 deletions
+1 -1
View File
@@ -2,7 +2,7 @@
Utility library to persistently store transformer activations on disk.
These activations can be quite large (layers x batch x sequence x hidden_size), so storing them on disk helps avoid out of memory errors.
These activations can be quite large (layers x batch x sequence x hidden_size), so generating them to disk helps avoid out of memory errors.
Install using `pip install git+https://github.com/wassname/activation_store.git`.
+13 -2
View File
@@ -53,13 +53,24 @@ def dataset_hash(**kwargs):
return suffix
def collect_act_to_disk(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1):
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
"""
Collect activations from a model and store them in a dataset
Args:
- loader: DataLoader
- model: AutoModelForCausalLM
- dataset_name: str
- layers: List[str] - selected from `model.named_modules()`
- dataset_dir: Path
- postprocess_result: Callable - see `default_postprocess_result` for signature
"""
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}"
f.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {f}")
iterator = generate_batches(loader, model, layers=layers)
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer:
for bo in iterator:
+34 -28
View File
@@ -12,18 +12,25 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from activation_store.collect import collect_act_to_disk\n",
"from activation_store.collect import activation_store\n",
"\n",
"import torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load model"
]
},
{
"cell_type": "code",
"execution_count": 3,
@@ -41,9 +48,16 @@
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data and tokenize"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -65,29 +79,7 @@
"max_length = 256\n",
"\n",
"imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n",
"imdb"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"\n",
"def proc(row):\n",
" messages = [\n",
@@ -102,6 +94,13 @@
"ds2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data loader"
]
},
{
"cell_type": "code",
"execution_count": 6,
@@ -140,6 +139,13 @@
"# outputs.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Collect activations"
]
},
{
"cell_type": "code",
"execution_count": 8,
@@ -187,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -233,7 +239,7 @@
}
],
"source": [
"ds_a, f = collect_act_to_disk(ds, model, layers=layers)\n",
"ds_a, f = activation_store(ds, model, layers=layers)\n",
"ds_a"
]
},