mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:48:54 +08:00
[SGD] Fix IterableDataset errors (#8208)
This commit is contained in:
@@ -6,7 +6,7 @@ import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
|
||||
@@ -145,12 +145,16 @@ class DistributedTorchRunner(TorchRunner):
|
||||
}
|
||||
return DataLoader(**data_loader_args)
|
||||
|
||||
if isinstance(self.train_loader, DataLoader):
|
||||
def should_wrap_dataloader(loader):
|
||||
return (isinstance(loader, DataLoader)
|
||||
and not isinstance(loader.dataset, IterableDataset))
|
||||
|
||||
if should_wrap_dataloader(self.train_loader):
|
||||
if self.add_dist_sampler:
|
||||
self.train_loader = with_sampler(self.train_loader)
|
||||
|
||||
if self.validation_loader and isinstance(self.validation_loader,
|
||||
DataLoader):
|
||||
if self.validation_loader and should_wrap_dataloader(
|
||||
self.validation_loader):
|
||||
if self.add_dist_sampler:
|
||||
self.validation_loader = with_sampler(self.validation_loader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user