mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
use tempdir
This commit is contained in:
@@ -14,6 +14,8 @@ from transformers.modeling_outputs import ModelOutput
|
||||
from activation_store.helpers.torch import clear_mem
|
||||
from typing import Dict, Generator, List, Union, Optional
|
||||
from torch import Tensor
|
||||
import tempfile
|
||||
import gc
|
||||
import os
|
||||
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
@@ -35,17 +37,18 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
for k, group in act_groups.items():
|
||||
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
|
||||
assert len(aas) > 0, f"no activations found for {group}"
|
||||
assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}"
|
||||
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
|
||||
acts[f'acts-{k}'] = aas
|
||||
else:
|
||||
acts = {f'act-{k}':
|
||||
acts = {f'acts-{k}':
|
||||
v.output[0] if isinstance(v.output, tuple) else v.output
|
||||
for k, v in trace.items()}
|
||||
acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None}
|
||||
del trace
|
||||
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype)
|
||||
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
|
||||
|
||||
output.logits = output.logits[:, token_index].to(dtype)
|
||||
|
||||
@@ -60,6 +63,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
if v.dim() == 0:
|
||||
bs = input['input_ids'].shape[0]
|
||||
o[k] = v.repeat(bs)
|
||||
|
||||
|
||||
# finally check for nans
|
||||
for k, v in o.items():
|
||||
if torch.isnan(v).any():
|
||||
raise ValueError(f"nan found in {k}")
|
||||
return o
|
||||
|
||||
|
||||
@@ -91,8 +100,12 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
|
||||
|
||||
# FIXME for some reason autocast isn't converting the inputs
|
||||
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
|
||||
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
|
||||
if layers is not None:
|
||||
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
else:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
trace = None
|
||||
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
|
||||
|
||||
# copy to avoid memory leaks
|
||||
@@ -108,7 +121,7 @@ def dataset_hash(**kwargs):
|
||||
return suffix
|
||||
|
||||
|
||||
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset:
|
||||
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
"""
|
||||
Collect activations from a model and store them in a dataset
|
||||
|
||||
@@ -130,13 +143,16 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
"""
|
||||
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet"
|
||||
f.parent.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"creating dataset {f}")
|
||||
|
||||
if outfile is None:
|
||||
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
|
||||
outfile = outdir / f"ds_{dataset_name}_{hash}.parquet"
|
||||
outfile.parent.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"creating dataset {outfile}")
|
||||
|
||||
iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result)
|
||||
|
||||
with ParquetWriter(path=f, writer_batch_size=writer_batch_size,
|
||||
with ParquetWriter(path=outfile, writer_batch_size=writer_batch_size,
|
||||
embed_local_files=True
|
||||
) as writer:
|
||||
for bo in iterator:
|
||||
@@ -146,7 +162,9 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
|
||||
# or maybe better compression to `writer.write(example, key)` for each
|
||||
writer.write_batch(bo)
|
||||
del bo
|
||||
gc.collect()
|
||||
writer.finalize()
|
||||
writer.close()
|
||||
|
||||
return f
|
||||
return outfile
|
||||
|
||||
Reference in New Issue
Block a user