This commit is contained in:
wassname
2025-03-21 16:36:17 +08:00
parent 37b1123fb0
commit ea2d23bcd4
+13 -13
View File
@@ -1,4 +1,4 @@
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase
import torch
from datasets import Dataset
from tqdm.auto import tqdm
@@ -8,17 +8,14 @@ from loguru import logger
from pathlib import Path
from baukit.nethook import TraceDict, recursive_copy
from einops import rearrange
from datasets.arrow_writer import ArrowWriter, ParquetWriter
from datasets.arrow_writer import ParquetWriter
from datasets.fingerprint import Hasher
from transformers.modeling_outputs import ModelOutput
from activation_store.helpers.torch import clear_mem
from typing import Dict, Generator, List, Union, Optional
from torch import Tensor
import tempfile
import gc
import os
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
import inspect
@torch.no_grad()
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]:
@@ -127,24 +124,27 @@ def output_dataset_hash(**kwargs):
for k,v in kwargs.items():
if isinstance(v, Dataset):
kwargs[k] = v._fingerprint
kwargs[k] = f"Dataset_{v._fingerprint}_{len(v)}"
elif isinstance(v, DataLoader):
kwargs[k] = v.dataset._fingerprint # assume this is Dataset
kwargs[k] = f"DataLoader.dataset_{v.dataset._fingerprint}_{len(v)}_{v.batch_size}"
elif isinstance(v, PreTrainedTokenizerBase):
raise NotImplementedError("hashing tokenizers not implemented")
elif isinstance(v, PreTrainedModel):
kwargs[k] = v.config._name_or_path # PretrainedConfig
kwargs[k] = f"PreTrainedModel_{v.config._name_or_path}" # PretrainedConfig
elif inspect.isfunction(v):
kwargs[k] = "Function: %s.%s" % (v.__module__, v.__name__)
logger.debug(f"hashing {kwargs}")
suffix = Hasher.hash(kwargs)
return suffix
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
"""
Collect activations from a model and store them in a dataset
Args:
- loader: DataLoader
- model: AutoModelForCausalLM
- model: PreTrainedModel
- dataset_name: str
- layers:
- List[str] selected from `model.named_modules()`
@@ -160,11 +160,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
Dataset.from_parquet(f).with_format("torch")
"""
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
if outfile is None:
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
outfile = outdir / f"ds_{dataset_name}_{hash}.parquet"
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
outfile.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"creating dataset {outfile}")