diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index d0d9bbdc0..ca0f34fa3 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -111,6 +111,8 @@ def test_apply_all_workers(ray_start_2_cpus, num_workers, use_local): results = trainer.apply_all_workers(fn) assert all(x == 1 for x in results) + trainer.shutdown() + @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @pytest.mark.parametrize("use_local", [True, False])