handle 0d tensors like loss

This commit is contained in:
wassname
2025-03-12 14:06:40 +08:00
parent 237b8217d5
commit 756b653913
+6 -2
View File
@@ -42,8 +42,12 @@ def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutpu
o['attention_mask'] = input['attention_mask']
if 'label' in input:
o['label'] = input['label']
input = output = acts = None
# convert any 0d tensors like loss to 1d, by repeating along batch dimension
for k, v in o.items():
if v.dim() == 0:
bs = input['input_ids'].shape[0]
o[k] = v.repeat(bs)
return o
@@ -105,7 +109,7 @@ def activation_store(loader: DataLoader, model: AutoModelForCausalLM, dataset_na
for bo in iterator:
bs = len(next(iter(bo.values())))
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size a first dimension"
assert all(len(v) == bs for v in bo.values()), "must return Dict[str,Tensor] and all tensors with same batch size as first dimension"
# or maybe better compression to `writer.write(example, key)` for each
writer.write_batch(bo)