diff --git a/activation_store/collect.py b/activation_store/collect.py index 5e8488b..0b25d51 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -1,4 +1,4 @@ -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase import torch from datasets import Dataset from tqdm.auto import tqdm @@ -21,7 +21,7 @@ import os default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() @torch.no_grad() -def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16) -> Dict[str, Tensor]: +def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]: """Make your own. This adds activations to output, and rearranges hidden states. Note the parquet write support float16, so we use that. It does not support float8, bfloat16, etc. @@ -34,21 +34,24 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu # Baukit records the literal layer output, which varies by model. Sometimes you get a tuple, or not.Usually [b, t, h] for MLP, but not for attention layers. You may need to customize this. if act_groups is not None: acts = {} - for k, group in act_groups.items(): - aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group] - assert len(aas) > 0, f"no activations found for {group}" - assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}" - aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1) - acts[f'acts-{k}'] = aas + for grp_nm, group in act_groups.items(): + grp_acts = [v.output[0] if isinstance(v.output, tuple) else v.output for lyr_nm, v in trace.items() if lyr_nm in group] + assert len(grp_acts) > 0, f"no activations found for {group}" + assert grp_acts[0].dim() == 3, f"expected [b, t, h] activations, got {grp_acts[0].shape}" + grp_acts = torch.stack([a[:, token_index] for a in grp_acts], dim=1) + acts[f'acts-{grp_nm}'] = grp_acts.to(dtype) else: - acts = {f'acts-{k}': + acts = {f'acts-{lyr_nm}': v.output[0] if isinstance(v.output, tuple) else v.output - for k, v in trace.items()} + for lyr_nm, v in trace.items()} acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None} del trace # batch must be first, also the writer supports float16 so lets use that - output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype) + if save_hidden_states and output.hidden_states is not None: + output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype) + else: + output.hidden_states = None output.logits = output.logits[:, token_index].to(dtype) @@ -59,27 +62,27 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu o['label'] = input['label'] # all output tensors must have a batch dim - for k, v in o.items(): + for grp_nm, v in o.items(): if v.dim() == 0: bs = input['input_ids'].shape[0] - o[k] = v.repeat(bs) + o[grp_nm] = v.repeat(bs) # finally check for nans - for k, v in o.items(): + for grp_nm, v in o.items(): if torch.isnan(v).any(): - raise ValueError(f"nan found in {k}") + raise ValueError(f"nan found in {grp_nm}") return o @torch.no_grad() -def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: +def generate_batches(loader: DataLoader, model: PreTrainedModel, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: """ Collect activations from a model Args: - loader: DataLoader - - model: AutoModelForCausalLM + - model: PreTrainedModel - layers: can be - selected from `model.named_modules()` - groups of layers to collect, these will be stacked so they must have compatible sizes @@ -96,16 +99,16 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [ model.eval() for batch in tqdm(loader, 'collecting activations'): device = next(model.parameters()).device - with torch.autocast(device_type=device.type): + # with torch.autocast(device_type=device.type): - # FIXME for some reason autocast isn't converting the inputs - batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()} - if layers is not None: - with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace: - out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) - else: + # FIXME for some reason autocast isn't converting the inputs + batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()} + if layers is not None: + with TraceDict(model, layers, retain_grad=False, detach=True) as trace: out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) - trace = None + else: + out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) + trace = None o = postprocess_result(batch, trace, out, model, act_groups=act_groups) # copy to avoid memory leaks @@ -116,12 +119,26 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [ yield o -def dataset_hash(**kwargs): +def output_dataset_hash(**kwargs): + # special hash for some key parts that don't hash well + func = generate_batches + name = "%s.%s" % (func.__module__, func.__name__) + kwargs['func'] = name + + for k,v in kwargs.items(): + if isinstance(v, Dataset): + kwargs[k] = v._fingerprint + elif isinstance(v, DataLoader): + kwargs[k] = v.dataset._fingerprint # assume this is Dataset + elif isinstance(v, PreTrainedTokenizerBase): + raise NotImplementedError("hashing tokenizers not implemented") + elif isinstance(v, PreTrainedModel): + kwargs[k] = v.config._name_or_path # PretrainedConfig suffix = Hasher.hash(kwargs) return suffix -def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: +def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset: """ Collect activations from a model and store them in a dataset @@ -142,7 +159,8 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na f = activation_store(loader, model, layers=['transformer.h']) Dataset.from_parquet(f).with_format("torch") """ - hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model) + # FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc + hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model) if outfile is None: outdir = Path(tempfile.mkdtemp(prefix='activation_store'))