mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 18:03:14 +08:00
fix, hash was crashing it
This commit is contained in:
+46
-28
@@ -1,4 +1,4 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from tqdm.auto import tqdm
|
||||
@@ -21,7 +21,7 @@ import os
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
|
||||
@torch.no_grad()
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16) -> Dict[str, Tensor]:
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]:
|
||||
"""Make your own. This adds activations to output, and rearranges hidden states.
|
||||
|
||||
Note the parquet write support float16, so we use that. It does not support float8, bfloat16, etc.
|
||||
@@ -34,21 +34,24 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
# Baukit records the literal layer output, which varies by model. Sometimes you get a tuple, or not.Usually [b, t, h] for MLP, but not for attention layers. You may need to customize this.
|
||||
if act_groups is not None:
|
||||
acts = {}
|
||||
for k, group in act_groups.items():
|
||||
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
|
||||
assert len(aas) > 0, f"no activations found for {group}"
|
||||
assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}"
|
||||
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
|
||||
acts[f'acts-{k}'] = aas
|
||||
for grp_nm, group in act_groups.items():
|
||||
grp_acts = [v.output[0] if isinstance(v.output, tuple) else v.output for lyr_nm, v in trace.items() if lyr_nm in group]
|
||||
assert len(grp_acts) > 0, f"no activations found for {group}"
|
||||
assert grp_acts[0].dim() == 3, f"expected [b, t, h] activations, got {grp_acts[0].shape}"
|
||||
grp_acts = torch.stack([a[:, token_index] for a in grp_acts], dim=1)
|
||||
acts[f'acts-{grp_nm}'] = grp_acts.to(dtype)
|
||||
else:
|
||||
acts = {f'acts-{k}':
|
||||
acts = {f'acts-{lyr_nm}':
|
||||
v.output[0] if isinstance(v.output, tuple) else v.output
|
||||
for k, v in trace.items()}
|
||||
for lyr_nm, v in trace.items()}
|
||||
acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None}
|
||||
del trace
|
||||
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
|
||||
if save_hidden_states and output.hidden_states is not None:
|
||||
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
|
||||
else:
|
||||
output.hidden_states = None
|
||||
|
||||
output.logits = output.logits[:, token_index].to(dtype)
|
||||
|
||||
@@ -59,27 +62,27 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
o['label'] = input['label']
|
||||
|
||||
# all output tensors must have a batch dim
|
||||
for k, v in o.items():
|
||||
for grp_nm, v in o.items():
|
||||
if v.dim() == 0:
|
||||
bs = input['input_ids'].shape[0]
|
||||
o[k] = v.repeat(bs)
|
||||
o[grp_nm] = v.repeat(bs)
|
||||
|
||||
|
||||
# finally check for nans
|
||||
for k, v in o.items():
|
||||
for grp_nm, v in o.items():
|
||||
if torch.isnan(v).any():
|
||||
raise ValueError(f"nan found in {k}")
|
||||
raise ValueError(f"nan found in {grp_nm}")
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
|
||||
def generate_batches(loader: DataLoader, model: PreTrainedModel, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
|
||||
"""
|
||||
Collect activations from a model
|
||||
|
||||
Args:
|
||||
- loader: DataLoader
|
||||
- model: AutoModelForCausalLM
|
||||
- model: PreTrainedModel
|
||||
- layers: can be
|
||||
- selected from `model.named_modules()`
|
||||
- groups of layers to collect, these will be stacked so they must have compatible sizes
|
||||
@@ -96,16 +99,16 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
|
||||
model.eval()
|
||||
for batch in tqdm(loader, 'collecting activations'):
|
||||
device = next(model.parameters()).device
|
||||
with torch.autocast(device_type=device.type):
|
||||
# with torch.autocast(device_type=device.type):
|
||||
|
||||
# FIXME for some reason autocast isn't converting the inputs
|
||||
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
|
||||
if layers is not None:
|
||||
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
else:
|
||||
# FIXME for some reason autocast isn't converting the inputs
|
||||
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
|
||||
if layers is not None:
|
||||
with TraceDict(model, layers, retain_grad=False, detach=True) as trace:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
trace = None
|
||||
else:
|
||||
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
|
||||
trace = None
|
||||
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
|
||||
|
||||
# copy to avoid memory leaks
|
||||
@@ -116,12 +119,26 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
|
||||
yield o
|
||||
|
||||
|
||||
def dataset_hash(**kwargs):
|
||||
def output_dataset_hash(**kwargs):
|
||||
# special hash for some key parts that don't hash well
|
||||
func = generate_batches
|
||||
name = "%s.%s" % (func.__module__, func.__name__)
|
||||
kwargs['func'] = name
|
||||
|
||||
for k,v in kwargs.items():
|
||||
if isinstance(v, Dataset):
|
||||
kwargs[k] = v._fingerprint
|
||||
elif isinstance(v, DataLoader):
|
||||
kwargs[k] = v.dataset._fingerprint # assume this is Dataset
|
||||
elif isinstance(v, PreTrainedTokenizerBase):
|
||||
raise NotImplementedError("hashing tokenizers not implemented")
|
||||
elif isinstance(v, PreTrainedModel):
|
||||
kwargs[k] = v.config._name_or_path # PretrainedConfig
|
||||
suffix = Hasher.hash(kwargs)
|
||||
return suffix
|
||||
|
||||
|
||||
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
"""
|
||||
Collect activations from a model and store them in a dataset
|
||||
|
||||
@@ -142,7 +159,8 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
f = activation_store(loader, model, layers=['transformer.h'])
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
"""
|
||||
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
|
||||
if outfile is None:
|
||||
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
|
||||
|
||||
Reference in New Issue
Block a user