diff --git a/activation_store/collect.py b/activation_store/collect.py index d6ee003..509de17 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -10,29 +10,33 @@ from einops import rearrange from datasets.arrow_writer import ArrowWriter, ParquetWriter from datasets.fingerprint import Hasher from transformers.modeling_outputs import ModelOutput - from activation_store.helpers.torch import clear_mem from typing import Dict, Generator from torch import Tensor default_output_folder = (Path(__file__).parent.parent / "outputs").resolve() -def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]: +@torch.no_grad() +def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, last_token=False) -> Dict[str, Tensor]: """Make your own. This adds activations to output, and rearranges hidden states""" + token_index = -1 if last_token else slice(None) # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want acts = {f'act-{k}': - v.output[0] if isinstance(v.output, tuple) else v.output + v.output[0][:, token_index] if isinstance(v.output, tuple) else v.output for k, v in trace.items()} + del trace # batch must be first, also the writer supports float16 so lets use that - output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h').half() + output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, token_index].half() o = dict(**acts, **output) - if 'attention_mask' in input: + if ('attention_mask' in input) and not last_token: o['attention_mask'] = input['attention_mask'] if 'label' in input: o['label'] = input['label'] + input = output = acts = None + return o @@ -41,8 +45,11 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po model.eval() for batch in tqdm(loader, 'collecting activations'): device = next(model.parameters()).device - with torch.amp.autocast(device_type=device.type): - with TraceDict(model, layers) as trace: + with torch.autocast(device_type=device.type): + + # FIXME for some reason autocast isn't converting the inputs + batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()} + with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace: out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True) o = postprocess_result(batch, trace, out, model)