This commit is contained in:
wassname
2025-05-03 20:35:31 +08:00
parent fe64c9dee4
commit 8692769bce
2 changed files with 9 additions and 6 deletions
+4 -4
View File
@@ -141,7 +141,7 @@ def output_dataset_hash(**kwargs):
return suffix
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Path:
"""
Collect activations from a model and store them in a dataset
@@ -160,11 +160,12 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
Usage:
f = activation_store(loader, model, layers=['transformer.h'])
Dataset.from_parquet(f).with_format("torch")
Dataset.from_parquet(f, split='train', keep_in_memory=False).with_format("torch")
# or
load_dataset("parquet", split='train', data_files=str(f), keep_in_memory=False).with_format("torch")
"""
if outfile is None:
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
outdir = Path(tempfile.gettempdir()) / 'activation_store'
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
@@ -186,7 +187,6 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
bs = len(next(iter(bo.values())))
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size as first dimension"
# or maybe better compression to `writer.write(example, key)` for each
writer.write_batch(bo)
del bo
clear_mem()
+5 -2
View File
@@ -287,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -319,7 +319,10 @@
}
],
"source": [
"from datasets import Dataset\n",
"from datasets import Dataset, load_dataset\n",
"\n",
"# ds_a = load_dataset(\"parquet\", split='train', data_files=str(f), keep_in_memory=False)\n",
"# OR\n",
"ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n",
"ds_a"
]