mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
misc
This commit is contained in:
@@ -72,7 +72,7 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [], postprocess_result=default_postprocess_result) -> Generator[Dict[str, Tensor], None, None]:
|
||||
"""
|
||||
Collect activations from a model
|
||||
@@ -109,7 +109,7 @@ def generate_batches(loader: DataLoader, model: AutoModelForCausalLM, layers = [
|
||||
o = postprocess_result(batch, trace, out, model, act_groups=act_groups)
|
||||
|
||||
# copy to avoid memory leaks
|
||||
o = {k: v.to('cpu') if isinstance(v, Tensor) else v for k, v in o.items()}
|
||||
o = {k: v.detach().cpu().contiguous().clone() if isinstance(v, Tensor) else v for k, v in o.items()}
|
||||
o = recursive_copy(o)
|
||||
out = trace = batch = None
|
||||
clear_mem()
|
||||
@@ -163,7 +163,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
# or maybe better compression to `writer.write(example, key)` for each
|
||||
writer.write_batch(bo)
|
||||
del bo
|
||||
gc.collect()
|
||||
clear_mem()
|
||||
writer.finalize()
|
||||
writer.close()
|
||||
|
||||
|
||||
@@ -4,5 +4,7 @@ import gc
|
||||
|
||||
def clear_mem():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user