mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
fix caching
This commit is contained in:
@@ -133,6 +133,9 @@ def output_dataset_hash(**kwargs):
|
||||
kwargs[k] = f"PreTrainedModel_{v.config._name_or_path}" # PretrainedConfig
|
||||
elif inspect.isfunction(v):
|
||||
kwargs[k] = "Function: %s.%s" % (v.__module__, v.__name__)
|
||||
else:
|
||||
# logger.debug(f"hashing {k} as {v} of type {type(v)}")
|
||||
kwargs[k] = str(v)
|
||||
logger.debug(f"hashing {kwargs}")
|
||||
suffix = Hasher.hash(kwargs)
|
||||
return suffix
|
||||
@@ -159,11 +162,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
|
||||
f = activation_store(loader, model, layers=['transformer.h'])
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
"""
|
||||
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
|
||||
|
||||
if outfile is None:
|
||||
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
|
||||
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
|
||||
outdir = Path(tempfile.gettempdir()) / 'activation_store'
|
||||
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
|
||||
outfile.parent.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"creating dataset {outfile}")
|
||||
|
||||
Reference in New Issue
Block a user