[SGD] Fix IterableDataset errors (#8208)

This commit is contained in:
Xianyang Liu
2020-04-30 01:51:31 +08:00
committed by GitHub
parent 1b1fe0cc5b
commit fbf23eb6ff
@@ -6,7 +6,7 @@ import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
@@ -145,12 +145,16 @@ class DistributedTorchRunner(TorchRunner):
}
return DataLoader(**data_loader_args)
if isinstance(self.train_loader, DataLoader):
def should_wrap_dataloader(loader):
return (isinstance(loader, DataLoader)
and not isinstance(loader.dataset, IterableDataset))
if should_wrap_dataloader(self.train_loader):
if self.add_dist_sampler:
self.train_loader = with_sampler(self.train_loader)
if self.validation_loader and isinstance(self.validation_loader,
DataLoader):
if self.validation_loader and should_wrap_dataloader(
self.validation_loader):
if self.add_dist_sampler:
self.validation_loader = with_sampler(self.validation_loader)