[tune] TensorFlow Distributed Trainable (#11876)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Keqiu Hu
2020-11-10 14:59:08 -08:00
committed by GitHub
parent 50dbf1a307
commit 0c1bdaef59
6 changed files with 395 additions and 1 deletions
+16
View File
@@ -359,6 +359,14 @@ py_test(
deps = [":tune_lib"],
)
py_test(
name = "test_tensorflow_trainable",
size = "medium",
srcs = ["tests/test_tensorflow_trainable.py"],
tags = ["exclusive", "example", "tensorflow"],
deps = [":tune_lib"],
)
py_test(
name = "test_horovod",
size = "medium",
@@ -376,6 +384,14 @@ py_test(
args = ["--num-workers=2"]
)
py_test(
name = "tf_distributed_keras_example",
size = "small",
srcs = ["examples/tf_distributed_keras_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example", "tensorflow"],
)
py_test(
name = "dragonfly_example",
size = "medium",
@@ -0,0 +1,110 @@
"""
Adapted from
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
"""
import argparse
import tensorflow as tf
import numpy as np
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.keras import TuneReportCheckpointCallback
from ray.tune.integration.tensorflow import (DistributedTrainableCreator,
get_num_workers)
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model(config):
model = tf.keras.Sequential([
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(config.get("hidden", 128), activation="relu"),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(
learning_rate=config.get("lr", 0.05),
momentum=config.get("momentum", 0.5)),
metrics=["accuracy"])
return model
def train_mnist(config, checkpoint_dir=None):
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
per_worker_batch_size = 64
num_workers = get_num_workers()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model(config)
multi_worker_model.fit(
multi_worker_dataset,
epochs=2,
steps_per_epoch=70,
callbacks=[
TuneReportCheckpointCallback(
{
"mean_accuracy": "accuracy"
}, filename="checkpoint")
])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="enables CUDA training")
parser.add_argument(
"--cluster",
action="store_true",
default=False,
help="enables multi-node tuning")
args = parser.parse_args()
tf_trainable = DistributedTrainableCreator(
train_mnist,
use_gpu=args.use_gpu,
num_workers=2,
)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
max_t=400,
grace_period=20)
tune.run(
tf_trainable,
name="exp",
scheduler=sched,
stop={
"mean_accuracy": 0.99,
"training_iteration": 10
},
num_samples=1,
config={
"lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"hidden": tune.sample_from(
lambda spec: np.random.randint(32, 512)),
})
+172
View File
@@ -0,0 +1,172 @@
import json
import ray
import os
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.utils.trainable import TrainableUtil
from ray.util.sgd.utils import find_free_port
from typing import Callable, Dict, Type
def setup_process_group(worker_addresses, index):
"""Set up distributed training info for training task.
Args:
worker_addresses (list): addresses of the workers.
index (int): index of current worker
"""
tf_config = {
"cluster": {
"worker": worker_addresses
},
"task": {
"type": "worker",
"index": index
}
}
os.environ["TF_CONFIG"] = json.dumps(tf_config)
def setup_address():
ip = ray.services.get_node_ip_address()
port = find_free_port()
return f"{ip}:{port}"
class _TensorFlowTrainable(tune.Trainable):
"""Base class for distributed training on Tune."""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def get_remote_worker_options(self) -> Dict[str, int]:
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config: Dict):
self._finished = False
num_workers = self._num_workers
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
self.workers = [
remote_trainable.remote(config=config, )
for _ in range(num_workers)
]
addresses = [
ray.get(worker.execute.remote(lambda _: setup_address()))
for worker in self.workers
]
from functools import partial
setup_on_worker = partial(
setup_process_group, worker_addresses=addresses)
ray.get([
w.execute.remote(lambda _: setup_on_worker(index=index))
for index, w in enumerate(self.workers)
])
def step(self) -> Dict:
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir: str) -> str:
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir: str):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(
func: Callable,
use_gpu: bool = False,
num_workers: int = 2,
num_cpus_per_worker: int = 1) -> Type[_TensorFlowTrainable]:
"""Converts TensorFlow MultiWorkerMirror training to be executable by Tune.
Requires TensorFlow > 2.0 to work, recommends TensorFlow > 2.2.
This function wraps and sets resources for a TF distributed training
function to be used with Tune. It generates a TensorFlow Trainable
which can be a distributed training job.
Note: there is no fault tolerance at the moment.
Args:
func (Callable[[dict], None]): A training function that takes in
a config dict for hyperparameters and should initialize
horovod via horovod.init.
use_gpu (bool); Whether to allocate a GPU per worker.
num_cpus_per_worker (int): Number of CPUs to request
from Ray per worker.
num_workers (int): Number of hosts that each trial is expected
to use.
Returns:
Trainable class that can be passed into `tune.run`.
.. versionadded:: 1.1.0
Example:
.. code-block:: python
# Please refer to full example in tf_distributed_keras_example.py
tf_trainable = DistributedTrainableCreator(
train_mnist,
use_gpu=args.use_gpu,
num_workers=2)
tune.run(tf_trainable,
num_samples=1)
"""
class WrappedDistributedTensorFlowTrainable(_TensorFlowTrainable):
_function = func
_num_workers = num_workers
_num_cpus_per_worker = num_cpus_per_worker
_use_gpu = use_gpu
@classmethod
def default_resource_request(cls, config: Dict) -> Resources:
num_workers_ = int(config.get("num_workers", num_workers))
num_worker_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_workers * num_worker_cpus,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTensorFlowTrainable
def get_num_workers():
"""Retrieve the number of workers in the training job."""
tf_config = json.loads(os.environ["TF_CONFIG"])
num_workers = len(tf_config["cluster"]["worker"])
return num_workers
@@ -0,0 +1,61 @@
import pytest
import ray
from ray.tune.integration.tensorflow import DistributedTrainableCreator
from ray.tune.examples.tf_distributed_keras_example import train_mnist
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def ray_connect_cluster():
try:
address_info = ray.init(address="auto")
except Exception as e:
pytest.skip(str(e))
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
def test_single_step(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(train_mnist)
trainer = trainable_cls()
trainer.train()
trainer.stop()
def test_step_after_completion(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(train_mnist, num_workers=2)
trainer = trainable_cls(config={"epochs": 1})
with pytest.raises(RuntimeError):
for i in range(10):
trainer.train()
def test_validation(ray_start_2_cpus): # noqa: F811
def bad_func(a, b, c):
return 1
t_cls = DistributedTrainableCreator(bad_func)
with pytest.raises(ValueError):
t_cls()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))