From 44ed44e05d837506f4c5e9385c3240ee53e15d70 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Sat, 11 Feb 2023 10:33:25 +0100 Subject: [PATCH] deactivtae samples mixing by default --- model/supervised_finetuning/configs/config.yaml | 1 + model/supervised_finetuning/utils.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 91a38596..191a7391 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -48,6 +48,7 @@ defaults: poly_eps: 1.0 fuse_gelu: true log_wandb: true + samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within galactica-125m: learning_rate: 5e-5 diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 43377bc9..0fd9ef00 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -238,7 +238,11 @@ def get_dataset(conf, tokenizer): train = ConcatDataset(train_datasets) collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length) - train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) + + train_collate_fn = ( + TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) if conf.samples_mixing else collate_fn + ) + return train, evals, collate_fn, train_collate_fn