pre commits

This commit is contained in:
Sotirios Anagnostidis
2023-02-11 10:29:56 +01:00
parent 0610865de7
commit 540a96fb0e
@@ -7,6 +7,8 @@ import torch
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from .formatting import QA_SPECIAL_TOKENS
@dataclass
class DialogueDataCollator:
@@ -28,7 +30,7 @@ class DialogueDataCollator:
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_message = self.tokenizer(
"".join(messages),
@@ -101,7 +103,7 @@ class TrainDialogueDataCollator:
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_message = self.tokenizer(
"".join(messages),