diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 0fd9ef00..a0eb3e82 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -70,7 +70,7 @@ def get_dataset_fractions(conf, dataset_sizes): """Calculate fraction of each dataset to use per epoch when subsampling""" fractions = [] for i, data_config in enumerate(conf): - dataset_name = get_dataset_name_from_data_config(data_config) + dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config) if isinstance(data_config, dict): if "fraction" in data_config[dataset_name]: if data_config[dataset_name]["fraction"] <= 0: @@ -220,18 +220,24 @@ def get_model(conf, tokenizer): return model -def get_dataset_name_from_data_config(data_config): +def get_dataset_name_and_kwargs_from_data_config(data_config): if isinstance(data_config, dict): - return list(data_config.keys())[0] - return data_config + kwargs = data_config + # remove 'fraction' or 'size' from kwargs + kwargs.pop("fraction", None) + kwargs.pop("size", None) + name = list(data_config.keys())[0] + return name, kwargs + else: + return data_config, {} def get_dataset(conf, tokenizer): train_datasets, evals = [], {} for data_config in conf.datasets: - dataset_name = get_dataset_name_from_data_config(data_config) - train, val = get_one_dataset(conf, dataset_name) + dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config) + train, val = get_one_dataset(conf, dataset_name, **kwargs) train_datasets.append(train) evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val