From bd19ed31e7a5ab2d51973d3eb349b131ff2ec1b7 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 5 Jan 2021 16:31:11 -0800 Subject: [PATCH] [Tune] Fix PBT Transformers Example (#13174) --- .../examples/pbt_transformers/pbt_transformers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/examples/pbt_transformers/pbt_transformers.py b/python/ray/tune/examples/pbt_transformers/pbt_transformers.py index 0c918ccca..41138d443 100644 --- a/python/ray/tune/examples/pbt_transformers/pbt_transformers.py +++ b/python/ray/tune/examples/pbt_transformers/pbt_transformers.py @@ -91,11 +91,19 @@ def tune_transformer(num_samples=8, eval_dataset=eval_dataset, compute_metrics=build_compute_metrics_fn(task_name)) + # Number of eval steps is dependent on per_device_train_batch_size. + # So we define a separate function that takes in a spec arg. + def eval_steps_func(spec): + if not smoke_test: + return len(train_dataset + ) // spec.config["per_device_train_batch_size"] + 1 + else: + return min(1, spec.config["per_device_train_batch_size"]) + tune_config = { + "per_device_train_batch_size": 32, "per_device_eval_batch_size": 32, - "eval_steps": tune.sample_from( - lambda spec: len(train_dataset) // spec.config["per_device_train_batch_size"] + 1 # noqa: E501 - ) if not smoke_test else 1, + "eval_steps": tune.sample_from(eval_steps_func), "save_steps": tune.sample_from(lambda spec: spec.config["eval_steps"]), "num_train_epochs": tune.choice([2, 3, 4, 5]), "max_steps": 1 if smoke_test else -1, # Used for smoke test.