diff --git a/activation_store/collect.py b/activation_store/collect.py index 77a3b30..5cb54e2 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -14,6 +14,8 @@ from transformers.modeling_outputs import ModelOutput from activation_store.helpers.torch import clear_mem from typing import Dict, Generator, List, Union, Optional from torch import Tensor +import tempfile +import gc import os default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() @@ -35,17 +37,18 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu for k, group in act_groups.items(): aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group] assert len(aas) > 0, f"no activations found for {group}" + assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}" aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1) acts[f'acts-{k}'] = aas else: - acts = {f'act-{k}': + acts = {f'acts-{k}': v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items()} acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None} del trace # batch must be first, also the writer supports float16 so lets use that - output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype) + output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype) output.logits = output.logits[:, token_index].to(dtype) @@ -60,6 +63,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu if v.dim() == 0: bs = input['input_ids'].shape[0] o[k] = v.repeat(bs) + + + # finally check for nans + for k, v in o.items(): + if torch.isnan(v).any(): + raise ValueError(f"nan found in {k}") return o @@ -91,8 +100,12 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [ # FIXME for some reason autocast isn't converting the inputs batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()} - with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace: + if layers is not None: + with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace: + out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) + else: out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) + trace = None o = postprocess_result(batch, trace, out, model, act_groups=act_groups) # copy to avoid memory leaks @@ -108,7 +121,7 @@ def dataset_hash(**kwargs): return suffix -def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result) -> Dataset: +def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: """ Collect activations from a model and store them in a dataset @@ -130,13 +143,16 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na Dataset.from_parquet(f).with_format("torch") """ hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) - f = dataset_dir / ".ds" / f"ds_{dataset_name}_{hash}.parquet" - f.parent.mkdir(exist_ok=True, parents=True) - logger.info(f"creating dataset {f}") + + if outfile is None: + outdir = Path(tempfile.mkdtemp(prefix='activation_store')) + outfile = outdir / f"ds_{dataset_name}_{hash}.parquet" + outfile.parent.mkdir(exist_ok=True, parents=True) + logger.info(f"creating dataset {outfile}") iterator = generate_batches(loader, model, layers=layers, postprocess_result=postprocess_result) - with ParquetWriter(path=f, writer_batch_size=writer_batch_size, + with ParquetWriter(path=outfile, writer_batch_size=writer_batch_size, embed_local_files=True ) as writer: for bo in iterator: @@ -146,7 +162,9 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na # or maybe better compression to `writer.write(example, key)` for each writer.write_batch(bo) + del bo + gc.collect() writer.finalize() writer.close() - return f + return outfile