diff --git a/python/ray/tune/examples/pbt_transformers/pbt_transformers.py b/python/ray/tune/examples/pbt_transformers/pbt_transformers.py index 0c918ccca..41138d443 100644 --- a/python/ray/tune/examples/pbt_transformers/pbt_transformers.py +++ b/python/ray/tune/examples/pbt_transformers/pbt_transformers.py @@ -91,11 +91,19 @@ def tune_transformer(num_samples=8, eval_dataset=eval_dataset, compute_metrics=build_compute_metrics_fn(task_name)) + # Number of eval steps is dependent on per_device_train_batch_size. + # So we define a separate function that takes in a spec arg. + def eval_steps_func(spec): + if not smoke_test: + return len(train_dataset + ) // spec.config["per_device_train_batch_size"] + 1 + else: + return min(1, spec.config["per_device_train_batch_size"]) + tune_config = { + "per_device_train_batch_size": 32, "per_device_eval_batch_size": 32, - "eval_steps": tune.sample_from( - lambda spec: len(train_dataset) // spec.config["per_device_train_batch_size"] + 1 # noqa: E501 - ) if not smoke_test else 1, + "eval_steps": tune.sample_from(eval_steps_func), "save_steps": tune.sample_from(lambda spec: spec.config["eval_steps"]), "num_train_epochs": tune.choice([2, 3, 4, 5]), "max_steps": 1 if smoke_test else -1, # Used for smoke test.