From 5af1bc748876506a2e7ed0b3b7af588e27aa70f9 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sat, 15 Mar 2025 11:52:53 +0800 Subject: [PATCH] misc --- activation_store/collect.py | 6 +++--- activation_store/helpers/torch.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/activation_store/collect.py b/activation_store/collect.py index 5cb54e2..5e8488b 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -72,7 +72,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu return o -@torch.no_grad +@torch.no_grad() def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]: """ Collect activations from a model @@ -109,7 +109,7 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [ o = postprocess_result(batch, trace, out, model, act_groups=act_groups) # copy to avoid memory leaks - o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()} + o = {k: v.detach().cpu().contiguous().clone() if isinstance(v, Tensor) else v for k, v in o.items()} o = recursive_copy(o) out = trace = batch = None clear_mem() @@ -163,7 +163,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na # or maybe better compression to `writer.write(example, key)` for each writer.write_batch(bo) del bo - gc.collect() + clear_mem() writer.finalize() writer.close() diff --git a/activation_store/helpers/torch.py b/activation_store/helpers/torch.py index e649b37..33ba1f6 100644 --- a/activation_store/helpers/torch.py +++ b/activation_store/helpers/torch.py @@ -4,5 +4,7 @@ import gc def clear_mem(): gc.collect() - torch.cuda.empty_cache() - gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect()