also trim logits as it's large

This commit is contained in:
wassname
2025-03-12 16:53:11 +08:00
parent 6fed0032f2
commit 9f21a00862
+2
View File
@@ -37,6 +37,8 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
# batch must be first, also the writer supports float16 so lets use that
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')[:, :, token_index].to(dtype)
output.logits = output.logits[:, token_index].to(dtype)
o = dict(**acts, **output)
if ('attention_mask' in input) and not last_token:
o['attention_mask'] = input['attention_mask']