diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 922be10b..72ae9f04 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -210,8 +210,7 @@ if __name__ == "__main__": wandb.init( project="supervised-finetuning", - # entity=training_conf.wandb_entity, - entity="maw501", + entity=training_conf.wandb_entity, name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index c18cb380..6c47bdb7 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -172,6 +172,7 @@ def get_dataset_name_from_data_config(data_config): def get_dataset_fractions(conf, dataset_sizes): + """Calculate fraction of each dataset to use per epoch when subsampling""" fractions = [] for i, data_config in enumerate(conf): dataset_name = get_dataset_name_from_data_config(data_config)