This commit is contained in:
wassname
2025-03-15 11:52:53 +08:00
parent 23a30319a5
commit 5af1bc7488
2 changed files with 7 additions and 5 deletions
+3 -3
View File
@@ -72,7 +72,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
return o
@torch.no_grad
@torch.no_grad()
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
"""
Collect activations from a model
@@ -109,7 +109,7 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
# copy to avoid memory leaks
o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()}
o = {k: v.detach().cpu().contiguous().clone() if isinstance(v, Tensor) else v for k, v in o.items()}
o = recursive_copy(o)
out = trace = batch = None
clear_mem()
@@ -163,7 +163,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
# or maybe better compression to `writer.write(example, key)` for each
writer.write_batch(bo)
del bo
gc.collect()
clear_mem()
writer.finalize()
writer.close()
+4 -2
View File
@@ -4,5 +4,7 @@ import gc
def clear_mem():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()