mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[sgd] Extend distributed pytorch functionality (#5675)
* raysgd * apply fn * double quotes * removed duplicate TimerStat * removed duplicate find_free_port * imports in pytorch_trainer * init doc * ray.experimental * remove resize example * resnet example * cifar * Fix up after kwargs * data_dir and dataloader_workers args * formatting * loss * init * update code * lint * smoketest * better_configs * fix * fix * fix * train_loader * fixdocs * ok * ok * fix * fix_update * fix * fix * done * fix * fix * fix * small * lint * fix * fix * fix_test * fix * validate * fix * fi
This commit is contained in:
committed by
Richard Liaw
parent
82be14f943
commit
8f6d73a93a
@@ -0,0 +1,224 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
import torch.utils.data
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable)
|
||||
from ray.experimental.sgd.pytorch.resnet import ResNet18
|
||||
|
||||
|
||||
def initialization_hook(runner):
|
||||
print("NCCL DEBUG SET")
|
||||
# Need this for avoiding a connection restart issue
|
||||
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
|
||||
os.environ["NCCL_LL_THRESHOLD"] = "0"
|
||||
os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
|
||||
def train(model, train_iterator, criterion, optimizer, config):
|
||||
model.train()
|
||||
train_loss, total_num, correct = 0, 0, 0
|
||||
for batch_idx, (data, target) in enumerate(train_iterator):
|
||||
if config.get("test_mode") and batch_idx > 0:
|
||||
break
|
||||
# get small model update
|
||||
if torch.cuda.is_available():
|
||||
data, target = data.cuda(), target.cuda()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
train_loss += loss.item() * target.size(0)
|
||||
total_num += target.size(0)
|
||||
_, predicted = output.max(1)
|
||||
correct += predicted.eq(target).sum().item()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
stats = {
|
||||
"train_loss": train_loss / total_num,
|
||||
"train_acc": correct / total_num
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def validate(model, val_iterator, criterion, config):
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
total_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (features, target) in enumerate(val_iterator):
|
||||
if config.get("test_mode") and batch_idx > 10:
|
||||
break
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
# compute output
|
||||
output = model(features)
|
||||
loss = criterion(output, target)
|
||||
total_loss += loss.item() * target.size(0)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
stats = {"mean_accuracy": correct / total, "mean_loss": total_loss / total}
|
||||
return stats
|
||||
|
||||
|
||||
def cifar_creator(batch_size, config):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
from filelock import FileLock
|
||||
with FileLock(os.path.expanduser("~/data.lock")):
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
validation_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=validation_sampler)
|
||||
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer"""
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 0.1))
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
config = {"test_mode": test_mode}
|
||||
trainer1 = PyTorchTrainer(
|
||||
ResNet18,
|
||||
cifar_creator,
|
||||
optimizer_creator,
|
||||
lambda config: nn.CrossEntropyLoss(),
|
||||
initialization_hook=initialization_hook,
|
||||
train_function=train,
|
||||
validation_function=validate,
|
||||
num_replicas=num_replicas,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(5):
|
||||
stats = trainer1.train()
|
||||
print(stats)
|
||||
|
||||
print(trainer1.validate())
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
config = {
|
||||
"model_creator": ResNet18,
|
||||
"data_creator": cifar_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.CrossEntropyLoss(),
|
||||
"train_function": train,
|
||||
"validation_function": validate,
|
||||
"num_replicas": num_replicas,
|
||||
"initialization_hook": initialization_hook,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 16 if test_mode else 512,
|
||||
"config": {
|
||||
"lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]),
|
||||
"test_mode": test_mode
|
||||
},
|
||||
"backend": "nccl" if use_gpu else "gloo"
|
||||
}
|
||||
|
||||
analysis = tune.run(
|
||||
PyTorchTrainable,
|
||||
num_samples=2,
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=2)
|
||||
|
||||
return analysis.get_best_config(metric="mean_accuracy", mode="max")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ray-redis-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", default=False, help="Tune training")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init(address=args.ray_redis_address, log_to_driver=False)
|
||||
|
||||
if args.tune:
|
||||
tune_example(
|
||||
num_replicas=args.num_replicas,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
else:
|
||||
train_example(
|
||||
num_replicas=args.num_replicas,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
@@ -9,8 +9,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from tensorflow.keras.datasets import cifar10
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
@@ -27,13 +25,14 @@ num_classes = 10
|
||||
|
||||
|
||||
def fetch_keras_data():
|
||||
import tensorflow as tf
|
||||
# The data, split between train and test sets:
|
||||
with FileLock(os.path.expanduser("~/.cifar.lock")):
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
||||
# Convert class vectors to binary class matrices.
|
||||
y_train = keras.utils.to_categorical(y_train, num_classes)
|
||||
y_test = keras.utils.to_categorical(y_test, num_classes)
|
||||
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
|
||||
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
|
||||
|
||||
x_train = x_train.astype("float32")
|
||||
x_test = x_test.astype("float32")
|
||||
@@ -47,6 +46,7 @@ input_shape = x_train.shape[1:]
|
||||
|
||||
|
||||
def create_model(config):
|
||||
import tensorflow as tf
|
||||
model = Sequential()
|
||||
model.add(Conv2D(32, (3, 3), padding="same", input_shape=input_shape))
|
||||
model.add(Activation("relu"))
|
||||
@@ -70,7 +70,7 @@ def create_model(config):
|
||||
model.add(Activation("softmax"))
|
||||
|
||||
# initiate RMSprop optimizer
|
||||
opt = keras.optimizers.RMSprop(lr=0.001, decay=1e-6)
|
||||
opt = tf.keras.optimizers.RMSprop(lr=0.001, decay=1e-6)
|
||||
|
||||
# Let"s train the model using RMSprop
|
||||
model.compile(
|
||||
@@ -79,6 +79,7 @@ def create_model(config):
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
import tensorflow as tf
|
||||
batch_size = config["batch_size"]
|
||||
(x_train, y_train), (x_test, y_test) = fetch_keras_data()
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
@@ -131,6 +132,7 @@ def _make_generator(x_train, y_train, batch_size):
|
||||
|
||||
|
||||
def data_augmentation_creator(config):
|
||||
import tensorflow as tf
|
||||
batch_size = config["batch_size"]
|
||||
(x_train, y_train), (x_test, y_test) = fetch_keras_data()
|
||||
trainset = tf.data.Dataset.from_generator(
|
||||
|
||||
@@ -3,9 +3,9 @@ cluster_name: sgd-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 1
|
||||
initial_workers: 1
|
||||
max_workers: 1
|
||||
min_workers: 3
|
||||
initial_workers: 3
|
||||
max_workers: 3
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
@@ -28,17 +28,17 @@ auth:
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0757fc5a639fe7666
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0757fc5a639fe7666
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
@@ -48,8 +48,7 @@ worker_nodes:
|
||||
|
||||
setup_commands:
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev6-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- conda install -y pytorch torchvision cudatoolkit=9.0 -c pytorch
|
||||
- pip install -U ipdb ray[rllib]
|
||||
- pip install -U ipdb ray[rllib] torch torchvision
|
||||
|
||||
|
||||
file_mounts: {
|
||||
|
||||
@@ -16,6 +16,8 @@ import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
|
||||
|
||||
@@ -41,15 +43,34 @@ def model_creator(config):
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns criterion, optimizer"""
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
|
||||
return criterion, optimizer
|
||||
"""Returns optimizer."""
|
||||
return torch.optim.SGD(model.parameters(), lr=1e-4)
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training set, validation set"""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
def data_creator(batch_size, config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
sampler=validation_sampler)
|
||||
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False):
|
||||
@@ -57,6 +78,7 @@ def train_example(num_replicas=1, use_gpu=False):
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=512,
|
||||
|
||||
@@ -15,6 +15,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
@@ -42,15 +44,33 @@ def model_creator(config):
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns criterion, optimizer"""
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
return criterion, optimizer
|
||||
"""Returns optimizer."""
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training set, validation set"""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
def data_creator(batch_size, config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
validation_sampler = None
|
||||
if distributed.is_initialized():
|
||||
validation_sampler = DistributedSampler(validation_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(validation_sampler is None),
|
||||
sampler=validation_sampler)
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False):
|
||||
@@ -58,6 +78,7 @@ def tune_example(num_replicas=1, use_gpu=False):
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(lambda config: nn.MSELoss()),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 512,
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
@@ -15,27 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
class DistributedPyTorchRunner(PyTorchRunner):
|
||||
"""Manages a distributed PyTorch model replica."""
|
||||
|
||||
def __init__(self,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
config=None,
|
||||
batch_size=16,
|
||||
backend="gloo"):
|
||||
def __init__(self, *args, backend="gloo", **kwargs):
|
||||
"""Initializes the runner.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
|
||||
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
|
||||
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||
see pytorch_trainer.py.
|
||||
config (dict): see pytorch_trainer.py.
|
||||
batch_size (int): batch size used by one replica for an update.
|
||||
backend (string): see pytorch_trainer.py.
|
||||
args: Arguments for the PyTorchRunner.
|
||||
kwargs: Keyword arguments for the PyTorchRunner.
|
||||
backend (string): backend used by distributed PyTorch.
|
||||
"""
|
||||
|
||||
super(DistributedPyTorchRunner, self).__init__(
|
||||
model_creator, data_creator, optimizer_creator, config, batch_size)
|
||||
super(DistributedPyTorchRunner, self).__init__(*args, **kwargs)
|
||||
self.backend = backend
|
||||
|
||||
def setup(self, url, world_rank, world_size):
|
||||
@@ -50,7 +37,6 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
self._setup_training()
|
||||
|
||||
def _setup_distributed_pytorch(self, url, world_rank, world_size):
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
with self._timers["setup_proc"]:
|
||||
self.world_rank = world_rank
|
||||
logger.debug(
|
||||
@@ -67,54 +53,34 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
logger.debug("Creating model")
|
||||
self.model = self.model_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(
|
||||
self.model.cuda())
|
||||
else:
|
||||
self.model = torch.nn.parallel.DistributedDataParallelCPU(
|
||||
self.model)
|
||||
self.model = self.model.cuda()
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
|
||||
|
||||
logger.debug("Creating optimizer")
|
||||
self.criterion, self.optimizer = self.optimizer_creator(
|
||||
self.model, self.config)
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizer = self.optimizer_creator(self.model, self.config)
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
self.training_set, self.validation_set = self.data_creator(self.config)
|
||||
|
||||
# TODO: make num_workers configurable
|
||||
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self.training_set)
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
self.training_set,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=(self.train_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=self.train_sampler)
|
||||
|
||||
self.validation_sampler = (
|
||||
torch.utils.data.distributed.DistributedSampler(
|
||||
self.validation_set))
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
self.validation_set,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=(self.validation_sampler is None),
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
sampler=self.validation_sampler)
|
||||
self.train_loader, self.validation_loader = self.data_creator(
|
||||
self.batch_size, self.config)
|
||||
|
||||
def step(self):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
|
||||
Automatically sets epoch of sampler if possible.
|
||||
"""
|
||||
logger.debug("Starting step")
|
||||
self.train_sampler.set_epoch(self.epoch)
|
||||
if hasattr(self.train_loader.sampler, "set_epoch"):
|
||||
self.train_loader.sampler.set_epoch(self.epoch)
|
||||
return super(DistributedPyTorchRunner, self).step()
|
||||
|
||||
def get_state(self):
|
||||
"""Returns the state of the runner."""
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"model": self.model.module.state_dict(),
|
||||
"model": self.model.module.cpu().state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.utils.data
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import pytorch_utils
|
||||
from ray.experimental.sgd.pytorch import utils as pytorch_utils
|
||||
from ray.experimental.sgd import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,23 +20,33 @@ class PyTorchRunner(object):
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=None,
|
||||
validation_function=None,
|
||||
config=None,
|
||||
batch_size=16):
|
||||
"""Initializes the runner.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
|
||||
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
|
||||
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py
|
||||
data_creator (int, dict -> DataLoader, DataLoader): see
|
||||
pytorch_trainer.py.
|
||||
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||
see pytorch_trainer.py.
|
||||
loss_creator (dict -> loss): see pytorch_trainer.py.
|
||||
train_function: see pytorch_trainer.py
|
||||
validation_function: see pytorch_trainer.py
|
||||
config (dict): see pytorch_trainer.py.
|
||||
batch_size (int): see pytorch_trainer.py.
|
||||
"""
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.config = {} if config is None else config
|
||||
self.train_function = train_function or pytorch_utils.train
|
||||
self.validation_function = (validation_function
|
||||
or pytorch_utils.validate)
|
||||
self.batch_size = batch_size
|
||||
self.verbose = True
|
||||
|
||||
@@ -57,26 +67,14 @@ class PyTorchRunner(object):
|
||||
self.model = self.model.cuda()
|
||||
|
||||
logger.debug("Creating optimizer")
|
||||
self.criterion, self.optimizer = self.optimizer_creator(
|
||||
self.model, self.config)
|
||||
self.optimizer = self.optimizer_creator(self.model, self.config)
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
self.training_set, self.validation_set = self.data_creator(self.config)
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
self.training_set,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
pin_memory=False)
|
||||
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
self.validation_set,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
pin_memory=False)
|
||||
self.train_loader, self.validation_loader = self.data_creator(
|
||||
self.batch_size, self.config)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
@@ -90,8 +88,9 @@ class PyTorchRunner(object):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
|
||||
with self._timers["training"]:
|
||||
train_stats = pytorch_utils.train(self.train_loader, self.model,
|
||||
self.criterion, self.optimizer)
|
||||
train_stats = self.train_function(self.model, self.train_loader,
|
||||
self.criterion, self.optimizer,
|
||||
self.config)
|
||||
train_stats["epoch"] = self.epoch
|
||||
|
||||
self.epoch += 1
|
||||
@@ -102,8 +101,9 @@ class PyTorchRunner(object):
|
||||
def validate(self):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
with self._timers["validation"]:
|
||||
validation_stats = pytorch_utils.validate(
|
||||
self.validation_loader, self.model, self.criterion)
|
||||
validation_stats = self.validation_function(
|
||||
self.model, self.validation_loader, self.criterion,
|
||||
self.config)
|
||||
|
||||
validation_stats.update(self.stats())
|
||||
return validation_stats
|
||||
@@ -121,7 +121,7 @@ class PyTorchRunner(object):
|
||||
"""Returns the state of the runner."""
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"model": self.model.state_dict(),
|
||||
"model": self.model.cpu().state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"stats": self.stats()
|
||||
}
|
||||
@@ -133,12 +133,13 @@ class PyTorchRunner(object):
|
||||
self.optimizer.load_state_dict(state["optimizer"])
|
||||
self.epoch = state["stats"]["epoch"]
|
||||
|
||||
def apply_fn(self, fn):
|
||||
return fn(self)
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
del self.validation_loader
|
||||
del self.validation_set
|
||||
del self.train_loader
|
||||
del self.training_set
|
||||
del self.criterion
|
||||
del self.optimizer
|
||||
del self.model
|
||||
|
||||
@@ -11,12 +11,11 @@ import logging
|
||||
import ray
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
from ray.tune.trial import Resources
|
||||
from ray.experimental.sgd.pytorch.distributed_pytorch_runner import (
|
||||
DistributedPyTorchRunner)
|
||||
from ray.experimental.sgd.pytorch import pytorch_utils
|
||||
from ray.experimental.sgd import utils
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +30,11 @@ class PyTorchTrainer(object):
|
||||
def __init__(self,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator=pytorch_utils.sgd_mse_optimizer,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=None,
|
||||
validation_function=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
num_replicas=1,
|
||||
use_gpu=False,
|
||||
@@ -42,12 +45,21 @@ class PyTorchTrainer(object):
|
||||
Args:
|
||||
model_creator (dict -> torch.nn.Module): creates the model
|
||||
using the config.
|
||||
data_creator (dict -> Dataset, Dataset): creates the training
|
||||
and validation data sets using the config.
|
||||
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||
data_creator (int, dict -> DataLoader, DataLoader): Function that
|
||||
takes in (batch_size, config) and returns two Torch DataLoader
|
||||
objects.
|
||||
optimizer_creator (torch.nn.Module, dict -> optimizer):
|
||||
creates the loss and optimizer using the model and the config.
|
||||
config (dict): configuration passed to 'model_creator',
|
||||
'data_creator', and 'optimizer_creator'.
|
||||
loss_creator (dict -> loss): Creates the loss function/criterion
|
||||
using the config.
|
||||
train_function: Trains a model for a epoch. This takes in (
|
||||
model, train_dataloader, criterion, optimizer, config), and
|
||||
returns a dict of training stats.
|
||||
validation_function: Runs validation. This takes in (
|
||||
model, val_dataloader, criterion, config) and returns a dict of
|
||||
validation stats.
|
||||
config (dict): configuration passed to "model_creator",
|
||||
"data_creator", "optimizer_creator", and "loss_creator".
|
||||
num_replicas (int): the number of workers used in distributed
|
||||
training.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
@@ -79,20 +91,29 @@ class PyTorchTrainer(object):
|
||||
num_cpus=1, num_gpus=int(use_gpu))(PyTorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(model_creator, data_creator, optimizer_creator,
|
||||
self.config, batch_size)
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(self.workers[0].setup.remote())
|
||||
else:
|
||||
# Geneate actor class
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(use_gpu))(DistributedPyTorchRunner)
|
||||
# Compute batch size per replica
|
||||
batch_size_per_replica = batch_size // num_replicas
|
||||
if batch_size % num_replicas > 0:
|
||||
new_batch_size = batch_size_per_replica * num_replicas
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
("Changing batch size from {old_batch_size} to "
|
||||
"{new_batch_size} to evenly distribute batches across "
|
||||
"{num_replicas} replicas.").format(
|
||||
@@ -101,10 +122,21 @@ class PyTorchTrainer(object):
|
||||
num_replicas=num_replicas))
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(model_creator, data_creator, optimizer_creator,
|
||||
self.config, batch_size_per_replica, backend)
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
backend=backend,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size_per_replica)
|
||||
for i in range(num_replicas)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
ip = ray.get(self.workers[0].get_node_ip.remote())
|
||||
port = ray.get(self.workers[0].find_free_port.remote())
|
||||
@@ -125,12 +157,16 @@ class PyTorchTrainer(object):
|
||||
[s["train_loss"] for s in worker_stats])
|
||||
return train_stats
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
|
||||
|
||||
def validate(self):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
worker_stats = ray.get([w.validate.remote() for w in self.workers])
|
||||
validation_stats = worker_stats[0].copy()
|
||||
validation_stats["validation_loss"] = np.mean(
|
||||
[s["validation_loss"] for s in worker_stats])
|
||||
if "validation_loss" in validation_stats:
|
||||
validation_stats["validation_loss"] = np.nanmean(
|
||||
[s.get("validation_loss", np.nan) for s in worker_stats])
|
||||
return validation_stats
|
||||
|
||||
def get_model(self):
|
||||
@@ -179,23 +215,15 @@ class PyTorchTrainable(Trainable):
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
|
||||
|
||||
def _setup(self, config):
|
||||
self._trainer = PyTorchTrainer(
|
||||
model_creator=config["model_creator"],
|
||||
data_creator=config["data_creator"],
|
||||
optimizer_creator=config["optimizer_creator"],
|
||||
config=config,
|
||||
num_replicas=config["num_replicas"],
|
||||
use_gpu=config["use_gpu"],
|
||||
batch_size=config["batch_size"],
|
||||
backend=config["backend"])
|
||||
self._trainer = PyTorchTrainer(**config)
|
||||
|
||||
def _train(self):
|
||||
|
||||
train_stats = self._trainer.train()
|
||||
validation_stats = self._trainer.validate()
|
||||
|
||||
train_stats.update(validation_stats)
|
||||
|
||||
# output {"mean_loss": test_loss, "mean_accuracy": accuracy}
|
||||
return train_stats
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"""ResNet in PyTorch.
|
||||
|
||||
Copied from https://github.com/kuangliu/pytorch-cifar/
|
||||
blob/ab908327d44bf9b1d22cd333a4466e85083d3f21/models/resnet.py
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, self.expansion * planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNet18(_):
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
|
||||
def ResNet34(_):
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def ResNet50(_):
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def ResNet101(_):
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3])
|
||||
|
||||
|
||||
def ResNet152(_):
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3])
|
||||
+32
-23
@@ -4,18 +4,17 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.experimental.sgd import utils
|
||||
from ray.experimental.sgd.utils import TimerStat
|
||||
|
||||
|
||||
def train(train_iterator, model, criterion, optimizer):
|
||||
def train(model, train_iterator, criterion, optimizer, config):
|
||||
"""Runs 1 training epoch"""
|
||||
batch_time = utils.AverageMeter()
|
||||
data_time = utils.AverageMeter()
|
||||
losses = utils.AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
timers = {k: utils.TimerStat() for k in ["d2h", "fwd", "grad", "apply"]}
|
||||
timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]}
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
@@ -63,16 +62,17 @@ def train(train_iterator, model, criterion, optimizer):
|
||||
return stats
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion):
|
||||
batch_time = utils.AverageMeter()
|
||||
losses = utils.AverageMeter()
|
||||
def validate(model, val_iterator, criterion, config):
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (features, target) in enumerate(val_loader):
|
||||
for i, (features, target) in enumerate(val_iterator):
|
||||
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
@@ -81,6 +81,9 @@ def validate(val_loader, model, criterion):
|
||||
# compute output
|
||||
output = model(features)
|
||||
loss = criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), features.size(0))
|
||||
@@ -90,18 +93,24 @@ def validate(val_loader, model, criterion):
|
||||
end = time.time()
|
||||
|
||||
stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg}
|
||||
stats.update(mean_accuracy=correct / total)
|
||||
return stats
|
||||
|
||||
|
||||
def sgd_mse_optimizer(model, config):
|
||||
"""Returns the mean squared error criterion and SGD optimizer.
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): the model to optimize.
|
||||
config (dict): configuration for the optimizer.
|
||||
lr (float): the learning rate. defaults to 0.01.
|
||||
"""
|
||||
learning_rate = config.get("lr", 0.01)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
||||
return criterion, optimizer
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import pytest
|
||||
import tempfile
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
from ray import tune
|
||||
@@ -23,6 +24,7 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["validation_loss"]
|
||||
@@ -45,6 +47,7 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(lambda config: nn.MSELoss()),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": False,
|
||||
"batch_size": 512,
|
||||
@@ -76,6 +79,7 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
trainer1.train()
|
||||
|
||||
@@ -90,6 +94,7 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
trainer2.restore(filename)
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _try_import_strategy():
|
||||
"""Late import for Tesnorflow"""
|
||||
from tensorflow.distribute.experimental import MultiWorkerMirroredStrategy
|
||||
return MultiWorkerMirroredStrategy
|
||||
import tensorflow as tf
|
||||
return tf.distribute.experimental.MultiWorkerMirroredStrategy
|
||||
|
||||
|
||||
class TFRunner(object):
|
||||
|
||||
Reference in New Issue
Block a user