add possible kwargs to datasets

This commit is contained in:
Sotirios Anagnostidis
2023-02-11 10:44:03 +01:00
parent 44ed44e05d
commit 714677b5d8
+12 -6
View File
@@ -70,7 +70,7 @@ def get_dataset_fractions(conf, dataset_sizes):
"""Calculate fraction of each dataset to use per epoch when subsampling"""
fractions = []
for i, data_config in enumerate(conf):
dataset_name = get_dataset_name_from_data_config(data_config)
dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config)
if isinstance(data_config, dict):
if "fraction" in data_config[dataset_name]:
if data_config[dataset_name]["fraction"] <= 0:
@@ -220,18 +220,24 @@ def get_model(conf, tokenizer):
return model
def get_dataset_name_from_data_config(data_config):
def get_dataset_name_and_kwargs_from_data_config(data_config):
if isinstance(data_config, dict):
return list(data_config.keys())[0]
return data_config
kwargs = data_config
# remove 'fraction' or 'size' from kwargs
kwargs.pop("fraction", None)
kwargs.pop("size", None)
name = list(data_config.keys())[0]
return name, kwargs
else:
return data_config, {}
def get_dataset(conf, tokenizer):
train_datasets, evals = [], {}
for data_config in conf.datasets:
dataset_name = get_dataset_name_from_data_config(data_config)
train, val = get_one_dataset(conf, dataset_name)
dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config)
train, val = get_one_dataset(conf, dataset_name, **kwargs)
train_datasets.append(train)
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val