mirror of
https://github.com/wassname/activation_store.git
synced 2026-07-05 17:52:10 +08:00
tidy
This commit is contained in:
+13
-13
@@ -1,4 +1,4 @@
|
||||
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from tqdm.auto import tqdm
|
||||
@@ -8,17 +8,14 @@ from loguru import logger
|
||||
from pathlib import Path
|
||||
from baukit.nethook import TraceDict, recursive_copy
|
||||
from einops import rearrange
|
||||
from datasets.arrow_writer import ArrowWriter, ParquetWriter
|
||||
from datasets.arrow_writer import ParquetWriter
|
||||
from datasets.fingerprint import Hasher
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from activation_store.helpers.torch import clear_mem
|
||||
from typing import Dict, Generator, List, Union, Optional
|
||||
from torch import Tensor
|
||||
import tempfile
|
||||
import gc
|
||||
import os
|
||||
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
import inspect
|
||||
|
||||
@torch.no_grad()
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: PreTrainedModel, act_groups=Optional[Dict[str,List[str]]], last_token=True, dtype=torch.float16, save_hidden_states=True) -> Dict[str, Tensor]:
|
||||
@@ -127,24 +124,27 @@ def output_dataset_hash(**kwargs):
|
||||
|
||||
for k,v in kwargs.items():
|
||||
if isinstance(v, Dataset):
|
||||
kwargs[k] = v._fingerprint
|
||||
kwargs[k] = f"Dataset_{v._fingerprint}_{len(v)}"
|
||||
elif isinstance(v, DataLoader):
|
||||
kwargs[k] = v.dataset._fingerprint # assume this is Dataset
|
||||
kwargs[k] = f"DataLoader.dataset_{v.dataset._fingerprint}_{len(v)}_{v.batch_size}"
|
||||
elif isinstance(v, PreTrainedTokenizerBase):
|
||||
raise NotImplementedError("hashing tokenizers not implemented")
|
||||
elif isinstance(v, PreTrainedModel):
|
||||
kwargs[k] = v.config._name_or_path # PretrainedConfig
|
||||
kwargs[k] = f"PreTrainedModel_{v.config._name_or_path}" # PretrainedConfig
|
||||
elif inspect.isfunction(v):
|
||||
kwargs[k] = "Function: %s.%s" % (v.__module__, v.__name__)
|
||||
logger.debug(f"hashing {kwargs}")
|
||||
suffix = Hasher.hash(kwargs)
|
||||
return suffix
|
||||
|
||||
|
||||
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=default_output_folder, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name='', layers: Union[List[str], Dict[str, List[str]]]=[], dataset_dir=None, writer_batch_size=1, postprocess_result=default_postprocess_result, outfile: Optional[Path] = None) -> Dataset:
|
||||
"""
|
||||
Collect activations from a model and store them in a dataset
|
||||
|
||||
Args:
|
||||
- loader: DataLoader
|
||||
- model: AutoModelForCausalLM
|
||||
- model: PreTrainedModel
|
||||
- dataset_name: str
|
||||
- layers:
|
||||
- List[str] selected from `model.named_modules()`
|
||||
@@ -160,11 +160,11 @@ def activation_store(loader: DataLoader, model: PreTrainedModel, dataset_name=''
|
||||
Dataset.from_parquet(f).with_format("torch")
|
||||
"""
|
||||
# FIXME I think this is the problem, instead of using a naive hash I will need to custom hash some key parts, model_name, dataset_name, layers, etc
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model)
|
||||
hash = output_dataset_hash(generate_batches=generate_batches, loader=loader, model=model, layers=layers, postprocess_result=postprocess_result)
|
||||
|
||||
if outfile is None:
|
||||
outdir = Path(tempfile.mkdtemp(prefix='activation_store'))
|
||||
outfile = outdir / f"ds_{dataset_name}_{hash}.parquet"
|
||||
outfile = outdir / f"ds_act_{dataset_name}_{hash}.parquet"
|
||||
outfile.parent.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"creating dataset {outfile}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user