diff --git a/README.md b/README.md index b2f5755..c796e7b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Utility library to persistently store transformer activations on disk. -These activations can be quite large (layers x batch x sequence x hidden_size), so storing them on disk helps avoid out of memory errors. +These activations can be quite large (layers x batch x sequence x hidden_size), so generating them to disk helps avoid out of memory errors. Install using `pip install git+https://github.com/wassname/activation_store.git`. diff --git a/activation_store/collect.py b/activation_store/collect.py index 94d9457..a85cec3 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -53,13 +53,24 @@ def dataset_hash(**kwargs): return suffix -def collect_act_to_disk(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1): +def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset: + """ + Collect activations from a model and store them in a dataset + + Args: + - loader: DataLoader + - model: AutoModelForCausalLM + - dataset_name: str + - layers: List[str] - selected from `model.named_modules()` + - dataset_dir: Path + - postprocess_result: Callable - see `default_postprocess_result` for signature + """ hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}" f.parent.mkdir(exist_ok=True, parents=True) logger.info(f"creating dataset {f}") - iterator = generate_batches(loader, model, layers=layers) + iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result) with ArrowWriter(path=f, writer_batch_size=writer_batch_size) as writer: for bo in iterator: diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 195bdfa..6d84ed6 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -12,18 +12,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", - "from activation_store.collect import collect_act_to_disk\n", + "from activation_store.collect import activation_store\n", "\n", "import torch" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load model" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -41,9 +48,16 @@ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and tokenize" + ] + }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -65,29 +79,7 @@ "max_length = 256\n", "\n", "imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n", - "imdb" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['input_ids', 'attention_mask'],\n", - " num_rows: 20\n", - "})" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ + "\n", "\n", "def proc(row):\n", " messages = [\n", @@ -102,6 +94,13 @@ "ds2" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loader" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -140,6 +139,13 @@ "# outputs.keys()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Collect activations" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -187,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -233,7 +239,7 @@ } ], "source": [ - "ds_a, f = collect_act_to_disk(ds, model, layers=layers)\n", + "ds_a, f = activation_store(ds, model, layers=layers)\n", "ds_a" ] },