mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add possible kwargs to datasets
This commit is contained in:
@@ -70,7 +70,7 @@ def get_dataset_fractions(conf, dataset_sizes):
|
||||
"""Calculate fraction of each dataset to use per epoch when subsampling"""
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
if isinstance(data_config, dict):
|
||||
if "fraction" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["fraction"] <= 0:
|
||||
@@ -220,18 +220,24 @@ def get_model(conf, tokenizer):
|
||||
return model
|
||||
|
||||
|
||||
def get_dataset_name_from_data_config(data_config):
|
||||
def get_dataset_name_and_kwargs_from_data_config(data_config):
|
||||
if isinstance(data_config, dict):
|
||||
return list(data_config.keys())[0]
|
||||
return data_config
|
||||
kwargs = data_config
|
||||
# remove 'fraction' or 'size' from kwargs
|
||||
kwargs.pop("fraction", None)
|
||||
kwargs.pop("size", None)
|
||||
name = list(data_config.keys())[0]
|
||||
return name, kwargs
|
||||
else:
|
||||
return data_config, {}
|
||||
|
||||
|
||||
def get_dataset(conf, tokenizer):
|
||||
train_datasets, evals = [], {}
|
||||
|
||||
for data_config in conf.datasets:
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name)
|
||||
dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name, **kwargs)
|
||||
train_datasets.append(train)
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
|
||||
Reference in New Issue
Block a user