diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index 48197d7b5..d0d9bbdc0 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -96,6 +96,22 @@ def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811 trainer.shutdown() +@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_apply_all_workers(ray_start_2_cpus, num_workers, use_local): + def fn(): + return 1 + + trainer = TorchTrainer( + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local, + use_gpu=False) + + results = trainer.apply_all_workers(fn) + assert all(x == 1 for x in results) + + @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @pytest.mark.parametrize("use_local", [True, False]) def test_multi_model(ray_start_2_cpus, num_workers, use_local): diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index 89bc41e63..626ea9f78 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -464,7 +464,7 @@ class LocalWorkerGroup(WorkerGroupInterface): return [local_call] + ray.get(remote_calls) def apply_all_workers(self, fn): - remote_calls = self.remote_worker_group.apply_all_workers(fn) + remote_calls = self.remote_worker_group._apply_all_workers(fn) local_call = self.local_worker.apply(fn) return [local_call] + ray.get(remote_calls)