mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:49:16 +08:00
[tune] TensorFlow Distributed Trainable (#11876)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -359,6 +359,14 @@ py_test(
|
||||
deps = [":tune_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_tensorflow_trainable",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_tensorflow_trainable.py"],
|
||||
tags = ["exclusive", "example", "tensorflow"],
|
||||
deps = [":tune_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_horovod",
|
||||
size = "medium",
|
||||
@@ -376,6 +384,14 @@ py_test(
|
||||
args = ["--num-workers=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_distributed_keras_example",
|
||||
size = "small",
|
||||
srcs = ["examples/tf_distributed_keras_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example", "tensorflow"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dragonfly_example",
|
||||
size = "medium",
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
|
||||
"""
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from ray.tune.integration.keras import TuneReportCheckpointCallback
|
||||
from ray.tune.integration.tensorflow import (DistributedTrainableCreator,
|
||||
get_num_workers)
|
||||
|
||||
|
||||
def mnist_dataset(batch_size):
|
||||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
|
||||
# The `x` arrays are in uint8 and have values in the range [0, 255].
|
||||
# You need to convert them to float32 with values in the range [0, 1]
|
||||
x_train = x_train / np.float32(255)
|
||||
y_train = y_train.astype(np.int64)
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
|
||||
return train_dataset
|
||||
|
||||
|
||||
def build_and_compile_cnn_model(config):
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.Input(shape=(28, 28)),
|
||||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(32, 3, activation="relu"),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(config.get("hidden", 128), activation="relu"),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
model.compile(
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
optimizer=tf.keras.optimizers.SGD(
|
||||
learning_rate=config.get("lr", 0.05),
|
||||
momentum=config.get("momentum", 0.5)),
|
||||
metrics=["accuracy"])
|
||||
return model
|
||||
|
||||
|
||||
def train_mnist(config, checkpoint_dir=None):
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
per_worker_batch_size = 64
|
||||
num_workers = get_num_workers()
|
||||
global_batch_size = per_worker_batch_size * num_workers
|
||||
multi_worker_dataset = mnist_dataset(global_batch_size)
|
||||
with strategy.scope():
|
||||
multi_worker_model = build_and_compile_cnn_model(config)
|
||||
|
||||
multi_worker_model.fit(
|
||||
multi_worker_dataset,
|
||||
epochs=2,
|
||||
steps_per_epoch=70,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
{
|
||||
"mean_accuracy": "accuracy"
|
||||
}, filename="checkpoint")
|
||||
])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables CUDA training")
|
||||
parser.add_argument(
|
||||
"--cluster",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables multi-node tuning")
|
||||
args = parser.parse_args()
|
||||
tf_trainable = DistributedTrainableCreator(
|
||||
train_mnist,
|
||||
use_gpu=args.use_gpu,
|
||||
num_workers=2,
|
||||
)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
max_t=400,
|
||||
grace_period=20)
|
||||
tune.run(
|
||||
tf_trainable,
|
||||
name="exp",
|
||||
scheduler=sched,
|
||||
stop={
|
||||
"mean_accuracy": 0.99,
|
||||
"training_iteration": 10
|
||||
},
|
||||
num_samples=1,
|
||||
config={
|
||||
"lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)),
|
||||
"momentum": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.1, 0.9)),
|
||||
"hidden": tune.sample_from(
|
||||
lambda spec: np.random.randint(32, 512)),
|
||||
})
|
||||
@@ -0,0 +1,172 @@
|
||||
import json
|
||||
import ray
|
||||
import os
|
||||
from ray import tune
|
||||
from ray.tune.result import RESULT_DUPLICATE
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.util.sgd.utils import find_free_port
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
|
||||
def setup_process_group(worker_addresses, index):
|
||||
"""Set up distributed training info for training task.
|
||||
|
||||
Args:
|
||||
worker_addresses (list): addresses of the workers.
|
||||
index (int): index of current worker
|
||||
"""
|
||||
tf_config = {
|
||||
"cluster": {
|
||||
"worker": worker_addresses
|
||||
},
|
||||
"task": {
|
||||
"type": "worker",
|
||||
"index": index
|
||||
}
|
||||
}
|
||||
os.environ["TF_CONFIG"] = json.dumps(tf_config)
|
||||
|
||||
|
||||
def setup_address():
|
||||
ip = ray.services.get_node_ip_address()
|
||||
port = find_free_port()
|
||||
return f"{ip}:{port}"
|
||||
|
||||
|
||||
class _TensorFlowTrainable(tune.Trainable):
|
||||
"""Base class for distributed training on Tune."""
|
||||
_function = None
|
||||
_num_workers = None
|
||||
_use_gpu = None
|
||||
_num_cpus_per_worker = None
|
||||
|
||||
__slots__ = ["workers", "_finished"]
|
||||
|
||||
@classmethod
|
||||
def get_remote_worker_options(self) -> Dict[str, int]:
|
||||
num_gpus = 1 if self._use_gpu else 0
|
||||
num_cpus = int(self._num_cpus_per_worker or 1)
|
||||
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def setup(self, config: Dict):
|
||||
self._finished = False
|
||||
num_workers = self._num_workers
|
||||
assert self._function
|
||||
|
||||
func_trainable = wrap_function(self.__class__._function)
|
||||
remote_trainable = ray.remote(func_trainable)
|
||||
remote_trainable = remote_trainable.options(
|
||||
**self.get_remote_worker_options())
|
||||
self.workers = [
|
||||
remote_trainable.remote(config=config, )
|
||||
for _ in range(num_workers)
|
||||
]
|
||||
|
||||
addresses = [
|
||||
ray.get(worker.execute.remote(lambda _: setup_address()))
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
from functools import partial
|
||||
setup_on_worker = partial(
|
||||
setup_process_group, worker_addresses=addresses)
|
||||
ray.get([
|
||||
w.execute.remote(lambda _: setup_on_worker(index=index))
|
||||
for index, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def step(self) -> Dict:
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = ray.get([w.step.remote() for w in self.workers])[0]
|
||||
if RESULT_DUPLICATE in result:
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir: str) -> str:
|
||||
# TODO: optimize if colocated
|
||||
save_obj = ray.get(self.workers[0].save_to_object.remote())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir: str):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
return ray.get(
|
||||
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
|
||||
|
||||
def stop(self):
|
||||
ray.get([worker.stop.remote() for worker in self.workers])
|
||||
|
||||
|
||||
def DistributedTrainableCreator(
|
||||
func: Callable,
|
||||
use_gpu: bool = False,
|
||||
num_workers: int = 2,
|
||||
num_cpus_per_worker: int = 1) -> Type[_TensorFlowTrainable]:
|
||||
"""Converts TensorFlow MultiWorkerMirror training to be executable by Tune.
|
||||
|
||||
Requires TensorFlow > 2.0 to work, recommends TensorFlow > 2.2.
|
||||
|
||||
This function wraps and sets resources for a TF distributed training
|
||||
function to be used with Tune. It generates a TensorFlow Trainable
|
||||
which can be a distributed training job.
|
||||
|
||||
Note: there is no fault tolerance at the moment.
|
||||
|
||||
Args:
|
||||
func (Callable[[dict], None]): A training function that takes in
|
||||
a config dict for hyperparameters and should initialize
|
||||
horovod via horovod.init.
|
||||
use_gpu (bool); Whether to allocate a GPU per worker.
|
||||
num_cpus_per_worker (int): Number of CPUs to request
|
||||
from Ray per worker.
|
||||
num_workers (int): Number of hosts that each trial is expected
|
||||
to use.
|
||||
|
||||
Returns:
|
||||
Trainable class that can be passed into `tune.run`.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Please refer to full example in tf_distributed_keras_example.py
|
||||
tf_trainable = DistributedTrainableCreator(
|
||||
train_mnist,
|
||||
use_gpu=args.use_gpu,
|
||||
num_workers=2)
|
||||
tune.run(tf_trainable,
|
||||
num_samples=1)
|
||||
"""
|
||||
|
||||
class WrappedDistributedTensorFlowTrainable(_TensorFlowTrainable):
|
||||
_function = func
|
||||
_num_workers = num_workers
|
||||
_num_cpus_per_worker = num_cpus_per_worker
|
||||
_use_gpu = use_gpu
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config: Dict) -> Resources:
|
||||
num_workers_ = int(config.get("num_workers", num_workers))
|
||||
num_worker_cpus = int(
|
||||
config.get("num_cpus_per_worker", num_cpus_per_worker))
|
||||
use_gpu_ = config.get("use_gpu", use_gpu)
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=num_workers * num_worker_cpus,
|
||||
extra_gpu=num_workers_ if use_gpu_ else 0)
|
||||
|
||||
return WrappedDistributedTensorFlowTrainable
|
||||
|
||||
|
||||
def get_num_workers():
|
||||
"""Retrieve the number of workers in the training job."""
|
||||
tf_config = json.loads(os.environ["TF_CONFIG"])
|
||||
num_workers = len(tf_config["cluster"]["worker"])
|
||||
return num_workers
|
||||
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
import ray
|
||||
from ray.tune.integration.tensorflow import DistributedTrainableCreator
|
||||
from ray.tune.examples.tf_distributed_keras_example import train_mnist
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_connect_cluster():
|
||||
try:
|
||||
address_info = ray.init(address="auto")
|
||||
except Exception as e:
|
||||
pytest.skip(str(e))
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(train_mnist)
|
||||
trainer = trainable_cls()
|
||||
trainer.train()
|
||||
trainer.stop()
|
||||
|
||||
|
||||
def test_step_after_completion(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(train_mnist, num_workers=2)
|
||||
trainer = trainable_cls(config={"epochs": 1})
|
||||
with pytest.raises(RuntimeError):
|
||||
for i in range(10):
|
||||
trainer.train()
|
||||
|
||||
|
||||
def test_validation(ray_start_2_cpus): # noqa: F811
|
||||
def bad_func(a, b, c):
|
||||
return 1
|
||||
|
||||
t_cls = DistributedTrainableCreator(bad_func)
|
||||
with pytest.raises(ValueError):
|
||||
t_cls()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user