diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 1d196fb2..925b6dda 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -16,29 +16,32 @@ defaults: eval_accumulation_steps: freeze_layer: datasets: - - webgpt - - prompt_dialogue - - squad_v2 - - adversarial_qa - - trivia_qa_nocontext - - xsum - - cnn_dailymail - - prompt_dialogue - - multi_news - - scitldr - - soda - - joke - - gsm8k - - dive_mt - - wmt2019_zh-en - - wmt2019_ru-en - - wmt2019_de-en - - ted_trans_nl-en - - ted_trans_de-ja - - instruct_tuning - - wmt2019_de-en - - samsum - - soda_dialogue + - webgpt: + fraction : 0.001 + - prompt_dialogue: + size : 10 + - squad_v2: + size : 10 + # - adversarial_qa + # - trivia_qa_nocontext + # - xsum + # - cnn_dailymail + # - prompt_dialogue + # - multi_news + # - scitldr + # - soda + # - joke + # - gsm8k + # - dive_mt + # - wmt2019_zh-en + # - wmt2019_ru-en + # - wmt2019_de-en + # - ted_trans_nl-en + # - ted_trans_de-ja + # - instruct_tuning + # - wmt2019_de-en + # - samsum + # - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 8f8cc63c..efe6df89 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,7 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -mpi4py==3.1.4 +# mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 0acb10dd..06ae39db 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -8,7 +8,8 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from utils import (build_train_sampler, get_dataset, get_loss, get_metrics, + get_model, get_tokenizer, read_yamls) def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -30,6 +31,7 @@ class SFTTrainer(Trainer): self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, + sampler: torch.utils.data.sampler.Sampler = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, **kwargs, @@ -38,6 +40,7 @@ class SFTTrainer(Trainer): # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) + self.sampler = sampler def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -88,6 +91,22 @@ class SFTTrainer(Trainer): return (loss, logits, labels) + def get_train_dataloader(self): + if self.sampler is None: + torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=self.data_collator + ) + else: + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + sampler=self.sampler, + collate_fn=self.data_collator + ) + def _strtobool(x): return bool(strtobool(x)) @@ -141,8 +160,8 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + sampler = build_train_sampler(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF if training_conf.quantization: @@ -159,7 +178,7 @@ if __name__ == "__main__": learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, optim=optimizer, - fp16=True, + # fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, @@ -177,19 +196,20 @@ if __name__ == "__main__": ) assert len(evals) > 0 - if not training_conf.deepspeed or training_conf.local_rank == 0: import wandb wandb.init( project="supervised-finetuning", - entity=training_conf.wandb_entity, + # entity=training_conf.wandb_entity, + entity="maw501", name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) - + trainer = SFTTrainer( - model, - args, + model=model, + args=args, + sampler=sampler, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f7a0ab15..6d6b271a 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -2,9 +2,10 @@ from pathlib import Path import evaluate +import random # import nltk -# import numpy as np +import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -14,6 +15,35 @@ from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset +from torch.utils.data.sampler import Sampler + + +class ClassSampler(Sampler): + """Sampler which returns a fixed number of samples per class, per epoch""" + def __init__(self, class_labels, class_sizes): + self.class_labels = class_labels + self.class_sizes = class_sizes + + def __iter__(self): + out = [] + for i, _class in enumerate(np.unique(self.class_labels)): + class_idx = np.argwhere(self.class_labels == _class).flatten() + sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i])) + out.extend(sampled_idx) + random.shuffle(out) + return iter(out) + + def __len__(self): + return int(sum(self.class_sizes)) + + +def build_train_sampler(training_conf, datasets): + train_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, train_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)] + dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)] + dataset_labels = [i for s in dataset_labels for i in s] + return ClassSampler(dataset_labels, dataset_size_per_epoch) def get_tokenizer(conf): @@ -115,10 +145,35 @@ def get_model(conf, tokenizer): return model +def get_dataset_name_from_data_config(data_config): + if isinstance(data_config, dict): + return list(data_config.keys())[0] + return data_config + + +def get_dataset_fractions(conf, dataset_sizes): + fractions = [] + for i, data_config in enumerate(conf): + dataset_name = get_dataset_name_from_data_config(data_config) + if isinstance(data_config, dict): + if "fraction" in data_config[dataset_name]: + fractions.append(min(1, data_config[dataset_name]["fraction"])) + elif "size" in data_config[dataset_name]: + if data_config[dataset_name]["size"] > dataset_sizes[i]: + raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})") + fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) + else: + raise ValueError("Please specify either fraction or size in config.yaml") + else: + fractions.append(1) + return fractions + + def get_dataset(conf, tokenizer): train_datasets, evals = [], {} - for dataset_name in conf.datasets: + for data_config in conf.datasets: + dataset_name = get_dataset_name_from_data_config(data_config) train, val = get_one_dataset(conf, dataset_name) train_datasets.append(train) evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val