mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
choice for empty eval
This commit is contained in:
@@ -234,7 +234,6 @@ if __name__ == "__main__":
|
||||
report_to="wandb" if training_conf.log_wandb else None,
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
if training_conf.log_wandb and not training_conf.deepspeed or training_conf.local_rank == 0:
|
||||
import wandb
|
||||
|
||||
|
||||
@@ -61,13 +61,17 @@ class PerDatasetSampler(Sampler):
|
||||
@classmethod
|
||||
def build_sampler_from_config(cls, training_conf, datasets):
|
||||
dataset_sizes = [len(x) for x in datasets]
|
||||
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes)
|
||||
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes, verbose=training_conf.verbose)
|
||||
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)]
|
||||
return cls(dataset_sizes, dataset_size_per_epoch)
|
||||
|
||||
|
||||
def get_dataset_fractions(conf, dataset_sizes):
|
||||
def get_dataset_fractions(conf, dataset_sizes, verbose=False):
|
||||
"""Calculate fraction of each dataset to use per epoch when subsampling"""
|
||||
|
||||
if verbose:
|
||||
print("Creating sampler for datasets:")
|
||||
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
@@ -81,9 +85,12 @@ def get_dataset_fractions(conf, dataset_sizes):
|
||||
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
|
||||
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
|
||||
else:
|
||||
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
|
||||
fractions.append(1)
|
||||
else:
|
||||
fractions.append(1)
|
||||
|
||||
if verbose:
|
||||
print(f"Dataset: {dataset_name} fraction chosen: {fractions[-1]:.2f}")
|
||||
return fractions
|
||||
|
||||
|
||||
@@ -222,11 +229,11 @@ def get_model(conf, tokenizer):
|
||||
|
||||
def get_dataset_name_and_kwargs_from_data_config(data_config):
|
||||
if isinstance(data_config, dict):
|
||||
kwargs = data_config
|
||||
name = list(data_config.keys())[0]
|
||||
kwargs = data_config[name]
|
||||
# remove 'fraction' or 'size' from kwargs
|
||||
kwargs.pop("fraction", None)
|
||||
kwargs.pop("size", None)
|
||||
name = list(data_config.keys())[0]
|
||||
return name, kwargs
|
||||
else:
|
||||
return data_config, {}
|
||||
@@ -239,7 +246,9 @@ def get_dataset(conf, tokenizer):
|
||||
dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name, **kwargs)
|
||||
train_datasets.append(train)
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
if val is not None:
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user