mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 21:46:28 +08:00
[sgd] Refactor PyTorch SGD Documentation. (#6910)
* Refactor documentation and directory structurre * update loss * ,ore examples * fix comments * more code * svgs * formatting * more_docs * more writing * comments ready * move * whitespace * examples * fix * bold * pytorch * batch * fix * fix test * Apply suggestions from code review * quarantinegp * tests/ * fix missing
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
from ray.experimental.sgd.tf import TFTrainer
|
||||
|
||||
__all__ = ["PyTorchTrainer", "TFTrainer"]
|
||||
|
||||
@@ -66,15 +66,27 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
logger.debug("Creating loss.")
|
||||
self._create_loss()
|
||||
|
||||
logger.debug("Creating dataset.")
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
data_loaders = self.data_creator(self.batch_size, self.config)
|
||||
self.train_loader, self.validation_loader = self._validate_loaders(
|
||||
data_loaders)
|
||||
datasets = self.data_creator(self.config)
|
||||
train_set, val_set = self._validate_datasets(datasets)
|
||||
|
||||
train_loader_config = self.dataloader_config.copy()
|
||||
train_loader_config.update(
|
||||
sampler=torch.utils.data.distributed.DistributedSampler(train_set),
|
||||
shuffle=False)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
train_set, batch_size=self.batch_size, **train_loader_config)
|
||||
|
||||
self.validation_loader = None
|
||||
if val_set:
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
def step(self):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
@@ -117,4 +129,9 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
super(DistributedPyTorchRunner, self).shutdown()
|
||||
dist.destroy_process_group()
|
||||
# TODO: Temporarily removing since it causes hangs on MacOSX.
|
||||
# However, it seems to be harmless to remove permanently
|
||||
# since the processes are shutdown anyways. This comment can be
|
||||
# removed in a future release if it is still not documented
|
||||
# the stable Pytorch docs.
|
||||
# dist.destroy_process_group()
|
||||
|
||||
+3
-27
@@ -4,8 +4,6 @@ import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
import torch.utils.data
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
@@ -71,7 +69,7 @@ def validate(model, val_iterator, criterion, config):
|
||||
return stats
|
||||
|
||||
|
||||
def cifar_creator(batch_size, config):
|
||||
def cifar_creator(config):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
@@ -90,29 +88,7 @@ def cifar_creator(batch_size, config):
|
||||
validation_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=validation_sampler)
|
||||
|
||||
return train_loader, validation_loader
|
||||
return train_dataset, validation_dataset
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
@@ -126,7 +102,7 @@ def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
ResNet18,
|
||||
cifar_creator,
|
||||
optimizer_creator,
|
||||
lambda config: nn.CrossEntropyLoss(),
|
||||
nn.CrossEntropyLoss,
|
||||
initialization_hook=initialization_hook,
|
||||
train_function=train,
|
||||
validation_function=validate,
|
||||
@@ -4,7 +4,6 @@ import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as dset
|
||||
@@ -13,11 +12,10 @@ import numpy as np
|
||||
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from scipy.stats import entropy
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
from ray.experimental.sgd import PyTorchTrainer
|
||||
|
||||
# Training parameters
|
||||
TRAIN_BATCHES = 5
|
||||
@@ -34,8 +32,8 @@ features_g = 32
|
||||
features_d = 32
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
dataset = dset.MNIST(
|
||||
def data_creator(config):
|
||||
return dset.MNIST(
|
||||
root="~/mnist/",
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
@@ -44,19 +42,6 @@ def data_creator(batch_size, config):
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
|
||||
# Create the dataloader
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(dataset)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
return dataloader, None
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
@@ -231,7 +216,7 @@ def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
lambda config: nn.BCELoss(),
|
||||
nn.BCELoss,
|
||||
train_function=train,
|
||||
validation_function=False,
|
||||
num_replicas=num_replicas,
|
||||
|
||||
+4
-27
@@ -12,10 +12,8 @@ import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
|
||||
from ray.experimental.sgd import PyTorchTrainer
|
||||
|
||||
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
@@ -42,30 +40,9 @@ def optimizer_creator(model, config):
|
||||
return torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
sampler=validation_sampler)
|
||||
|
||||
return train_loader, validation_loader
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False):
|
||||
@@ -73,7 +50,7 @@ def train_example(num_replicas=1, use_gpu=False):
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=num_replicas,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=num_replicas * 4,
|
||||
+3
-25
@@ -11,8 +11,6 @@ in the documentation.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
@@ -44,29 +42,9 @@ def optimizer_creator(model, config):
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
sampler=validation_sampler)
|
||||
return train_loader, validation_loader
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False):
|
||||
@@ -74,7 +52,7 @@ def tune_example(num_replicas=1, use_gpu=False):
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(lambda config: nn.MSELoss()),
|
||||
"loss_creator": tune.function(nn.MSELoss),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 512,
|
||||
@@ -1,10 +1,11 @@
|
||||
import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import inspect
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import utils as pytorch_utils
|
||||
@@ -24,19 +25,21 @@ class PyTorchRunner:
|
||||
train_function=None,
|
||||
validation_function=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
batch_size=16):
|
||||
"""Initializes the runner.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py
|
||||
data_creator (int, dict -> DataLoader, DataLoader): see
|
||||
data_creator (int, dict -> Dataset, Dataset): see
|
||||
pytorch_trainer.py.
|
||||
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||
see pytorch_trainer.py.
|
||||
loss_creator (dict -> loss): see pytorch_trainer.py.
|
||||
loss_creator (dict -> loss | Loss class): see pytorch_trainer.py.
|
||||
train_function: see pytorch_trainer.py
|
||||
validation_function: see pytorch_trainer.py
|
||||
config (dict): see pytorch_trainer.py.
|
||||
dataloader_config (dict): See pytorch_trainer.py.
|
||||
batch_size (int): see pytorch_trainer.py.
|
||||
"""
|
||||
self.model_creator = model_creator
|
||||
@@ -44,6 +47,10 @@ class PyTorchRunner:
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = {
|
||||
"num_workers": 2,
|
||||
"pin_memory": True
|
||||
} if dataloader_config is None else dataloader_config
|
||||
self.train_function = train_function or pytorch_utils.train
|
||||
self.validation_function = (validation_function
|
||||
or pytorch_utils.validate)
|
||||
@@ -65,16 +72,24 @@ class PyTorchRunner:
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
|
||||
def _validate_loaders(self, data_loaders):
|
||||
assert data_loaders, "Dataloaders need to be returned in data_creator."
|
||||
if isinstance(data_loaders, DataLoader):
|
||||
return data_loaders, None
|
||||
elif len(data_loaders) == 2 and isinstance(data_loaders[0],
|
||||
DataLoader):
|
||||
return data_loaders
|
||||
def _validate_datasets(self, dataset):
|
||||
assert dataset, "Datasets need to be returned in data_creator."
|
||||
if issubclass(type(dataset), Dataset):
|
||||
return dataset, None
|
||||
elif len(dataset) == 2 and issubclass(type(dataset[0]), Dataset):
|
||||
return dataset
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dataloaders must be <= 2. Got {}".format(data_loaders))
|
||||
raise ValueError("Datasets must be <= 2. Got {}".format(dataset))
|
||||
|
||||
def _create_loss(self):
|
||||
if inspect.isclass(self.loss_creator) and issubclass(
|
||||
self.loss_creator, torch.nn.modules.loss._Loss):
|
||||
self.criterion = self.loss_creator()
|
||||
else:
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
@@ -90,15 +105,23 @@ class PyTorchRunner:
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
self._create_loss()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
# When creating datasets, a filelock will be used to ensure no
|
||||
# race conditions in data downloading among different workers.
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
dataloaders = self.data_creator(self.batch_size, self.config)
|
||||
self.train_loader, self.validation_loader = self._validate_loaders(
|
||||
dataloaders)
|
||||
datasets = self.data_creator(self.config)
|
||||
train_set, val_set = self._validate_datasets(datasets)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
train_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
self.validation_loader = None
|
||||
if val_set:
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
@@ -145,9 +168,18 @@ class PyTorchRunner:
|
||||
|
||||
def get_state(self):
|
||||
"""Returns the state of the runner."""
|
||||
# This is so that we create a duplicate of weights into CPU rather than
|
||||
# move the model weights entirely out of the GPU, so that we can
|
||||
# resume training while saving intermediate checkpoints.
|
||||
cpu_state_dicts = []
|
||||
for model in self.models:
|
||||
state_dict = model.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
state_dict[k] = v.cpu()
|
||||
cpu_state_dicts += [state_dict]
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"models": [model.cpu().state_dict() for model in self.models],
|
||||
"models": cpu_state_dicts,
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
||||
@@ -25,6 +25,85 @@ class PyTorchTrainer:
|
||||
|
||||
Launches a set of actors which connect via distributed PyTorch and
|
||||
coordinate gradient updates to train the provided model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
return torch.optim.SGD(
|
||||
model.parameters(), lr=config.get("lr", 1e-4))
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
use_gpu=True
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
Args:
|
||||
model_creator (dict -> *): Constructor function that takes in
|
||||
config and returns the model(s) to be optimized. These must be
|
||||
``torch.nn.Module`` objects. Note that if multiple models
|
||||
are returned, the same number of optimizers must be returned
|
||||
by the optimizer_creator. If multiple models are returned,
|
||||
a ``train_function`` must be specified. You do not need to
|
||||
handle GPU/devices in this function;
|
||||
RaySGD will do that under the hood.
|
||||
data_creator (dict -> Dataset, Dataset): Constructor function
|
||||
that takes in the passed config and returns one or
|
||||
two ``torch.utils.data.Dataset`` objects.
|
||||
Note that even though two Dataset objects can be returned,
|
||||
only one dataset will be used for training. RaySGD
|
||||
will automatically wrap the objects with a ``DataLoader``.
|
||||
optimizer_creator (models, dict -> optimizers): Constructor
|
||||
function that takes in the return values from
|
||||
``model_creator`` and the passed config and returns One or
|
||||
more Torch optimizer objects. You must return as many
|
||||
optimizers as you have models. You do not need to handle
|
||||
GPU/devices in this function; ``RaySGD`` will do that for you.
|
||||
loss_creator (dict -> loss or torch.nn.*Loss): A constructor function
|
||||
for the training loss. This can be either a function that
|
||||
takes in the provided config for customization or a subclass
|
||||
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
|
||||
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
|
||||
train_function: Custom function for training. This function
|
||||
will be executed in parallel across all workers at once. The
|
||||
function needs to take in (models, train_dataloader, criterion,
|
||||
optimizers, config), and return a dict of training stats.
|
||||
validation_function: Custom function for validation. This function
|
||||
will be executed in parallel across all workers at once.
|
||||
This takes in (model, val_dataloader, criterion, config)
|
||||
and returns a dict of validation stats.
|
||||
config (dict): Custom configuration value to be passed to
|
||||
"model_creator", "data_creator", "optimizer_creator", and
|
||||
"loss_creator".
|
||||
dataloader_config (dict): Configuration values to be passed into
|
||||
the ``torch.utils.data.DataLoader`` object that wraps
|
||||
the dataset on each parallel worker for both training
|
||||
and validation. Note that if ``num_replicas``
|
||||
is greater than 1, ``shuffle`` and ``sampler`` will be
|
||||
automatically set. See the available arguments
|
||||
here https://pytorch.org/docs/stable/data.html.
|
||||
num_replicas (int): the number of workers used in distributed
|
||||
training.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true, and automatically moves both the model and optimizer
|
||||
to the available CUDA device.
|
||||
batch_size (int): Total batch size for each minibatch. This
|
||||
value is divided among all workers and rounded.
|
||||
backend (string): backend used by distributed PyTorch. Currently
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -36,39 +115,12 @@ class PyTorchTrainer:
|
||||
validation_function=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
num_replicas=1,
|
||||
use_gpu=False,
|
||||
batch_size=16,
|
||||
backend="auto"):
|
||||
"""Sets up the PyTorch trainer.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> torch.nn.Module): creates the model
|
||||
using the config.
|
||||
data_creator (int, dict -> DataLoader, DataLoader): Function that
|
||||
takes in (batch_size, config) and returns two Torch DataLoader
|
||||
objects.
|
||||
optimizer_creator (torch.nn.Module, dict -> optimizer):
|
||||
creates the loss and optimizer using the model and the config.
|
||||
loss_creator (dict -> loss): Creates the loss function/criterion
|
||||
using the config.
|
||||
train_function: Trains a model for a epoch. This takes in (
|
||||
model, train_dataloader, criterion, optimizer, config), and
|
||||
returns a dict of training stats.
|
||||
validation_function: Runs validation. This takes in (
|
||||
model, val_dataloader, criterion, config) and returns a dict of
|
||||
validation stats.
|
||||
config (dict): configuration passed to "model_creator",
|
||||
"data_creator", "optimizer_creator", and "loss_creator".
|
||||
num_replicas (int): the number of workers used in distributed
|
||||
training.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true.
|
||||
batch_size (int): batch size for an update.
|
||||
backend (string): backend used by distributed PyTorch.
|
||||
"""
|
||||
# TODO: add support for mixed precision
|
||||
# TODO: add support for callbacks
|
||||
if num_replicas > 1 and not dist.is_available():
|
||||
raise ValueError(
|
||||
("Distributed PyTorch is not supported on macOS. "
|
||||
@@ -84,6 +136,7 @@ class PyTorchTrainer:
|
||||
self.validation_function = validation_function
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = dataloader_config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
if backend == "auto":
|
||||
@@ -115,6 +168,7 @@ class PyTorchTrainer:
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=self.batch_size)
|
||||
]
|
||||
if self.initialization_hook:
|
||||
@@ -148,6 +202,7 @@ class PyTorchTrainer:
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=batch_size_per_replica)
|
||||
for i in range(num_replicas)
|
||||
]
|
||||
@@ -274,11 +329,12 @@ class PyTorchTrainer:
|
||||
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
for worker in self.workers:
|
||||
if not force:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
else:
|
||||
if not force:
|
||||
cleanup = [worker.shutdown.remote() for worker in self.workers]
|
||||
ray.get(cleanup)
|
||||
[worker.__ray_terminate__.remote() for worker in self.workers]
|
||||
else:
|
||||
for worker in self.workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
worker.__ray_kill__()
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
from ray.experimental.sgd.pytorch.utils import train
|
||||
from ray.experimental.sgd.utils import check_for_failure
|
||||
|
||||
from ray.experimental.sgd.examples.train_example import (
|
||||
from ray.experimental.sgd.pytorch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
@@ -175,11 +177,8 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
@@ -206,11 +205,8 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
@@ -243,11 +239,8 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
|
||||
@@ -36,18 +36,12 @@ def loss_creator(config):
|
||||
return nn.MSELoss()
|
||||
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
return train_loader
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5)
|
||||
|
||||
|
||||
def create_dataloaders(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(validation_dataset)
|
||||
return train_loader, validation_loader
|
||||
def create_dataloaders(config):
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
|
||||
|
||||
class TestPyTorchRunner(unittest.TestCase):
|
||||
@@ -109,12 +103,9 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
|
||||
|
||||
def testMultiLoaders(self):
|
||||
def three_data_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(validation_dataset)
|
||||
return train_loader, validation_loader, validation_loader
|
||||
def three_data_loader(config):
|
||||
return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
|
||||
LinearDataset(2, 5, size=400))
|
||||
|
||||
runner = PyTorchRunner(model_creator, three_data_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
@@ -134,6 +125,15 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
runner.validate()
|
||||
|
||||
def testNativeLoss(self):
|
||||
runner = PyTorchRunner(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
loss_creator=nn.MSELoss)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
|
||||
def testMultiModel(self):
|
||||
def multi_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.tf import TFTrainer, TFTrainable
|
||||
|
||||
from ray.experimental.sgd.examples.tensorflow_train_example import (
|
||||
from ray.experimental.sgd.tf.examples.tensorflow_train_example import (
|
||||
simple_model, simple_dataset)
|
||||
|
||||
SIMPLE_CONFIG = {
|
||||
|
||||
Reference in New Issue
Block a user