mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:28:55 +08:00
also trim logits as it's large
This commit is contained in:
@@ -37,6 +37,8 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype)
|
||||
|
||||
output.logits = output.logits[:, token_index].to(dtype)
|
||||
|
||||
o = dict(**acts, **output)
|
||||
if ('attention_mask' in input) and not last_token:
|
||||
o['attention_mask'] = input['attention_mask']
|
||||
|
||||
Reference in New Issue
Block a user