diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 79217bcc0..149c510ae 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -594,7 +594,8 @@ class TorchTrainer: TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, - num_gpus=2, + num_workers=2, + use_gpu=True, override_tune_step=step ) analysis = tune.run( @@ -695,7 +696,8 @@ class BaseTorchTrainable(Trainable): # TorchTrainable is subclass of BaseTorchTrainable. TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, - num_gpus=2, + num_workers=2, + use_gpu=True, override_tune_step=custom_step )