diff --git a/activation_store/collect.py b/activation_store/collect.py index 0b25d51..919a9d5 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -1,4 +1,4 @@ -from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase import torch from datasets import Dataset from tqdm.auto import tqdm @@ -8,17 +8,14 @@ from loguru import logger from pathlib import Path from baukit.nethook import TraceDict, recursive_copy from einops import rearrange -from datasets.arrow_writer import ArrowWriter, ParquetWriter +from datasets.arrow_writer import ParquetWriter from datasets.fingerprint import Hasher from transformers.modeling_outputs import ModelOutput from activation_store.helpers.torch import clear_mem from typing import Dict, Generator, List, Union, Optional from torch import Tensor import tempfile -import gc -import os - -default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() +import inspect @torch.no_grad() def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]: @@ -127,24 +124,27 @@ def output_dataset_hash(**kwargs): for k,v in kwargs.items(): if isinstance(v, Dataset): - kwargs[k] = v._fingerprint + kwargs[k] = f"Dataset_{v._fingerprint}_{len(v)}" elif isinstance(v, DataLoader): - kwargs[k] = v.dataset._fingerprint # assume this is Dataset + kwargs[k] = f"DataLoader.dataset_{v.dataset._fingerprint}_{len(v)}_{v.batch_size}" elif isinstance(v, PreTrainedTokenizerBase): raise NotImplementedError("hashing tokenizers not implemented") elif isinstance(v, PreTrainedModel): - kwargs[k] = v.config._name_or_path # PretrainedConfig + kwargs[k] = f"PreTrainedModel_{v.config._name_or_path}" # PretrainedConfig + elif inspect.isfunction(v): + kwargs[k] = "Function: %s.%s" % (v.__module__, v.__name__) + logger.debug(f"hashing {kwargs}") suffix = Hasher.hash(kwargs) return suffix -def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: +def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: """ Collect activations from a model and store them in a dataset Args: - loader: DataLoader - - model: AutoModelForCausalLM + - model: PreTrainedModel - dataset_name: str - layers: - List[str] selected from `model.named_modules()` @@ -160,11 +160,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='' Dataset.from_parquet(f).with_format("torch") """ # FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc - hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model) + hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result) if outfile is None: outdir = Path(tempfile.mkdtemp(prefix='activation_store')) - outfile = outdir / f"ds_{dataset_name}_{hash}.parquet" + outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet" outfile.parent.mkdir(exist_ok=True, parents=True) logger.info(f"creating dataset {outfile}")