mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
comments
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
Utility library to persistently store transformer activations on disk.
|
||||
|
||||
These activations can be quite large (layers x batch x sequence x hidden_size), so storing them on disk helps avoid out of memory errors.
|
||||
These activations can be quite large (layers x batch x sequence x hidden_size), so generating them to disk helps avoid out of memory errors.
|
||||
|
||||
Install using `pip install git+https://github.com/wassname/activation_store.git`.
|
||||
|
||||
|
||||
@@ -53,13 +53,24 @@ def dataset_hash(**kwargs):
|
||||
return suffix
|
||||
|
||||
|
||||
def collect_act_to_disk(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1):
|
||||
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
|
||||
"""
|
||||
Collect activations from a model and store them in a dataset
|
||||
|
||||
Args:
|
||||
- loader: DataLoader
|
||||
- model: AutoModelForCausalLM
|
||||
- dataset_name: str
|
||||
- layers: List[str] - selected from `model.named_modules()`
|
||||
- dataset_dir: Path
|
||||
- postprocess_result: Callable - see `default_postprocess_result` for signature
|
||||
"""
|
||||
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}"
|
||||
f.parent.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"creating dataset {f}")
|
||||
|
||||
iterator = generate_batches(loader, model, layers=layers)
|
||||
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
|
||||
with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer:
|
||||
for bo in iterator:
|
||||
|
||||
|
||||
+34
-28
@@ -12,18 +12,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
||||
"\n",
|
||||
"from activation_store.collect import collect_act_to_disk\n",
|
||||
"from activation_store.collect import activation_store\n",
|
||||
"\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
@@ -41,9 +48,16 @@
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load data and tokenize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -65,29 +79,7 @@
|
||||
"max_length = 256\n",
|
||||
"\n",
|
||||
"imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n",
|
||||
"imdb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Dataset({\n",
|
||||
" features: ['input_ids', 'attention_mask'],\n",
|
||||
" num_rows: 20\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"def proc(row):\n",
|
||||
" messages = [\n",
|
||||
@@ -102,6 +94,13 @@
|
||||
"ds2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Data loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
@@ -140,6 +139,13 @@
|
||||
"# outputs.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Collect activations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
@@ -187,7 +193,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -233,7 +239,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a, f = collect_act_to_disk(ds, model, layers=layers)\n",
|
||||
"ds_a, f = activation_store(ds, model, layers=layers)\n",
|
||||
"ds_a"
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user