[fix] Add drop_token_type to use galactica

This commit is contained in:
theblackcat102
2022-12-31 09:42:49 +00:00
parent 3a10f1024a
commit d2572d0323
+3
View File
@@ -31,6 +31,7 @@ class DataCollatorForPairRank:
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
drop_token_type: bool = False
def __call__(self, features):
@@ -51,6 +52,8 @@ class DataCollatorForPairRank:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if self.drop_token_type:
batch.pop('token_type_ids')
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
return batch