mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 15:14:01 +08:00
tidy
This commit is contained in:
@@ -141,7 +141,7 @@ def output_dataset_hash(**kwargs):
|
||||
return suffix
|
||||
|
||||
|
||||
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
Collect activations from a model and store them in a dataset
|
||||
|
||||
@@ -160,11 +160,12 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
|
||||
|
||||
Usage:
|
||||
f = activation_store(loader, model, layers=['transformer.h'])
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
Dataset.from_parquet(f, split='train', keep_in_memory=False).with_format("torch")
|
||||
# or
|
||||
load_dataset("parquet", split='train', data_files=str(f), keep_in_memory=False).with_format("torch")
|
||||
"""
|
||||
|
||||
if outfile is None:
|
||||
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
|
||||
outdir = Path(tempfile.gettempdir()) / 'activation_store'
|
||||
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
|
||||
@@ -186,7 +187,6 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
|
||||
bs = len(next(iter(bo.values())))
|
||||
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size as first dimension"
|
||||
|
||||
# or maybe better compression to `writer.write(example, key)` for each
|
||||
writer.write_batch(bo)
|
||||
del bo
|
||||
clear_mem()
|
||||
|
||||
+5
-2
@@ -287,7 +287,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -319,7 +319,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import Dataset\n",
|
||||
"from datasets import Dataset, load_dataset\n",
|
||||
"\n",
|
||||
"# ds_a = load_dataset(\"parquet\", split='train', data_files=str(f), keep_in_memory=False)\n",
|
||||
"# OR\n",
|
||||
"ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n",
|
||||
"ds_a"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user