use tempdir

This commit is contained in:
wassname
2025-03-15 11:35:15 +08:00
parent e099d03638
commit 23a30319a5
+27 -9
View File
@@ -14,6 +14,8 @@ from transformers.modeling_outputs import ModelOutput
from activation_store.helpers.torch import clear_mem
from typing import Dict, Generator, List, Union, Optional
from torch import Tensor
import tempfile
import gc
import os
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
@@ -35,17 +37,18 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
for k, group in act_groups.items():
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
assert len(aas) > 0, f"no activations found for {group}"
assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}"
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
acts[f'acts-{k}'] = aas
else:
acts = {f'act-{k}':
acts = {f'acts-{k}':
v.output[0] if isinstance(v.output, tuple) else v.output
for k, v in trace.items()}
acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None}
del trace
# batch must be first, also the writer supports float16 so lets use that
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype)
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
output.logits = output.logits[:, token_index].to(dtype)
@@ -60,6 +63,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
if v.dim() == 0:
bs = input['input_ids'].shape[0]
o[k] = v.repeat(bs)
# finally check for nans
for k, v in o.items():
if torch.isnan(v).any():
raise ValueError(f"nan found in {k}")
return o
@@ -91,8 +100,12 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
# FIXME for some reason autocast isn't converting the inputs
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
if layers is not None:
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
else:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
trace = None
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
# copy to avoid memory leaks
@@ -108,7 +121,7 @@ def dataset_hash(**kwargs):
return suffix
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
"""
Collect activations from a model and store them in a dataset
@@ -130,13 +143,16 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
Dataset.from_parquet(f).with_format("torch")
"""
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet"
f.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {f}")
if outfile is None:
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
outfile = outdir / f"ds_{dataset_name}_{hash}.parquet"
outfile.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {outfile}")
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
with ParquetWriter(path=f, writer_batch_size=writer_batch_size,
with ParquetWriter(path=outfile, writer_batch_size=writer_batch_size,
embed_local_files=True
) as writer:
for bo in iterator:
@@ -146,7 +162,9 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
# or maybe better compression to `writer.write(example, key)` for each
writer.write_batch(bo)
del bo
gc.collect()
writer.finalize()
writer.close()
return f
return outfile