mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
[fix] Add drop_token_type to use galactica
This commit is contained in:
@@ -31,6 +31,7 @@ class DataCollatorForPairRank:
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False
|
||||
|
||||
def __call__(self, features):
|
||||
|
||||
@@ -51,6 +52,8 @@ class DataCollatorForPairRank:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch.pop('token_type_ids')
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user