diff --git a/activation_store/collect.py b/activation_store/collect.py index 3013217..b4d1914 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -141,7 +141,7 @@ def output_dataset_hash(**kwargs): return suffix -def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: +def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Path: """ Collect activations from a model and store them in a dataset @@ -160,11 +160,12 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='' Usage: f = activation_store(loader, model, layers=['transformer.h']) - Dataset.from_parquet(f).with_format("torch") + Dataset.from_parquet(f, split='train', keep_in_memory=False).with_format("torch") + # or + load_dataset("parquet", split='train', data_files=str(f), keep_in_memory=False).with_format("torch") """ if outfile is None: - # FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result) outdir = Path(tempfile.gettempdir()) / 'activation_store' outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet" @@ -186,7 +187,6 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='' bs = len(next(iter(bo.values()))) assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size as first dimension" - # or maybe better compression to `writer.write(example, key)` for each writer.write_batch(bo) del bo clear_mem() diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 834341f..309360d 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -287,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -319,7 +319,10 @@ } ], "source": [ - "from datasets import Dataset\n", + "from datasets import Dataset, load_dataset\n", + "\n", + "# ds_a = load_dataset(\"parquet\", split='train', data_files=str(f), keep_in_memory=False)\n", + "# OR\n", "ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n", "ds_a" ]