mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 10:00:00 +08:00
@@ -24,17 +24,16 @@ You can start a ``TorchTrainer`` with the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
|
||||
import ray
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.examples.train_example import LinearDataset
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
return torch.nn.Linear(1, 1)
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
@@ -42,21 +41,21 @@ You can start a ``TorchTrainer`` with the following:
|
||||
return torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
def data_creator(config):
|
||||
train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
|
||||
val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
|
||||
return train_loader, val_loader
|
||||
|
||||
ray.init()
|
||||
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=2,
|
||||
use_gpu=True,
|
||||
batch_size=512,
|
||||
backend="nccl")
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=torch.nn.MSELoss,
|
||||
num_workers=2,
|
||||
use_gpu=False,
|
||||
config={"batch_size": 64})
|
||||
|
||||
stats = trainer1.train()
|
||||
print(stats)
|
||||
|
||||
@@ -3,7 +3,7 @@ Distributed PyTorch
|
||||
|
||||
The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch. The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts.
|
||||
|
||||
Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_replicas``), each of which is managed by a Ray actor.
|
||||
Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor.
|
||||
|
||||
.. image:: raysgd-actors.svg
|
||||
:align: center
|
||||
@@ -139,7 +139,7 @@ You can also set the number of workers and whether the workers will use GPUs:
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
config={"lr": 0.001},
|
||||
num_replicas=100,
|
||||
num_workers=100,
|
||||
use_gpu=True)
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ Below is a partial example of a custom ``TrainingOperator`` that provides a ``tr
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.BCELoss,
|
||||
training_operator_cls=GANOperator,
|
||||
num_replicas=num_replicas,
|
||||
num_workers=num_workers,
|
||||
config=config,
|
||||
use_gpu=True,
|
||||
batch_size=128)
|
||||
@@ -320,7 +320,7 @@ Use the ``initialization_hook`` parameter to initialize state on each worker pro
|
||||
loss_creator=nn.MSELoss,
|
||||
initialization_hook=initialization_hook,
|
||||
config={"lr": 0.001}
|
||||
num_replicas=100,
|
||||
num_workers=100,
|
||||
use_gpu=True)
|
||||
|
||||
Save and Load
|
||||
@@ -339,7 +339,7 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load``
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
trainer_2.restore(checkpoint_path)
|
||||
|
||||
|
||||
@@ -366,7 +366,7 @@ You can enable mixed precision training for PyTorch with the ``use_fp16`` flag.
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=4,
|
||||
num_workers=4,
|
||||
use_fp16=True
|
||||
)
|
||||
|
||||
@@ -383,7 +383,7 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=4,
|
||||
num_workers=4,
|
||||
use_fp16=True,
|
||||
apex_args={
|
||||
opt_level="O3",
|
||||
@@ -417,7 +417,7 @@ After connecting, you can scale up the number of workers seamlessly across multi
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_replicas=100
|
||||
num_workers=100
|
||||
)
|
||||
trainer.train()
|
||||
model = trainer.get_model()
|
||||
|
||||
Reference in New Issue
Block a user