[sgd] Readme fix (#7564)

* readme fix

* replicas
This commit is contained in:
Richard Liaw
2020-03-11 13:40:18 -07:00
committed by GitHub
parent b70f31339c
commit d046faeb9c
2 changed files with 25 additions and 26 deletions
+17 -18
View File
@@ -24,17 +24,16 @@ You can start a ``TorchTrainer`` with the following:
.. code-block:: python
import numpy as np
import torch
import torch.nn as nn
from torch import distributed
import ray
from ray.util.sgd import TorchTrainer
from ray.util.sgd.examples.train_example import LinearDataset
from ray.util.sgd.torch.examples.train_example import LinearDataset
import torch
from torch.utils.data import DataLoader
def model_creator(config):
return nn.Linear(1, 1)
return torch.nn.Linear(1, 1)
def optimizer_creator(model, config):
@@ -42,21 +41,21 @@ You can start a ``TorchTrainer`` with the following:
return torch.optim.SGD(model.parameters(), lr=1e-2)
def data_creator(batch_size, config):
"""Returns training dataloader, validation dataloader."""
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
def data_creator(config):
train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
return train_loader, val_loader
ray.init()
trainer1 = TorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=2,
use_gpu=True,
batch_size=512,
backend="nccl")
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=2,
use_gpu=False,
config={"batch_size": 64})
stats = trainer1.train()
print(stats)
+8 -8
View File
@@ -3,7 +3,7 @@ Distributed PyTorch
The RaySGD ``TorchTrainer`` simplifies distributed model training for PyTorch. The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts.
Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_replicas``), each of which is managed by a Ray actor.
Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor.
.. image:: raysgd-actors.svg
:align: center
@@ -139,7 +139,7 @@ You can also set the number of workers and whether the workers will use GPUs:
loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator,
config={"lr": 0.001},
num_replicas=100,
num_workers=100,
use_gpu=True)
@@ -287,7 +287,7 @@ Below is a partial example of a custom ``TrainingOperator`` that provides a ``tr
optimizer_creator=optimizer_creator,
loss_creator=nn.BCELoss,
training_operator_cls=GANOperator,
num_replicas=num_replicas,
num_workers=num_workers,
config=config,
use_gpu=True,
batch_size=128)
@@ -320,7 +320,7 @@ Use the ``initialization_hook`` parameter to initialize state on each worker pro
loss_creator=nn.MSELoss,
initialization_hook=initialization_hook,
config={"lr": 0.001}
num_replicas=100,
num_workers=100,
use_gpu=True)
Save and Load
@@ -339,7 +339,7 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load``
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=num_replicas)
num_workers=num_workers)
trainer_2.restore(checkpoint_path)
@@ -366,7 +366,7 @@ You can enable mixed precision training for PyTorch with the ``use_fp16`` flag.
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=4,
num_workers=4,
use_fp16=True
)
@@ -383,7 +383,7 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=4,
num_workers=4,
use_fp16=True,
apex_args={
opt_level="O3",
@@ -417,7 +417,7 @@ After connecting, you can scale up the number of workers seamlessly across multi
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=100
num_workers=100
)
trainer.train()
model = trainer.get_model()