Files
ray/python/ray/util/sgd/torch/examples/tune_example.py
T
Ian Rodney 826a9253c6 [docker] Detect CPUs in container correctly (#10507)
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Alex Wu <itswu.alex@gmail.com>
2020-09-14 17:11:47 +00:00

91 lines
2.7 KiB
Python

# yapf: disable
"""
This file holds code for a Distributed Pytorch + Tune page in the docs.
It ignores yapf because yapf doesn't allow comments right after code blocks,
but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import ray
from ray import tune
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
from ray.util.sgd.utils import BATCH_SIZE
from ray.util.sgd.torch.examples.train_example import LinearDataset
def model_creator(config):
return nn.Linear(1, 1)
def optimizer_creator(model, config):
"""Returns optimizer."""
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
def data_creator(config):
"""Returns training dataloader, validation dataloader."""
train_dataset = LinearDataset(2, 5)
val_dataset = LinearDataset(2, 5, size=400)
train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE])
validation_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE])
return train_loader, validation_loader
# __torch_tune_example__
def tune_example(operator_cls, num_workers=1, use_gpu=False):
TorchTrainable = TorchTrainer.as_trainable(
training_operator_cls=operator_cls,
num_workers=num_workers,
use_gpu=use_gpu,
config={BATCH_SIZE: 128}
)
analysis = tune.run(
TorchTrainable,
num_samples=3,
config={"lr": tune.grid_search([1e-4, 1e-3])},
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
# __end_torch_tune_example__
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--address",
type=str,
help="the address to use for Ray")
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=2)
else:
ray.init(address=args.address)
CustomTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator, optimizer_creator=optimizer_creator,
data_creator=data_creator, loss_creator=nn.MSELoss)
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
use_gpu=args.use_gpu)