[Tune] Fix PBT Transformers Example (#13174)

This commit is contained in:
Amog Kamsetty
2021-01-05 16:31:11 -08:00
committed by GitHub
parent 7e52351ae5
commit bd19ed31e7
@@ -91,11 +91,19 @@ def tune_transformer(num_samples=8,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(task_name))
# Number of eval steps is dependent on per_device_train_batch_size.
# So we define a separate function that takes in a spec arg.
def eval_steps_func(spec):
if not smoke_test:
return len(train_dataset
) // spec.config["per_device_train_batch_size"] + 1
else:
return min(1, spec.config["per_device_train_batch_size"])
tune_config = {
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"eval_steps": tune.sample_from(
lambda spec: len(train_dataset) // spec.config["per_device_train_batch_size"] + 1 # noqa: E501
) if not smoke_test else 1,
"eval_steps": tune.sample_from(eval_steps_func),
"save_steps": tune.sample_from(lambda spec: spec.config["eval_steps"]),
"num_train_epochs": tune.choice([2, 3, 4, 5]),
"max_steps": 1 if smoke_test else -1, # Used for smoke test.