diff --git a/activation_store/collect.py b/activation_store/collect.py index b42c26f..52fd899 100644 --- a/activation_store/collect.py +++ b/activation_store/collect.py @@ -29,9 +29,13 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu # Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want # usually [b, t, h] but it depends on the model - acts = {f'act-{k}': - v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output - for k, v in trace.items()} + try: + acts = {'acts': torch.stack([v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] for k, v in trace.items()], dim=1)} + except Exception as e: + logger.error(f"failed to stack activations: {e}") + acts = {f'act-{k}': + v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index] + for k, v in trace.items()} del trace # batch must be first, also the writer supports float16 so lets use that