From e2caf5365465254e8f5563f1dcbeeed9e32710af Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 08:01:04 +0000 Subject: [PATCH 1/3] First version of single GPU sampling working --- .../supervised_finetuning/configs/config.yaml | 49 +++++++-------- model/supervised_finetuning/requirements.txt | 2 +- model/supervised_finetuning/trainer.py | 36 ++++++++--- model/supervised_finetuning/utils.py | 59 ++++++++++++++++++- 4 files changed, 112 insertions(+), 34 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 1d196fb2..925b6dda 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -16,29 +16,32 @@ defaults: eval_accumulation_steps: freeze_layer: datasets: - - webgpt - - prompt_dialogue - - squad_v2 - - adversarial_qa - - trivia_qa_nocontext - - xsum - - cnn_dailymail - - prompt_dialogue - - multi_news - - scitldr - - soda - - joke - - gsm8k - - dive_mt - - wmt2019_zh-en - - wmt2019_ru-en - - wmt2019_de-en - - ted_trans_nl-en - - ted_trans_de-ja - - instruct_tuning - - wmt2019_de-en - - samsum - - soda_dialogue + - webgpt: + fraction : 0.001 + - prompt_dialogue: + size : 10 + - squad_v2: + size : 10 + # - adversarial_qa + # - trivia_qa_nocontext + # - xsum + # - cnn_dailymail + # - prompt_dialogue + # - multi_news + # - scitldr + # - soda + # - joke + # - gsm8k + # - dive_mt + # - wmt2019_zh-en + # - wmt2019_ru-en + # - wmt2019_de-en + # - ted_trans_nl-en + # - ted_trans_de-ja + # - instruct_tuning + # - wmt2019_de-en + # - samsum + # - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 8f8cc63c..efe6df89 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,7 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -mpi4py==3.1.4 +# mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 0acb10dd..06ae39db 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -8,7 +8,8 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from utils import (build_train_sampler, get_dataset, get_loss, get_metrics, + get_model, get_tokenizer, read_yamls) def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -30,6 +31,7 @@ class SFTTrainer(Trainer): self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, + sampler: torch.utils.data.sampler.Sampler = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, **kwargs, @@ -38,6 +40,7 @@ class SFTTrainer(Trainer): # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) + self.sampler = sampler def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -88,6 +91,22 @@ class SFTTrainer(Trainer): return (loss, logits, labels) + def get_train_dataloader(self): + if self.sampler is None: + torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=self.data_collator + ) + else: + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + sampler=self.sampler, + collate_fn=self.data_collator + ) + def _strtobool(x): return bool(strtobool(x)) @@ -141,8 +160,8 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + sampler = build_train_sampler(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF if training_conf.quantization: @@ -159,7 +178,7 @@ if __name__ == "__main__": learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, optim=optimizer, - fp16=True, + # fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, @@ -177,19 +196,20 @@ if __name__ == "__main__": ) assert len(evals) > 0 - if not training_conf.deepspeed or training_conf.local_rank == 0: import wandb wandb.init( project="supervised-finetuning", - entity=training_conf.wandb_entity, + # entity=training_conf.wandb_entity, + entity="maw501", name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) - + trainer = SFTTrainer( - model, - args, + model=model, + args=args, + sampler=sampler, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f7a0ab15..6d6b271a 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -2,9 +2,10 @@ from pathlib import Path import evaluate +import random # import nltk -# import numpy as np +import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -14,6 +15,35 @@ from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset +from torch.utils.data.sampler import Sampler + + +class ClassSampler(Sampler): + """Sampler which returns a fixed number of samples per class, per epoch""" + def __init__(self, class_labels, class_sizes): + self.class_labels = class_labels + self.class_sizes = class_sizes + + def __iter__(self): + out = [] + for i, _class in enumerate(np.unique(self.class_labels)): + class_idx = np.argwhere(self.class_labels == _class).flatten() + sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i])) + out.extend(sampled_idx) + random.shuffle(out) + return iter(out) + + def __len__(self): + return int(sum(self.class_sizes)) + + +def build_train_sampler(training_conf, datasets): + train_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, train_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)] + dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)] + dataset_labels = [i for s in dataset_labels for i in s] + return ClassSampler(dataset_labels, dataset_size_per_epoch) def get_tokenizer(conf): @@ -115,10 +145,35 @@ def get_model(conf, tokenizer): return model +def get_dataset_name_from_data_config(data_config): + if isinstance(data_config, dict): + return list(data_config.keys())[0] + return data_config + + +def get_dataset_fractions(conf, dataset_sizes): + fractions = [] + for i, data_config in enumerate(conf): + dataset_name = get_dataset_name_from_data_config(data_config) + if isinstance(data_config, dict): + if "fraction" in data_config[dataset_name]: + fractions.append(min(1, data_config[dataset_name]["fraction"])) + elif "size" in data_config[dataset_name]: + if data_config[dataset_name]["size"] > dataset_sizes[i]: + raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})") + fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) + else: + raise ValueError("Please specify either fraction or size in config.yaml") + else: + fractions.append(1) + return fractions + + def get_dataset(conf, tokenizer): train_datasets, evals = [], {} - for dataset_name in conf.datasets: + for data_config in conf.datasets: + dataset_name = get_dataset_name_from_data_config(data_config) train, val = get_one_dataset(conf, dataset_name) train_datasets.append(train) evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val From 283df8ec84bae350c8c13e364807443c36ad071e Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 20:49:25 +0000 Subject: [PATCH 2/3] Get working on multi-gpu --- model/supervised_finetuning/README.md | 128 +++++++++++------- .../supervised_finetuning/configs/config.yaml | 47 +++---- .../custom_datasets/qa_datasets.py | 5 +- .../custom_datasets/translation.py | 2 - model/supervised_finetuning/requirements.txt | 2 +- model/supervised_finetuning/trainer.py | 25 ++-- model/supervised_finetuning/utils.py | 78 +++++++---- 7 files changed, 167 insertions(+), 120 deletions(-) diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index d5b10e01..387e91e4 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -1,62 +1,18 @@ # Train using supervised examples -Requirements +## Requirements -``` -wandb -evaluate -datasets -transformers -torch -``` +`pip install -r requirements.txt` -Start training reward model +Start training SFT model ```bash -python trainer.py --configs defaults galactica-125 +python trainer.py --configs defaults galactica-125m ``` -## Dataset - -For now we only support webgpt and summary dataset from OpenAI. Once -open-asisstant dataset are available it will be added here. - -## Model - -Normally you should be able to add new models in configs/config.yml - -``` -your-model-name: - learning_rate: 2e-6 - model_name: - weight_decay: 0.01 - max_length: 812 - warmup_steps: 600 - gradient_checkpointing: false - gradient_accumulation_steps: 5 - per_device_train_batch_size: 4 - per_device_eval_batch_size: 4 -``` - -``` -python trainer.py --configs defaults your-model-name -``` - -However, if the model of your choice doesn't have pad_token, eos_token, -sep_token, you have to update utils.py `get_tokenizer` to use the right token. - -## Deepspeed support - -You can edit the configs/zero_config.json and use any stage you wish. The -current config uses zero-stage 3. For more details on how to setup the config -checkout [this page](https://www.deepspeed.ai/tutorials/zero/) - -Once you are satisfy with your deepzero config, you can add --deepspeed flag at -the end to trigger deepspeed - -``` -python trainer.py --configs defaults your-model-name --deepspeed -``` +For `wandb`: update the `entity` argument in `trainer.py`'s call to `wandb.init` +to be your weights and biases username per +[docs](https://docs.wandb.ai/ref/python/init). ## Dataset choices @@ -80,6 +36,74 @@ Currently only these languages are supported via prompt translation: ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh ``` +## Dataset sub-sampling + +We can subsample the **training** data by passing either the `fraction` or +`size` argument in the `configs/config.yml` file. Don't forget the additional +colon ":" after the dataset name when doing this. + +Example: + +``` + datasets: + - webgpt: + fraction : 0.05 + - prompt_dialogue: + size : 500 + - adversarial_qa + - trivia_qa_nocontext +``` + +In this example, per epoch we will use: + +- A random 5% of `webgpt`; +- A random 500 examples from `prompt_dialogue`; +- All examples from datasets for which we don't specify the `fraction` or `size` + argument. + +In the above example, per epoch we'll use a different 5% from `webgpt` and a +different 500 examples from `prompt_dialogue`. + +This works with `torch.distributed`. + +## Model + +Normally you should be able to add new models in `configs/config.yml` + +``` +your-model-name: + learning_rate: 2e-6 + model_name: + weight_decay: 0.01 + max_length: 812 + warmup_steps: 600 + gradient_checkpointing: false + gradient_accumulation_steps: 5 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 +``` + +``` +python trainer.py --configs defaults your-model-name +``` + +However, if the model of your choice doesn't have `pad_token`, `eos_token`, +`sep_token`, you have to update `get_tokenizer` in `utils.py` to use the right +token. + +## Deepspeed support + +You can edit the configs/zero_config.json and use any stage you wish. The +current config uses zero-stage 3. For more details on how to setup the config +checkout [this page](https://www.deepspeed.ai/tutorials/zero/). + +Once you are satisfy with your deepzero config, you can add --deepspeed flag at +the end to trigger deepspeed + +``` +python trainer.py --configs defaults your-model-name --deepspeed +``` + ## Results Experimental results in wandb @@ -87,7 +111,7 @@ Experimental results in wandb ## TODOS -- decide on a model +- Decide on a model - Merge utils etc with reward model - Casual Modelling for GPT-JT does not leverage the bidirectional mask for the prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 925b6dda..289e49ad 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -16,32 +16,29 @@ defaults: eval_accumulation_steps: freeze_layer: datasets: - - webgpt: - fraction : 0.001 - - prompt_dialogue: - size : 10 - - squad_v2: - size : 10 - # - adversarial_qa - # - trivia_qa_nocontext - # - xsum - # - cnn_dailymail + - webgpt # - prompt_dialogue - # - multi_news - # - scitldr - # - soda - # - joke - # - gsm8k - # - dive_mt - # - wmt2019_zh-en - # - wmt2019_ru-en - # - wmt2019_de-en - # - ted_trans_nl-en - # - ted_trans_de-ja - # - instruct_tuning - # - wmt2019_de-en - # - samsum - # - soda_dialogue + - squad_v2 + - adversarial_qa + - trivia_qa_nocontext + - xsum + - cnn_dailymail + - prompt_dialogue + - multi_news + - scitldr + - soda + - joke + - gsm8k + - dive_mt + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja + - instruct_tuning + - wmt2019_de-en + - samsum + - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 2c5c7ee2..206b679c 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -219,7 +219,7 @@ class SODA(Dataset): return pairs - def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: + def __init__(self, cache_dir, input_max_length=1024) -> None: super().__init__() self.pairs = [] @@ -230,9 +230,6 @@ class SODA(Dataset): if len(prompt) < input_max_length: self.pairs.append((prompt, answer)) - if len(self.pairs) > max_sample_size: - break - def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index f9a71a8e..008de751 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -100,8 +100,6 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) - if len(self.pairs) > 100000: - break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index efe6df89..95e5a472 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,6 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -# mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr @@ -12,3 +11,4 @@ PyYAML>=6.0 scikit_learn==1.2.0 torch>=1.11.0 transformers==4.25.1 +wandb diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 06ae39db..922be10b 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -8,8 +8,7 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import (build_train_sampler, get_dataset, get_loss, get_metrics, - get_model, get_tokenizer, read_yamls) +from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -92,20 +91,30 @@ class SFTTrainer(Trainer): return (loss, logits, labels) def get_train_dataloader(self): + """Inject custom data sampling behaviour into training loop""" if self.sampler is None: torch.utils.data.DataLoader( self.train_dataset, batch_size=self.args.per_device_train_batch_size, shuffle=True, - collate_fn=self.data_collator + collate_fn=self.data_collator, ) else: - return torch.utils.data.DataLoader( + dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=self.args.per_device_train_batch_size, sampler=self.sampler, - collate_fn=self.data_collator + collate_fn=self.data_collator, ) + if torch.cuda.device_count() <= 1: + return dataloader + else: + # Not strictly necessary to use accelerate, currently just + # ensures batches are padded to be divisible by # devices + from accelerate import Accelerator + + accelerator = Accelerator() + return accelerator.prepare(dataloader) def _strtobool(x): @@ -160,7 +169,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - sampler = build_train_sampler(training_conf, train.datasets) + sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF @@ -178,7 +187,7 @@ if __name__ == "__main__": learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, optim=optimizer, - # fp16=True, + fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, @@ -205,7 +214,7 @@ if __name__ == "__main__": entity="maw501", name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) - + trainer = SFTTrainer( model=model, args=args, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 6d6b271a..c18cb380 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,11 +1,8 @@ -# from functools import partial +import random from pathlib import Path +from typing import List import evaluate -import random - -# import nltk -import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -18,32 +15,55 @@ from torch.utils.data import ConcatDataset, Subset from torch.utils.data.sampler import Sampler -class ClassSampler(Sampler): - """Sampler which returns a fixed number of samples per class, per epoch""" - def __init__(self, class_labels, class_sizes): - self.class_labels = class_labels - self.class_sizes = class_sizes +class PerDatasetSampler(Sampler): + """Sampler which returns a fixed number of samples per dataset, per epoch. + + Example: + + Dataset 1 has 10,000 examples and we want 200 per epoch + Dataset 2 has 500 examples and we want all 500 per epoch + + Epoch size will be 700 and every epoch we'll sample a different + 200 from dataset 1. + + Parameters + ---------- + dataset_sizes : List[int] + A list with the size of each dataset. + dataset_size_per_epoch : List[int] + How many examples to get from each dataset per epoch. + + Note: dataset_sizes & dataset_size_per_epoch must be in the same order. + Further the examples in the underlying torch.utils.data.Dataset + must per ordered as dataset_1, dataset_2, ..., dataset_n. This is fine + if we concatenate a bunch of datasets together + e.g. using torch.utils.data.ConcatDataset which is current behaviour. + """ + + def __init__(self, dataset_sizes: List[int], dataset_size_per_epoch: List[int]): + self.dataset_sizes = dataset_sizes + self.dataset_size_per_epoch = dataset_size_per_epoch + self.num_datasets = len(dataset_sizes) def __iter__(self): - out = [] - for i, _class in enumerate(np.unique(self.class_labels)): - class_idx = np.argwhere(self.class_labels == _class).flatten() - sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i])) - out.extend(sampled_idx) - random.shuffle(out) - return iter(out) + epoch_idx = [] + n = 0 + for i in range(self.num_datasets): + sampled_idx = random.sample(range(n, self.dataset_sizes[i] + n), self.dataset_size_per_epoch[i]) + n += self.dataset_sizes[i] + epoch_idx.extend(sampled_idx) + random.shuffle(epoch_idx) + return iter(epoch_idx) def __len__(self): - return int(sum(self.class_sizes)) + return int(sum(self.dataset_size_per_epoch)) - -def build_train_sampler(training_conf, datasets): - train_sizes = [len(x) for x in datasets] - fractions = get_dataset_fractions(training_conf.datasets, train_sizes) - dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)] - dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)] - dataset_labels = [i for s in dataset_labels for i in s] - return ClassSampler(dataset_labels, dataset_size_per_epoch) + @classmethod + def build_sampler_from_config(cls, training_conf, datasets): + dataset_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)] + return cls(dataset_sizes, dataset_size_per_epoch) def get_tokenizer(conf): @@ -157,13 +177,15 @@ def get_dataset_fractions(conf, dataset_sizes): dataset_name = get_dataset_name_from_data_config(data_config) if isinstance(data_config, dict): if "fraction" in data_config[dataset_name]: + if data_config[dataset_name]["fraction"] <= 0: + raise ValueError("Please specify fraction as a value between 0 < fraction <= 1") fractions.append(min(1, data_config[dataset_name]["fraction"])) elif "size" in data_config[dataset_name]: if data_config[dataset_name]["size"] > dataset_sizes[i]: - raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})") + raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}") fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) else: - raise ValueError("Please specify either fraction or size in config.yaml") + raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.") else: fractions.append(1) return fractions From 9faae250ce176c1025cf529003ceb3a5ef188a92 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 20:57:13 +0000 Subject: [PATCH 3/3] minor tidy-up --- model/supervised_finetuning/trainer.py | 3 +-- model/supervised_finetuning/utils.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 922be10b..72ae9f04 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -210,8 +210,7 @@ if __name__ == "__main__": wandb.init( project="supervised-finetuning", - # entity=training_conf.wandb_entity, - entity="maw501", + entity=training_conf.wandb_entity, name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index c18cb380..6c47bdb7 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -172,6 +172,7 @@ def get_dataset_name_from_data_config(data_config): def get_dataset_fractions(conf, dataset_sizes): + """Calculate fraction of each dataset to use per epoch when subsampling""" fractions = [] for i, data_config in enumerate(conf): dataset_name = get_dataset_name_from_data_config(data_config)