mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
wip
This commit is contained in:
@@ -36,11 +36,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
aas = [v.output[0] if isinstance(v.output, tuple) else v.output for k, v in trace.items() if k in group]
|
||||
assert len(aas) > 0, f"no activations found for {group}"
|
||||
aas = torch.stack([a[:, token_index].to(dtype) for a in aas], dim=1)
|
||||
acts[k] = aas
|
||||
acts[f'acts-{k}'] = aas
|
||||
else:
|
||||
acts = {f'act-{k}':
|
||||
v.output[0][:, token_index].to(dtype) if isinstance(v.output, tuple) else v.output[:, token_index]
|
||||
v.output[0] if isinstance(v.output, tuple) else v.output
|
||||
for k, v in trace.items()}
|
||||
acts = {k: v[:, token_index].to(dtype) for k, v in acts.items() if v is not None}
|
||||
del trace
|
||||
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
|
||||
Reference in New Issue
Block a user