diff --git a/activation_store/collect.py b/activation_store/collect.py index 259dd52..3013217 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -168,6 +168,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='' hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result) outdir = Path(tempfile.gettempdir()) / 'activation_store' outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet" + + if outfile.exists(): + logger.warning(f"file {outfile} already exists, skipping") + return outfile + outfile.parent.mkdir(exist_ok=True, parents=True) logger.info(f"creating dataset {outfile}")