mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 16:15:55 +08:00
910 lines
35 KiB
Python
910 lines
35 KiB
Python
import inspect
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from filelock import FileLock
|
|
|
|
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
|
NUM_SAMPLES)
|
|
from ray.util.sgd.torch.constants import (
|
|
SCHEDULER_STEP_EPOCH,
|
|
NUM_STEPS,
|
|
SCHEDULER_STEP_BATCH,
|
|
)
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.utils.data import DistributedSampler, DataLoader, IterableDataset
|
|
|
|
logger = logging.getLogger(__name__)
|
|
amp = None
|
|
|
|
try:
|
|
from collections.abc import Iterable
|
|
except ImportError:
|
|
from collections import Iterable
|
|
|
|
try:
|
|
from apex import amp
|
|
except ImportError:
|
|
# Apex library is not installed, so we cannot enable mixed precision.
|
|
# We don't log here because logging happens in the torch_runner,
|
|
# where amp is initialized.
|
|
logger.debug("apex is not installed.")
|
|
pass
|
|
|
|
tqdm = None
|
|
try:
|
|
from tqdm import tqdm
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def _is_multiple(component):
|
|
"""Checks if a component (optimizer, model, etc) is not singular."""
|
|
return isinstance(component, Iterable) and len(component) > 1
|
|
|
|
|
|
class TrainingOperator:
|
|
"""Abstract class to define training and validation state and logic.
|
|
|
|
You must subclass this class and override the ``setup`` method to define
|
|
your training components such as the model, optimizer, data, loss,
|
|
and scheduler. When you pass this class to ``TorchTrainer``, a copy of
|
|
this class will be made on each worker.
|
|
|
|
.. code-block:: python
|
|
|
|
class MyTrainingOperator(TrainingOperator):
|
|
|
|
def setup(self, config):
|
|
model = nn.Linear(1, 1)
|
|
optimizer = torch.optim.SGD(
|
|
model.parameters(), lr=config.get("lr", 1e-4))
|
|
loss = torch.nn.MSELoss()
|
|
|
|
batch_size = config["batch_size"]
|
|
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
|
|
train_loader = DataLoader(train_data, batch_size=batch_size)
|
|
val_loader = DataLoader(val_data, batch_size=batch_size)
|
|
|
|
self.model, self.optimizer = self.register(
|
|
models=model,
|
|
optimizers=optimizer,
|
|
criterion=loss)
|
|
|
|
self.register_data(
|
|
train_loader=train_loader,
|
|
validation_loader=val_loader)
|
|
|
|
|
|
trainer = TorchTrainer(
|
|
training_operator_cls=MyTrainingOperator,
|
|
config={"batch_size": 32},
|
|
use_gpu=True
|
|
)
|
|
for i in range(4):
|
|
trainer.train()
|
|
|
|
This class provides default implementations for training and validation.
|
|
Set ``self.model``, ``self.optimizer``, and
|
|
``self.criterion`` to leverage the default training and validation loops.
|
|
If ``self.scheduler`` is set, it will only be called at a batch or epoch
|
|
frequency, depending on the user parameter. Set
|
|
``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch"
|
|
to increment the scheduler correctly during training. If using a
|
|
learning rate scheduler that depends on validation loss, you can use
|
|
``trainer.update_scheduler``.
|
|
|
|
If you want to provide custom training and validation loops, you can do
|
|
so using this class as well. There are two granularities that
|
|
you can provide customization: per epoch or per batch.
|
|
You do not need to override both.
|
|
|
|
.. image:: raysgd-custom.jpg
|
|
:scale: 80%
|
|
:align: center
|
|
|
|
If you are using multiple models, optimizers, or schedulers, you must
|
|
implement custom training and validation.
|
|
|
|
Raises:
|
|
ValueError
|
|
You are expected to either set ``self.model``,
|
|
``self.optimizer``, and ``self.criterion`` instance attributes in
|
|
setup or implement custom training & validation.
|
|
"""
|
|
|
|
def __init__(self,
|
|
config,
|
|
world_rank,
|
|
device_ids=None,
|
|
use_gpu=False,
|
|
use_fp16=False,
|
|
use_tqdm=False,
|
|
apex_args=None,
|
|
wrap_ddp=False,
|
|
wrap_distributed_sampler=False,
|
|
add_dist_sampler=False,
|
|
scheduler_step_freq=None):
|
|
# You are not expected to override this method.
|
|
self._world_rank = world_rank
|
|
self._config = config
|
|
self._use_fp16 = use_fp16
|
|
self._device_ids = device_ids
|
|
self._use_gpu = use_gpu and torch.cuda.is_available()
|
|
self._device = torch.device("cuda" if self._use_gpu else "cpu")
|
|
if tqdm is None and use_tqdm:
|
|
raise ValueError("tqdm must be installed to use tqdm in training.")
|
|
self._use_tqdm = use_tqdm
|
|
self.global_step = 0
|
|
self._apex_args = apex_args if apex_args else {}
|
|
self._wrap_ddp = wrap_ddp
|
|
self._wrap_distributed_sampler = wrap_distributed_sampler
|
|
self._add_dist_sampler = add_dist_sampler
|
|
self._scheduler_step_freq = scheduler_step_freq
|
|
|
|
self.timers = TimerCollection()
|
|
self.setup(config)
|
|
|
|
def _set_timers(self, timers):
|
|
"""Passes in the timers from the Runner."""
|
|
self.timers = timers
|
|
|
|
def setup(self, config):
|
|
"""Override this method to implement operator setup.
|
|
|
|
You should call self.register and self.register_data here to
|
|
register training components and data loaders with Ray SGD.
|
|
|
|
Args:
|
|
config (dict): Custom configuration value to be passed to
|
|
all creator and operator constructors. Same as ``self.config``.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def register(self, *, models, optimizers, criterion=None, schedulers=None):
|
|
"""Registers parameters with Ray SGD and sets up training components.
|
|
|
|
By calling this method to register your models, optimizers,
|
|
criterion, and schedulers, Ray SGD will automatically handle
|
|
necessary setup such as GPU/devices, Distributed Data Parallel, and
|
|
Fp16. The registered components are returned and should be set as
|
|
instance attributes to access during training/validation.
|
|
|
|
If more than one model, optimizer, or scheduler is passed in,
|
|
you should implement your own custom training loop.
|
|
|
|
.. code-block:: python
|
|
|
|
class MyTrainingOperator(TrainingOperator):
|
|
def setup(self, config):
|
|
model = ...
|
|
optimizer = ...
|
|
train_loader = ...
|
|
val_loader = ...
|
|
loss = ...
|
|
|
|
self.model, self.optimizer, self.criterion = self.register(
|
|
models=model, optimizers=optimizer, criterion=loss)
|
|
|
|
# At this point DDP, Cuda, and Fp16
|
|
# are set up for all our components. We then use
|
|
# self.model, self.optimizer, etc. in our training loop.
|
|
|
|
self.register_data(train_loader=train_loader,
|
|
validation_loader=val_loader)
|
|
|
|
|
|
Args:
|
|
models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or
|
|
multiple Pytorch models to use for training. If
|
|
`use_gpu=True` is passed into ``TorchTrainer``, and Cuda is
|
|
available, models will automatically be placed on GPU.
|
|
If ``wrap_ddp=True`` is passed into ``TorchTrainer``,
|
|
models will be wrapped in DDP. If wrap_ddp is False,
|
|
you should handle DDP for your models in setup.
|
|
optimizers (torch.optim.Optimizer or Iterable[
|
|
torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch
|
|
optimizers to use for training.
|
|
criterion (Callable, optional): Function to return loss
|
|
metric given features and target. If not provided,
|
|
must implement a custom training loop.
|
|
schedulers (torch.optim.lr_scheduler or Iterable[
|
|
torch.optim.lr_scheduler], optional): A learning rate
|
|
scheduler or multiple learning rate schedulers.
|
|
|
|
Returns:
|
|
Tuple of model, optimizer, criterion if not None, and scheduler
|
|
if not None.
|
|
|
|
"""
|
|
return_vals = []
|
|
logger.debug("Registering models.")
|
|
self._original_models = models
|
|
if not isinstance(self._original_models, Iterable):
|
|
self._original_models = [self._original_models]
|
|
assert all(
|
|
isinstance(model, nn.Module) for model in self._original_models), (
|
|
f"All models must be PyTorch models: {self._original_models}.")
|
|
if self.use_gpu and torch.cuda.is_available():
|
|
self._original_models = [
|
|
model.cuda() for model in self._original_models
|
|
]
|
|
|
|
logger.debug("Registering optimizers.")
|
|
self._optimizers = optimizers
|
|
if not isinstance(self._optimizers, Iterable):
|
|
self._optimizers = [self._optimizers]
|
|
|
|
if schedulers:
|
|
logger.debug("Registering scheduler.")
|
|
self._schedulers = schedulers
|
|
if not isinstance(self._schedulers, Iterable):
|
|
self._schedulers = [self._schedulers]
|
|
else:
|
|
self._schedulers = None
|
|
|
|
if criterion:
|
|
logger.debug("Registering loss.")
|
|
self._criterion = criterion
|
|
if self.use_gpu and torch.cuda.is_available():
|
|
if hasattr(self._criterion, "cuda"):
|
|
self._criterion = self._criterion.cuda()
|
|
else:
|
|
self._criterion = None
|
|
|
|
if self.use_fp16 and amp:
|
|
logger.debug("Setting up Apex.")
|
|
self._original_models, self._optimizers = amp.initialize(
|
|
self._original_models, self._optimizers, **self._apex_args)
|
|
self._amp = amp
|
|
|
|
if self._wrap_ddp:
|
|
logging.debug("Setting up DDP for models.")
|
|
self._models = [
|
|
DistributedDataParallel(model, device_ids=self.device_ids)
|
|
for model in self._original_models
|
|
]
|
|
else:
|
|
self._models = self._original_models
|
|
|
|
if len(self._models) == 1:
|
|
return_vals.append(self._models[0])
|
|
else:
|
|
return_vals.append(self._models)
|
|
|
|
if len(self._optimizers) == 1:
|
|
return_vals.append(self._optimizers[0])
|
|
else:
|
|
return_vals.append(self._optimizers)
|
|
|
|
if self._criterion is not None:
|
|
return_vals.append(self._criterion)
|
|
|
|
if self._schedulers is not None:
|
|
if self.scheduler_step_freq is None:
|
|
raise ValueError("scheduler_step_freq passed into "
|
|
"TorchTrainer cannot be None if you "
|
|
"are registering schedulers. Set this to "
|
|
"'manual' if you will be manually stepping "
|
|
"the schedulers.")
|
|
if len(self._schedulers) == 1:
|
|
return_vals.append(self._schedulers[0])
|
|
else:
|
|
return_vals.append(self._schedulers)
|
|
|
|
return tuple(return_vals)
|
|
|
|
def register_data(self, *, train_loader=None, validation_loader=None):
|
|
"""Registers data loaders with Ray SGD.
|
|
|
|
Calling this method will automatically setup Distributed Sampler for
|
|
these data loaders if add_dist_sampler=True is passed into the
|
|
TorchTrainer. This method does not return the wrapped data loaders.
|
|
You should use the iterators passed into train_epoch and validate
|
|
instead.
|
|
|
|
.. code-block:: python
|
|
|
|
class MyTrainingOperator(TrainingOperator):
|
|
def setup(self, config):
|
|
model = ...
|
|
optimizer = ...
|
|
train_loader = ...
|
|
val_loader = ...
|
|
loss = ...
|
|
|
|
self.model, self.optimizer, self.criterion = self.register(
|
|
models=model, optimizers=optimizer, criterion=loss)
|
|
|
|
self.register_data(train_loader=train_loader,
|
|
validation_loader=val_loader)
|
|
|
|
# At this point the data loaders are registered with
|
|
# Ray SGD and are wrapped with Distributed Samplers if
|
|
# applicable.
|
|
|
|
|
|
def train_epoch(self, iterator, info):
|
|
# If providing custom training or validation methods,
|
|
# the registered data loaders are passed in through the
|
|
# iterator parameter.
|
|
...
|
|
|
|
Args:
|
|
train_loader (Iterator): An iterator for training
|
|
data. If None is explicitly passed in, a Ray SGD Dataset
|
|
must be passed in through TorchTrainer.train. Ray SGD will
|
|
automatically use a Distributed Sampler if TorchTrainer(...,
|
|
add_dist_sampler=True).
|
|
validation_loader (Iterator): An iterator for validation
|
|
data. Ray SGD will automatically use a Distributed Sampler
|
|
if TorchTrainer(..., add_dist_sampler=True).
|
|
"""
|
|
|
|
logger.debug("Registering data loaders..")
|
|
self._train_loader = train_loader
|
|
self._validation_loader = validation_loader
|
|
|
|
if self._wrap_distributed_sampler:
|
|
logging.debug("Wrapping data loaders with DistributedSampler.")
|
|
|
|
def with_sampler(loader):
|
|
# Automatically set the DistributedSampler
|
|
data_loader_args = {
|
|
"dataset": loader.dataset,
|
|
"batch_size": loader.batch_size,
|
|
"shuffle": False,
|
|
"num_workers": loader.num_workers,
|
|
"collate_fn": loader.collate_fn,
|
|
"pin_memory": loader.pin_memory,
|
|
"drop_last": loader.drop_last,
|
|
"timeout": loader.timeout,
|
|
"worker_init_fn": loader.worker_init_fn,
|
|
"sampler": DistributedSampler(loader.dataset)
|
|
}
|
|
return DataLoader(**data_loader_args)
|
|
|
|
def should_wrap_dataloader(loader):
|
|
return (isinstance(loader, DataLoader)
|
|
and not isinstance(loader.dataset, IterableDataset))
|
|
|
|
if should_wrap_dataloader(self._train_loader):
|
|
if self._add_dist_sampler:
|
|
self._train_loader = with_sampler(self._train_loader)
|
|
|
|
if self._validation_loader is not None and should_wrap_dataloader(
|
|
self._validation_loader):
|
|
if self._add_dist_sampler:
|
|
self._validation_loader = with_sampler(
|
|
self._validation_loader)
|
|
|
|
def train_epoch(self, iterator, info):
|
|
"""Runs one standard training pass over the training dataloader.
|
|
|
|
By default, this method will iterate over the given iterator and
|
|
call ``self.train_batch`` over each batch. If ``scheduler_step_freq``
|
|
is set, this default method will also step the scheduler accordingly.
|
|
|
|
You do not need to call ``train_batch`` in this method if you plan
|
|
to implement a custom optimization/training routine here.
|
|
|
|
You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful
|
|
when overriding this method. See example below:
|
|
|
|
.. code-block:: python
|
|
|
|
def train_epoch(self, ...):
|
|
meter_collection = AverageMeterCollection()
|
|
self.model.train()
|
|
for batch in iterator:
|
|
# do some processing
|
|
metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics
|
|
|
|
# This keeps track of all metrics across multiple batches
|
|
meter_collection.update(metrics, n=len(batch))
|
|
|
|
# Returns stats of the meters.
|
|
stats = meter_collection.summary()
|
|
return stats
|
|
|
|
|
|
Args:
|
|
iterator (iter): Iterator over the training data for the entire
|
|
epoch. This iterator is expected to be entirely consumed.
|
|
info (dict): Dictionary for information to be used for custom
|
|
training operations.
|
|
|
|
Returns:
|
|
A dict of metrics from training.
|
|
"""
|
|
if not hasattr(self, "model"):
|
|
raise RuntimeError("Either set self.model in setup function or "
|
|
"override this method to implement a custom "
|
|
"training loop.")
|
|
model = self.model
|
|
scheduler = None
|
|
if hasattr(self, "scheduler"):
|
|
scheduler = self.scheduler
|
|
|
|
if self.use_tqdm and self.world_rank == 0:
|
|
desc = ""
|
|
if info is not None and "epoch_idx" in info:
|
|
if "num_epochs" in info:
|
|
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
|
|
else:
|
|
desc = f"{info['epoch_idx'] + 1}e"
|
|
|
|
# TODO: Implement len for Dataset?
|
|
total = info[NUM_STEPS]
|
|
if total is None:
|
|
if hasattr(iterator, "__len__"):
|
|
total = len(iterator)
|
|
|
|
_progress_bar = tqdm(
|
|
total=total, desc=desc, unit="batch", leave=False)
|
|
|
|
metric_meters = AverageMeterCollection()
|
|
|
|
model.train()
|
|
for batch_idx, batch in enumerate(iterator):
|
|
batch_info = {
|
|
"batch_idx": batch_idx,
|
|
"global_step": self.global_step
|
|
}
|
|
batch_info.update(info)
|
|
metrics = self.train_batch(batch, batch_info=batch_info)
|
|
|
|
if self.use_tqdm and self.world_rank == 0:
|
|
_progress_bar.n = batch_idx + 1
|
|
postfix = {}
|
|
if "train_loss" in metrics:
|
|
postfix.update(loss=metrics["train_loss"])
|
|
_progress_bar.set_postfix(postfix)
|
|
|
|
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH:
|
|
scheduler.step()
|
|
|
|
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
|
|
self.global_step += 1
|
|
|
|
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH:
|
|
scheduler.step()
|
|
|
|
return metric_meters.summary()
|
|
|
|
def train_batch(self, batch, batch_info):
|
|
"""Computes loss and updates the model over one batch.
|
|
|
|
This method is responsible for computing the loss and gradient and
|
|
updating the model.
|
|
|
|
By default, this method implementation assumes that batches
|
|
are in (\\*features, labels) format. So we also support multiple inputs
|
|
model. If using amp/fp16 training, it will also scale the loss
|
|
automatically.
|
|
|
|
You can provide custom loss metrics and training operations if you
|
|
override this method.
|
|
|
|
You do not need to override this method if you plan to
|
|
override ``train_epoch``.
|
|
|
|
Args:
|
|
batch: One item of the validation iterator.
|
|
batch_info (dict): Information dict passed in from ``train_epoch``.
|
|
|
|
Returns:
|
|
A dictionary of metrics.
|
|
By default, this dictionary contains "loss" and "num_samples".
|
|
"num_samples" corresponds to number of datapoints in the batch.
|
|
However, you can provide any number of other values.
|
|
Consider returning "num_samples" in the metrics because
|
|
by default, ``train_epoch`` uses "num_samples" to
|
|
calculate averages.
|
|
|
|
"""
|
|
if not hasattr(self, "model"):
|
|
raise RuntimeError("Either set self.model in setup function or "
|
|
"override this method to implement a custom "
|
|
"training loop.")
|
|
if not hasattr(self, "optimizer"):
|
|
raise RuntimeError("Either set self.optimizer in setup function "
|
|
"or override this method to implement a custom "
|
|
"training loop.")
|
|
if not hasattr(self, "criterion"):
|
|
raise RuntimeError("Either set self.criterion in setup function "
|
|
"or override this method to implement a custom "
|
|
"training loop.")
|
|
model = self.model
|
|
optimizer = self.optimizer
|
|
criterion = self.criterion
|
|
# unpack features into list to support multiple inputs model
|
|
*features, target = batch
|
|
# Create non_blocking tensors for distributed training
|
|
if self.use_gpu:
|
|
features = [
|
|
feature.cuda(non_blocking=True) for feature in features
|
|
]
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
# Compute output.
|
|
with self.timers.record("fwd"):
|
|
output = model(*features)
|
|
loss = criterion(output, target)
|
|
|
|
# Compute gradients in a backward pass.
|
|
with self.timers.record("grad"):
|
|
optimizer.zero_grad()
|
|
if self.use_fp16:
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
# Call step of optimizer to update model params.
|
|
with self.timers.record("apply"):
|
|
optimizer.step()
|
|
|
|
return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}
|
|
|
|
def validate(self, val_iterator, info):
|
|
"""Runs one standard validation pass over the val_iterator.
|
|
|
|
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
|
|
over the validation dataloader.
|
|
|
|
You also do not need to call ``validate_batch`` if overriding this
|
|
method.
|
|
|
|
Args:
|
|
val_iterator (iter): Iterable constructed from the
|
|
validation dataloader.
|
|
info: (dict): Dictionary for information to be used for custom
|
|
validation operations.
|
|
|
|
Returns:
|
|
A dict of metrics from the evaluation.
|
|
By default, returns "val_accuracy" and "val_loss"
|
|
which is computed by aggregating "loss" and "correct" values
|
|
from ``validate_batch`` and dividing it by the sum of
|
|
``num_samples`` from all calls to ``self.validate_batch``.
|
|
"""
|
|
if not hasattr(self, "model"):
|
|
raise RuntimeError("Either set self.model in setup function or "
|
|
"override this method to implement a custom "
|
|
"validation loop.")
|
|
model = self.model
|
|
metric_meters = AverageMeterCollection()
|
|
|
|
# switch to evaluate mode
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for batch_idx, batch in enumerate(val_iterator):
|
|
batch_info = {"batch_idx": batch_idx}
|
|
batch_info.update(info)
|
|
metrics = self.validate_batch(batch, batch_info)
|
|
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
|
|
|
|
return metric_meters.summary()
|
|
|
|
def validate_batch(self, batch, batch_info):
|
|
"""Calcuates the loss and accuracy over a given batch.
|
|
|
|
You can override this method to provide arbitrary metrics.
|
|
|
|
Same as ``train_batch``, this method implementation assumes that
|
|
batches are in (\\*features, labels) format by default. So we also
|
|
support multiple inputs model.
|
|
|
|
Args:
|
|
batch: One item of the validation iterator.
|
|
batch_info (dict): Contains information per batch from
|
|
``validate()``.
|
|
|
|
Returns:
|
|
A dict of metrics.
|
|
By default, returns "val_loss", "val_accuracy", and
|
|
"num_samples". When overriding, consider returning
|
|
"num_samples" in the metrics because
|
|
by default, ``validate`` uses "num_samples" to
|
|
calculate averages.
|
|
"""
|
|
if not hasattr(self, "model"):
|
|
raise RuntimeError("Either set self.model in setup function or "
|
|
"override this method to implement a custom "
|
|
"training loop.")
|
|
if not hasattr(self, "criterion"):
|
|
raise RuntimeError("Either set self.criterion in setup function "
|
|
"or override this method to implement a custom "
|
|
"training loop.")
|
|
model = self.model
|
|
criterion = self.criterion
|
|
# unpack features into list to support multiple inputs model
|
|
*features, target = batch
|
|
if self.use_gpu:
|
|
features = [
|
|
feature.cuda(non_blocking=True) for feature in features
|
|
]
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
# compute output
|
|
|
|
with self.timers.record("eval_fwd"):
|
|
output = model(*features)
|
|
loss = criterion(output, target)
|
|
_, predicted = torch.max(output.data, 1)
|
|
|
|
num_correct = (predicted == target).sum().item()
|
|
num_samples = target.size(0)
|
|
return {
|
|
"val_loss": loss.item(),
|
|
"val_accuracy": num_correct / num_samples,
|
|
NUM_SAMPLES: num_samples
|
|
}
|
|
|
|
def state_dict(self):
|
|
"""Override this to return a representation of the operator state.
|
|
Any argument passed into self.register and self.register_data will
|
|
automatically be saved.
|
|
Use this method to save any additional state. If your TorchTrainer
|
|
is on a CPU-only machine, make sure this method converts all state
|
|
to be CPU-compatible.
|
|
|
|
Returns:
|
|
dict: The state dict of the operator."""
|
|
pass
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Override this to load the representation of the operator state.
|
|
Anything passed into self.register and self.register_data will
|
|
automatically be loaded. Use this method to load any additional state.
|
|
Args:
|
|
state_dict (dict): State dict as returned by the operator. """
|
|
pass
|
|
|
|
@classmethod
|
|
def from_creators(cls,
|
|
model_creator,
|
|
optimizer_creator,
|
|
data_creator=None,
|
|
loss_creator=None,
|
|
scheduler_creator=None,
|
|
serialize_data_creation=True):
|
|
"""A utility method to create a custom TrainingOperator class from
|
|
creator functions. This is useful for backwards compatibility with
|
|
previous versions of Ray. To provide custom training and validation,
|
|
you should subclass the class that is returned by this method instead
|
|
of ``TrainingOperator``.
|
|
|
|
Args:
|
|
model_creator (dict -> Model(s)): Constructor function that takes
|
|
in config and returns the model(s) to be optimized. These
|
|
must be ``torch.nn.Module`` objects. If multiple models are
|
|
returned, a ``training_operator_cls`` must be specified.
|
|
You do not need to handle GPU/devices in this function;
|
|
RaySGD will do that under the hood.
|
|
data_creator (dict -> Iterable(s)): Constructor function
|
|
that takes in the passed config and returns one or
|
|
two Iterable objects. Note that even though two Iterable
|
|
objects can be returned, only one will be used for training,
|
|
and the other will be used for validation. If not provided,
|
|
you must pass in a Dataset to ``TorchTrainer.train``.
|
|
optimizer_creator ((models, dict) -> optimizers): Constructor
|
|
function that takes in the return values from
|
|
``model_creator`` and the passed config and returns One or
|
|
more Torch optimizer objects. You do not need to handle
|
|
GPU/devices in this function; ``RaySGD`` will do that for you.
|
|
loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
|
|
function for the training loss. This can be either a function
|
|
that takes in the provided config for customization or a
|
|
subclass of ``torch.nn.modules.loss._Loss``, which is most
|
|
Pytorch loss classes. For example,
|
|
``loss_creator=torch.nn.BCELoss``. If not provided, you must
|
|
provide a custom TrainingOperator.
|
|
scheduler_creator ((optimizers, dict) -> scheduler):
|
|
A constructor function for the torch scheduler. This is
|
|
a function that takes in the generated optimizers (from
|
|
``optimizer_creator``) provided config for customization.
|
|
Be sure to set ``scheduler_step_freq`` to increment the
|
|
scheduler correctly.
|
|
serialize_data_creation (bool): A filelock will be used
|
|
to ensure no race conditions in data downloading among
|
|
different workers on the same node (using the local file
|
|
system). Defaults to True.
|
|
|
|
Returns:
|
|
A TrainingOperator class with a ``setup`` method that utilizes
|
|
the passed in creator functions.
|
|
"""
|
|
|
|
if not (callable(model_creator) and callable(optimizer_creator)):
|
|
raise ValueError(
|
|
"Must provide a callable model_creator and optimizer_creator.")
|
|
|
|
class CustomCreatorOperator(CreatorOperator):
|
|
_model_creator = model_creator
|
|
_optimizer_creator = optimizer_creator
|
|
_data_creator = data_creator
|
|
_loss_creator = loss_creator
|
|
_scheduler_creator = scheduler_creator
|
|
_serialize_data_creation = serialize_data_creation
|
|
|
|
return CustomCreatorOperator
|
|
|
|
@property
|
|
def device(self):
|
|
"""torch.device: The appropriate torch device, at your convenience."""
|
|
return self._device
|
|
|
|
@property
|
|
def config(self):
|
|
"""dict: Provided into TorchTrainer."""
|
|
return self._config
|
|
|
|
@property
|
|
def world_rank(self):
|
|
"""int: The rank of the parent runner. Always 0 if not distributed."""
|
|
return self._world_rank
|
|
|
|
@property
|
|
def use_gpu(self):
|
|
"""Returns True if cuda is available and use_gpu is True."""
|
|
return self._use_gpu
|
|
|
|
@property
|
|
def use_fp16(self):
|
|
"""bool: Whether the model and optimizer have been FP16 enabled."""
|
|
return self._use_fp16
|
|
|
|
@property
|
|
def use_tqdm(self):
|
|
"""bool: Whether tqdm progress bars are enabled."""
|
|
return self._use_tqdm
|
|
|
|
@property
|
|
def device_ids(self):
|
|
"""List[int]: Device IDs for the model.
|
|
|
|
This is useful for using batch norm with DistributedDataParallel.
|
|
"""
|
|
return self._device_ids
|
|
|
|
@property
|
|
def scheduler_step_freq(self):
|
|
"""Optional[str]: The ``scheduler_step_freq`` passed into
|
|
``TorchTrainer``
|
|
|
|
This is useful to determine when to call scheduler.step.
|
|
"""
|
|
return self._scheduler_step_freq
|
|
|
|
|
|
class CreatorOperator(TrainingOperator):
|
|
"""A subclass of TrainingOperator specifically for defining training
|
|
state using creator functions.
|
|
"""
|
|
|
|
def _validate_loaders(self, loaders):
|
|
assert loaders, "Loaders need to be returned in data_creator."
|
|
if isinstance(loaders, (tuple, list)):
|
|
if len(loaders) == 1:
|
|
return loaders, None
|
|
elif len(loaders) == 2:
|
|
return loaders
|
|
else:
|
|
raise ValueError(
|
|
f"Number of loaders must be <= 2. Got {loaders}")
|
|
# No great way of checking type otherwise
|
|
return loaders, None
|
|
|
|
def _initialize_dataloaders(self, config):
|
|
logger.debug("Instantiating dataloaders.")
|
|
loaders = None
|
|
if self._serialize_data_creation:
|
|
logger.debug("Serializing the dataloading process.")
|
|
with FileLock(
|
|
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
|
|
loaders = self.__class__._data_creator(config)
|
|
else:
|
|
loaders = self.__class__._data_creator(config)
|
|
train_loader, val_loader = self._validate_loaders(loaders)
|
|
|
|
return train_loader, val_loader
|
|
|
|
def setup(self, config):
|
|
kwargs = {}
|
|
logger.debug("Loading data.")
|
|
train_loader = None
|
|
validation_loader = None
|
|
if self.__class__._data_creator and callable(
|
|
self.__class__._data_creator):
|
|
train_loader, validation_loader = self._initialize_dataloaders(
|
|
config)
|
|
|
|
logger.debug("Creating model")
|
|
models = self.__class__._model_creator(config)
|
|
|
|
kwargs["models"] = models
|
|
|
|
logger.debug("Creating optimizer.")
|
|
optimizers = self.__class__._optimizer_creator(models, config)
|
|
|
|
kwargs["optimizers"] = optimizers
|
|
|
|
if self.__class__._scheduler_creator:
|
|
logger.debug("Creating scheduler.")
|
|
schedulers = self.__class__._scheduler_creator(optimizers, config)
|
|
kwargs["schedulers"] = schedulers
|
|
|
|
if self.__class__._loss_creator:
|
|
logger.debug("Creating loss.")
|
|
if inspect.isclass(self.__class__._loss_creator) and issubclass(
|
|
self.__class__._loss_creator, torch.nn.modules.loss._Loss):
|
|
criterion = self.__class__._loss_creator()
|
|
else:
|
|
criterion = self.__class__._loss_creator(config)
|
|
kwargs["criterion"] = criterion
|
|
|
|
state = self.register(**kwargs)
|
|
self.models, self.optimizers = state[:2]
|
|
if isinstance(self.models, tuple):
|
|
self.model = self.models[0]
|
|
else:
|
|
self.model = self.models
|
|
|
|
if isinstance(self.optimizers, tuple):
|
|
self.optimizer = self.optimizers[0]
|
|
else:
|
|
self.optimizer = self.optimizers
|
|
|
|
if len(state) >= 3:
|
|
self.criterion = state[2]
|
|
if len(state) == 4:
|
|
self.schedulers = state[3]
|
|
if isinstance(self.schedulers, tuple):
|
|
self.scheduler = self.schedulers[0]
|
|
else:
|
|
self.scheduler = self.schedulers
|
|
|
|
self.register_data(
|
|
train_loader=train_loader, validation_loader=validation_loader)
|
|
|
|
|
|
def get_test_operator(operator_cls):
|
|
class _TestingOperator(operator_cls):
|
|
def train_epoch(self, iterator, info):
|
|
func = self.config.get("custom_func")
|
|
if callable(func):
|
|
return func(self, iterator, info)
|
|
return {"done": 1}
|
|
|
|
return _TestingOperator
|
|
|
|
|
|
def get_test_metrics_operator(operator_cls):
|
|
class _TestMetricsOperator(operator_cls):
|
|
def setup(self, config):
|
|
super(_TestMetricsOperator, self).setup(config)
|
|
self._train_scores = config["scores"].copy()
|
|
self._val_scores = config["val_scores"].copy()
|
|
self.key = config["key"]
|
|
|
|
def train_batch(self, batch, batch_info=None):
|
|
metrics = super(_TestMetricsOperator, self).train_batch(
|
|
batch, batch_info)
|
|
num_samples = metrics[NUM_SAMPLES]
|
|
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
|
|
return metrics
|
|
|
|
def validate_batch(self, batch, batch_info=None):
|
|
metrics = super(_TestMetricsOperator, self).validate_batch(
|
|
batch, batch_info)
|
|
num_samples = metrics[NUM_SAMPLES]
|
|
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
|
|
return metrics
|
|
|
|
return _TestMetricsOperator
|