fix, hash was crashing it

This commit is contained in:
wassname
2025-03-20 15:58:34 +08:00
parent 1df4c3eaaf
commit 37b1123fb0
+46 -28
View File
@@ -1,4 +1,4 @@
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
import torch
from datasets import Dataset
from tqdm.auto import tqdm
@@ -21,7 +21,7 @@ import os
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
@torch.no_grad()
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16) -> Dict[str, Tensor]:
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]:
"""Make your own. This adds activations to output, and rearranges hidden states.
Note the parquet write support float16, so we use that. It does not support float8, bfloat16, etc.
@@ -34,21 +34,24 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
# Baukit records the literal layer output, which varies by model. Sometimes you get a tuple, or not.Usually [b, t, h] for MLP, but not for attention layers. You may need to customize this.
if act_groups is not None:
acts = {}
for k, group in act_groups.items():
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
assert len(aas) > 0, f"no activations found for {group}"
assert aas[0].dim() == 3, f"expected [b, t, h] activations, got {aas[0].shape}"
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
acts[f'acts-{k}'] = aas
for grp_nm, group in act_groups.items():
grp_acts = [v.output[0] if isinstance(v.output, tuple) else v.output for lyr_nm, v in trace.items() if lyr_nm in group]
assert len(grp_acts) > 0, f"no activations found for {group}"
assert grp_acts[0].dim() == 3, f"expected [b, t, h] activations, got {grp_acts[0].shape}"
grp_acts = torch.stack([a[:, token_index] for a in grp_acts], dim=1)
acts[f'acts-{grp_nm}'] = grp_acts.to(dtype)
else:
acts = {f'acts-{k}':
acts = {f'acts-{lyr_nm}':
v.output[0] if isinstance(v.output, tuple) else v.output
for k, v in trace.items()}
for lyr_nm, v in trace.items()}
acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None}
del trace
# batch must be first, also the writer supports float16 so lets use that
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
if save_hidden_states and output.hidden_states is not None:
output.hidden_states = rearrange([h[:, token_index] for h in output.hidden_states], 'l b t h -> b l t h').to(dtype)
else:
output.hidden_states = None
output.logits = output.logits[:, token_index].to(dtype)
@@ -59,27 +62,27 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
o['label'] = input['label']
# all output tensors must have a batch dim
for k, v in o.items():
for grp_nm, v in o.items():
if v.dim() == 0:
bs = input['input_ids'].shape[0]
o[k] = v.repeat(bs)
o[grp_nm] = v.repeat(bs)
# finally check for nans
for k, v in o.items():
for grp_nm, v in o.items():
if torch.isnan(v).any():
raise ValueError(f"nan found in {k}")
raise ValueError(f"nan found in {grp_nm}")
return o
@torch.no_grad()
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
def generate_batches(loader: DataLoader, model: PreTrainedModel, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
"""
Collect activations from a model
Args:
- loader: DataLoader
- model: AutoModelForCausalLM
- model: PreTrainedModel
- layers: can be
- selected from `model.named_modules()`
- groups of layers to collect, these will be stacked so they must have compatible sizes
@@ -96,16 +99,16 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
model.eval()
for batch in tqdm(loader, 'collecting activations'):
device = next(model.parameters()).device
with torch.autocast(device_type=device.type):
# with torch.autocast(device_type=device.type):
# FIXME for some reason autocast isn't converting the inputs
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
if layers is not None:
with TraceDict(model, layers, retain_grad=False, detach=True, clone=True) as trace:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
else:
# FIXME for some reason autocast isn't converting the inputs
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
if layers is not None:
with TraceDict(model, layers, retain_grad=False, detach=True) as trace:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
trace = None
else:
out = model(**batch, use_cache=False, output_hidden_states=True, return_dict=True)
trace = None
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
# copy to avoid memory leaks
@@ -116,12 +119,26 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
yield o
def dataset_hash(**kwargs):
def output_dataset_hash(**kwargs):
# special hash for some key parts that don't hash well
func = generate_batches
name = "%s.%s" % (func.__module__, func.__name__)
kwargs['func'] = name
for k,v in kwargs.items():
if isinstance(v, Dataset):
kwargs[k] = v._fingerprint
elif isinstance(v, DataLoader):
kwargs[k] = v.dataset._fingerprint # assume this is Dataset
elif isinstance(v, PreTrainedTokenizerBase):
raise NotImplementedError("hashing tokenizers not implemented")
elif isinstance(v, PreTrainedModel):
kwargs[k] = v.config._name_or_path # PretrainedConfig
suffix = Hasher.hash(kwargs)
return suffix
def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
"""
Collect activations from a model and store them in a dataset
@@ -142,7 +159,8 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
f = activation_store(loader, model, layers=['transformer.h'])
Dataset.from_parquet(f).with_format("torch")
"""
hash = dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
if outfile is None:
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))