mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:28:55 +08:00
tweaks to prevent OOM
This commit is contained in:
@@ -10,29 +10,33 @@ from einops import rearrange
|
||||
from datasets.arrow_writer import ArrowWriter, ParquetWriter
|
||||
from datasets.fingerprint import Hasher
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from activation_store.helpers.torch import clear_mem
|
||||
from typing import Dict, Generator
|
||||
from torch import Tensor
|
||||
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]:
|
||||
@torch.no_grad()
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, last_token=False) -> Dict[str, Tensor]:
|
||||
"""Make your own. This adds activations to output, and rearranges hidden states"""
|
||||
token_index = -1 if last_token else slice(None)
|
||||
|
||||
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
|
||||
acts = {f'act-{k}':
|
||||
v.output[0] if isinstance(v.output, tuple) else v.output
|
||||
v.output[0][:, token_index] if isinstance(v.output, tuple) else v.output
|
||||
for k, v in trace.items()}
|
||||
del trace
|
||||
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h').half()
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, token_index].half()
|
||||
|
||||
o = dict(**acts, **output)
|
||||
if 'attention_mask' in input:
|
||||
if ('attention_mask' in input) and not last_token:
|
||||
o['attention_mask'] = input['attention_mask']
|
||||
if 'label' in input:
|
||||
o['label'] = input['label']
|
||||
input = output = acts = None
|
||||
|
||||
return o
|
||||
|
||||
|
||||
@@ -41,8 +45,11 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers, po
|
||||
model.eval()
|
||||
for batch in tqdm(loader, 'collecting activations'):
|
||||
device = next(model.parameters()).device
|
||||
with torch.amp.autocast(device_type=device.type):
|
||||
with TraceDict(model, layers) as trace:
|
||||
with torch.autocast(device_type=device.type):
|
||||
|
||||
# FIXME for some reason autocast isn't converting the inputs
|
||||
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
|
||||
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
o = postprocess_result(batch, trace, out, model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user