mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 16:44:07 +08:00
handle 0d tensors like loss
This commit is contained in:
@@ -42,8 +42,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
|
||||
o['attention_mask'] = input['attention_mask']
|
||||
if 'label' in input:
|
||||
o['label'] = input['label']
|
||||
input = output = acts = None
|
||||
|
||||
# convert any 0d tensors like loss to 1d, by repeating along batch dimension
|
||||
for k, v in o.items():
|
||||
if v.dim() == 0:
|
||||
bs = input['input_ids'].shape[0]
|
||||
o[k] = v.repeat(bs)
|
||||
return o
|
||||
|
||||
|
||||
@@ -105,7 +109,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
|
||||
for bo in iterator:
|
||||
|
||||
bs = len(next(iter(bo.values())))
|
||||
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
|
||||
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size as first dimension"
|
||||
|
||||
# or maybe better compression to `writer.write(example, key)` for each
|
||||
writer.write_batch(bo)
|
||||
|
||||
Reference in New Issue
Block a user