From 9f21a008625e871f6e1e1b04be7cbbca4e307d80 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 12 Mar 2025 16:53:11 +0800 Subject: [PATCH] also trim logits as it's large --- activation_store/collect.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/activation_store/collect.py b/activation_store/collect.py index 4393fa9..b42c26f 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -37,6 +37,8 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu # batch must be first, also the writer supports float16 so lets use that output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype) + output.logits = output.logits[:, token_index].to(dtype) + o = dict(**acts, **output) if ('attention_mask' in input) and not last_token: o['attention_mask'] = input['attention_mask']