fix caching

This commit is contained in:
wassname
2025-05-03 20:19:03 +08:00
parent ea2d23bcd4
commit 9661ff75b6
+6 -3
View File
@@ -133,6 +133,9 @@ def output_dataset_hash(**kwargs):
kwargs[k] = f"PreTrainedModel_{v.config._name_or_path}" # PretrainedConfig
elif inspect.isfunction(v):
kwargs[k] = "Function: %s.%s" % (v.__module__, v.__name__)
else:
# logger.debug(f"hashing {k} as {v} of type {type(v)}")
kwargs[k] = str(v)
logger.debug(f"hashing {kwargs}")
suffix = Hasher.hash(kwargs)
return suffix
@@ -159,11 +162,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
f = activation_store(loader, model, layers=['transformer.h'])
Dataset.from_parquet(f).with_format("torch")
"""
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
if outfile is None:
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
outdir = Path(tempfile.gettempdir()) / 'activation_store'
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
outfile.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {outfile}")