diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 77ade0b9d..3ab4bd841 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -140,7 +140,7 @@ class DistributedTorchRunner(TorchRunner): if self.add_dist_sampler: self.train_loader = with_sampler(self.train_loader) - if self.validation_loader and should_wrap_dataloader( + if self.validation_loader is not None and should_wrap_dataloader( self.validation_loader): if self.add_dist_sampler: self.validation_loader = with_sampler(self.validation_loader)