diff --git a/activation_store/collect.py b/activation_store/collect.py index 4393fa9..b42c26f 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -37,6 +37,8 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu # batch must be first, also the writer supports float16 so lets use that output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype) + output.logits = output.logits[:, token_index].to(dtype) + o = dict(**acts, **output) if ('attention_mask' in input) and not last_token: o['attention_mask'] = input['attention_mask']