diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 0a0b7a5a..e43d6d8e 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -7,6 +7,8 @@ import torch from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase +from .formatting import QA_SPECIAL_TOKENS + @dataclass class DialogueDataCollator: @@ -28,7 +30,7 @@ class DialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(self.tokenizer.eos_token) + messages.append(QA_SPECIAL_TOKENS["Question"]) flatten_message = self.tokenizer( "".join(messages), @@ -101,7 +103,7 @@ class TrainDialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(self.tokenizer.eos_token) + messages.append(QA_SPECIAL_TOKENS["Question"]) flatten_message = self.tokenizer( "".join(messages),