From fbf23eb6ff5253077e1181fe92d4971296bb1bca Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 30 Apr 2020 01:51:31 +0800 Subject: [PATCH] [SGD] Fix IterableDataset errors (#8208) --- .../ray/util/sgd/torch/distributed_torch_runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index b73bb3933..225490bf6 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -6,7 +6,7 @@ import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, IterableDataset from torch.utils.data.distributed import DistributedSampler from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S @@ -145,12 +145,16 @@ class DistributedTorchRunner(TorchRunner): } return DataLoader(**data_loader_args) - if isinstance(self.train_loader, DataLoader): + def should_wrap_dataloader(loader): + return (isinstance(loader, DataLoader) + and not isinstance(loader.dataset, IterableDataset)) + + if should_wrap_dataloader(self.train_loader): if self.add_dist_sampler: self.train_loader = with_sampler(self.train_loader) - if self.validation_loader and isinstance(self.validation_loader, - DataLoader): + if self.validation_loader and should_wrap_dataloader( + self.validation_loader): if self.add_dist_sampler: self.validation_loader = with_sampler(self.validation_loader)