try stacking acts

This commit is contained in:
wassname
2025-03-14 16:18:04 +08:00
parent 9f21a00862
commit 9d63a74fc6
+7 -3
View File
@@ -29,9 +29,13 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
# usually [b, t, h] but it depends on the model
acts = {f'act-{k}':
v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output
for k, v in trace.items()}
try:
acts = {'acts': torch.stack([v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] for k, v in trace.items()], dim=1)}
except Exception as e:
logger.error(f"failed to stack activations: {e}")
acts = {f'act-{k}':
v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index]
for k, v in trace.items()}
del trace
# batch must be first, also the writer supports float16 so lets use that