diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 128baafe..41740dcf 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -31,6 +31,7 @@ class DataCollatorForPairRank: padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None + drop_token_type: bool = False def __call__(self, features): @@ -51,6 +52,8 @@ class DataCollatorForPairRank: pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) + if self.drop_token_type: + batch.pop('token_type_ids') # batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()} return batch