mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[Tune] Fix PBT Transformers Example (#13174)
This commit is contained in:
@@ -91,11 +91,19 @@ def tune_transformer(num_samples=8,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=build_compute_metrics_fn(task_name))
|
||||
|
||||
# Number of eval steps is dependent on per_device_train_batch_size.
|
||||
# So we define a separate function that takes in a spec arg.
|
||||
def eval_steps_func(spec):
|
||||
if not smoke_test:
|
||||
return len(train_dataset
|
||||
) // spec.config["per_device_train_batch_size"] + 1
|
||||
else:
|
||||
return min(1, spec.config["per_device_train_batch_size"])
|
||||
|
||||
tune_config = {
|
||||
"per_device_train_batch_size": 32,
|
||||
"per_device_eval_batch_size": 32,
|
||||
"eval_steps": tune.sample_from(
|
||||
lambda spec: len(train_dataset) // spec.config["per_device_train_batch_size"] + 1 # noqa: E501
|
||||
) if not smoke_test else 1,
|
||||
"eval_steps": tune.sample_from(eval_steps_func),
|
||||
"save_steps": tune.sample_from(lambda spec: spec.config["eval_steps"]),
|
||||
"num_train_epochs": tune.choice([2, 3, 4, 5]),
|
||||
"max_steps": 1 if smoke_test else -1, # Used for smoke test.
|
||||
|
||||
Reference in New Issue
Block a user