mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[sgd] Document and add simple MNIST example (#3236)
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.experimental.sgd.sgd import DistributedSGD
|
||||
from ray.experimental.sgd.model import Model
|
||||
|
||||
__all__ = [
|
||||
"DistributedSGD",
|
||||
"Model",
|
||||
]
|
||||
|
||||
Executable
+134
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python
|
||||
"""Example of how to train a model with Ray SGD.
|
||||
|
||||
We use a small model here, so no speedup for distributing the computation is
|
||||
expected. This example shows:
|
||||
- How to set up a simple input pipeline
|
||||
- How to evaluate model accuracy during training
|
||||
- How to get and set model weights
|
||||
- How to train with ray.experimental.sgd.DistributedSGD
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.examples.tune_mnist_ray import deepnn
|
||||
from ray.experimental.sgd.model import Model
|
||||
from ray.experimental.sgd.sgd import DistributedSGD
|
||||
from ray.experimental.tfutils import TensorFlowVariables
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--redis-address", default=None, type=str)
|
||||
parser.add_argument("--num-iters", default=10000, type=int)
|
||||
parser.add_argument("--batch-size", default=50, type=int)
|
||||
parser.add_argument("--num-workers", default=1, type=int)
|
||||
parser.add_argument("--devices-per-worker", default=1, type=int)
|
||||
parser.add_argument("--tune", action="store_true", help="Run in Ray Tune")
|
||||
parser.add_argument(
|
||||
"--strategy", default="ps", type=str, help="One of 'simple' or 'ps'")
|
||||
parser.add_argument(
|
||||
"--gpu", action="store_true", help="Use GPUs for optimization")
|
||||
|
||||
|
||||
class MNISTModel(Model):
|
||||
def __init__(self):
|
||||
# Import data
|
||||
error = None
|
||||
for _ in range(10):
|
||||
try:
|
||||
self.mnist = input_data.read_data_sets(
|
||||
"/tmp/tensorflow/mnist/input_data", one_hot=True)
|
||||
error = None
|
||||
break
|
||||
except Exception as e:
|
||||
error = e
|
||||
time.sleep(5)
|
||||
if error:
|
||||
raise ValueError("Failed to import data", error)
|
||||
|
||||
# Set seed and build layers
|
||||
tf.set_random_seed(0)
|
||||
self.x = tf.placeholder(tf.float32, [None, 784], name="x")
|
||||
self.y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
|
||||
y_conv, self.keep_prob = deepnn(self.x)
|
||||
|
||||
# Need to define loss and optimizer attributes
|
||||
self.loss = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=self.y_, logits=y_conv))
|
||||
self.optimizer = tf.train.AdamOptimizer(1e-4)
|
||||
self.variables = TensorFlowVariables(self.loss,
|
||||
tf.get_default_session())
|
||||
|
||||
# For evaluating test accuracy
|
||||
correct_prediction = tf.equal(
|
||||
tf.argmax(y_conv, 1), tf.argmax(self.y_, 1))
|
||||
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
|
||||
def get_feed_dict(self):
|
||||
batch = self.mnist.train.next_batch(50)
|
||||
return {
|
||||
self.x: batch[0],
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 0.5,
|
||||
}
|
||||
|
||||
def test_accuracy(self):
|
||||
return self.accuracy.eval(
|
||||
feed_dict={
|
||||
self.x: self.mnist.test.images,
|
||||
self.y_: self.mnist.test.labels,
|
||||
self.keep_prob: 1.0,
|
||||
})
|
||||
|
||||
|
||||
def train_mnist(config, reporter):
|
||||
args = config["args"]
|
||||
sgd = DistributedSGD(
|
||||
lambda w_i, d_i: MNISTModel(),
|
||||
num_workers=args.num_workers,
|
||||
devices_per_worker=args.devices_per_worker,
|
||||
gpu=args.gpu,
|
||||
strategy=args.strategy)
|
||||
|
||||
# Important: synchronize the initial weights of all model replicas
|
||||
w0 = sgd.for_model(lambda m: m.variables.get_flat())
|
||||
sgd.foreach_model(lambda m: m.variables.set_flat(w0))
|
||||
|
||||
for i in range(args.num_iters):
|
||||
if i % 10 == 0:
|
||||
start = time.time()
|
||||
loss = sgd.step(fetch_stats=True)["loss"]
|
||||
acc = sgd.foreach_model(lambda model: model.test_accuracy())
|
||||
print("Iter", i, "loss", loss, "accuracy", acc)
|
||||
print("Time per iteration", time.time() - start)
|
||||
assert len(set(acc)) == 1, ("Models out of sync", acc)
|
||||
reporter(timesteps_total=i, mean_loss=loss, mean_accuracy=acc[0])
|
||||
else:
|
||||
sgd.step()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(redis_address=args.redis_address)
|
||||
|
||||
if args.tune:
|
||||
run_experiments({
|
||||
"mnist_sgd": {
|
||||
"run": train_mnist,
|
||||
"config": {
|
||||
"args": args,
|
||||
},
|
||||
},
|
||||
})
|
||||
else:
|
||||
train_mnist({"args": args}, lambda **kw: None)
|
||||
@@ -91,7 +91,9 @@ class DistributedSGD(object):
|
||||
|
||||
RemoteSGDWorker = ray.remote(**requests)(SGDWorker)
|
||||
self.workers = []
|
||||
logger.info("Creating SGD workers ({} total)".format(num_workers))
|
||||
logger.info(
|
||||
"Creating SGD workers ({} total, {} devices per worker)".format(
|
||||
num_workers, devices_per_worker))
|
||||
for worker_index in range(num_workers):
|
||||
self.workers.append(
|
||||
RemoteSGDWorker.remote(
|
||||
@@ -143,7 +145,15 @@ class DistributedSGD(object):
|
||||
out = []
|
||||
for r in results:
|
||||
out.extend(r)
|
||||
return r
|
||||
return out
|
||||
|
||||
def for_model(self, fn):
|
||||
"""Apply the given function to a single model replica.
|
||||
|
||||
Returns:
|
||||
Result from applying the function.
|
||||
"""
|
||||
return ray.get(self.workers[0].for_model.remote(fn))
|
||||
|
||||
def step(self, fetch_stats=False):
|
||||
"""Run a single SGD step.
|
||||
@@ -176,7 +186,7 @@ def _average_gradients(grads):
|
||||
|
||||
def _simple_sgd_step(actors):
|
||||
if len(actors) == 1:
|
||||
return ray.get(actors[0].compute_apply.remote())
|
||||
return {"loss": ray.get(actors[0].compute_apply.remote())}
|
||||
|
||||
start = time.time()
|
||||
fetches = ray.get([a.compute_gradients.remote() for a in actors])
|
||||
@@ -193,18 +203,18 @@ def _simple_sgd_step(actors):
|
||||
start = time.time()
|
||||
ray.get([a.apply_gradients.remote(avg_grad) for a in actors])
|
||||
logger.debug("apply all grads time {}".format(time.time() - start))
|
||||
return np.mean(losses)
|
||||
return {"loss": np.mean(losses)}
|
||||
|
||||
|
||||
def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
|
||||
# Preallocate object ids that actors will write gradient shards to
|
||||
grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list]
|
||||
for _ in actors]
|
||||
logger.info("Generated grad oids")
|
||||
logger.debug("Generated grad oids")
|
||||
|
||||
# Preallocate object ids that param servers will write new weights to
|
||||
accum_shard_ids = [np.random.bytes(20) for _ in ps_list]
|
||||
logger.info("Generated accum oids")
|
||||
logger.debug("Generated accum oids")
|
||||
|
||||
# Kick off the fused compute grad / update weights tf run for each actor
|
||||
losses = []
|
||||
@@ -214,7 +224,7 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
|
||||
grad_shard_oids,
|
||||
accum_shard_ids,
|
||||
write_timeline=write_timeline))
|
||||
logger.info("Launched all ps_compute_applys on all actors")
|
||||
logger.debug("Launched all ps_compute_applys on all actors")
|
||||
|
||||
# Issue prefetch ops
|
||||
for j, (ps, weight_shard_oid) in list(
|
||||
@@ -224,7 +234,7 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
|
||||
to_fetch.append(grad_shard_oids[j])
|
||||
random.shuffle(to_fetch)
|
||||
ps.prefetch.remote(to_fetch)
|
||||
logger.info("Launched all prefetch ops")
|
||||
logger.debug("Launched all prefetch ops")
|
||||
|
||||
# Aggregate the gradients produced by the actors. These operations
|
||||
# run concurrently with the actor methods above.
|
||||
@@ -233,11 +243,11 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
|
||||
enumerate(zip(ps_list, accum_shard_ids)))[::-1]:
|
||||
ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list])
|
||||
ps_gets.append(ps.get.remote(weight_shard_oid))
|
||||
logger.info("Launched all aggregate ops")
|
||||
logger.debug("Launched all aggregate ops")
|
||||
|
||||
if write_timeline:
|
||||
timelines = [ps.get_timeline.remote() for ps in ps_list]
|
||||
logger.info("launched timeline gets")
|
||||
logger.debug("Launched timeline gets")
|
||||
timelines = ray.get(timelines)
|
||||
t0 = timelines[0]
|
||||
for t in timelines[1:]:
|
||||
@@ -247,6 +257,6 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
|
||||
# Wait for at least the ps gets to finish
|
||||
ray.get(ps_gets)
|
||||
if fetch_stats:
|
||||
return np.mean(ray.get(losses))
|
||||
return {"loss": np.mean(ray.get(losses))}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -48,23 +48,24 @@ class SGDWorker(object):
|
||||
device_tmpl = "/gpu:%d"
|
||||
else:
|
||||
device_tmpl = "/cpu:%d"
|
||||
for device_idx in range(num_devices):
|
||||
device = device_tmpl % device_idx
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("device_%d" % device_idx):
|
||||
model = model_creator(worker_index, device_idx)
|
||||
self.models.append(model)
|
||||
model.grads = [
|
||||
t
|
||||
for t in model.optimizer.compute_gradients(model.loss)
|
||||
if t[0] is not None
|
||||
]
|
||||
grad_ops.append(model.grads)
|
||||
with self.sess.as_default():
|
||||
for device_idx in range(num_devices):
|
||||
device = device_tmpl % device_idx
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("device_%d" % device_idx):
|
||||
model = model_creator(worker_index, device_idx)
|
||||
self.models.append(model)
|
||||
grads = [
|
||||
t for t in model.optimizer.compute_gradients(
|
||||
model.loss) if t[0] is not None
|
||||
]
|
||||
grad_ops.append(grads)
|
||||
|
||||
if num_devices == 1:
|
||||
assert not max_bytes, \
|
||||
"grad_shard_bytes > 0 ({}) requires num_devices > 1".format(
|
||||
max_bytes)
|
||||
if max_bytes:
|
||||
raise ValueError(
|
||||
"Implementation limitation: grad_shard_bytes > 0 "
|
||||
"({}) currently requires > 1 device".format(max_bytes))
|
||||
self.packed_grads_and_vars = grad_ops
|
||||
else:
|
||||
if max_bytes:
|
||||
@@ -182,15 +183,28 @@ class SGDWorker(object):
|
||||
tf.local_variables_initializer())
|
||||
self.sess.run(init_op)
|
||||
|
||||
def _grad_feed_dict(self):
|
||||
# Aggregate feed dicts for each model on this worker.
|
||||
feed_dict = {}
|
||||
for model in self.models:
|
||||
feed_dict.update(model.get_feed_dict())
|
||||
return feed_dict
|
||||
|
||||
def foreach_model(self, fn):
|
||||
return [fn(m) for m in self.models]
|
||||
with self.sess.as_default():
|
||||
return [fn(m) for m in self.models]
|
||||
|
||||
def foreach_worker(self, fn):
|
||||
return fn(self)
|
||||
with self.sess.as_default():
|
||||
return fn(self)
|
||||
|
||||
def for_model(self, fn):
|
||||
with self.sess.as_default():
|
||||
return fn(self.models[0])
|
||||
|
||||
def compute_gradients(self):
|
||||
start = time.time()
|
||||
feed_dict = {}
|
||||
feed_dict = self._grad_feed_dict()
|
||||
# Aggregate feed dicts for each model on this worker.
|
||||
for model in self.models:
|
||||
feed_dict.update(model.get_feed_dict())
|
||||
@@ -219,6 +233,7 @@ class SGDWorker(object):
|
||||
fetches = run_timeline(
|
||||
self.sess,
|
||||
[self.models[0].loss, self.apply_op, self.nccl_control_out],
|
||||
feed_dict=self._grad_feed_dict(),
|
||||
name="compute_apply")
|
||||
return fetches[0]
|
||||
|
||||
@@ -227,7 +242,9 @@ class SGDWorker(object):
|
||||
agg_grad_shard_oids,
|
||||
tl_name="ps_compute_apply",
|
||||
write_timeline=False):
|
||||
feed_dict = dict(zip(self.plasma_in_grads_oids, out_grad_shard_oids))
|
||||
feed_dict = self._grad_feed_dict()
|
||||
feed_dict.update(
|
||||
dict(zip(self.plasma_in_grads_oids, out_grad_shard_oids)))
|
||||
feed_dict.update(
|
||||
dict(zip(self.plasma_out_grads_oids, agg_grad_shard_oids)))
|
||||
fetch(agg_grad_shard_oids)
|
||||
|
||||
@@ -189,7 +189,7 @@ class ModelCatalog(object):
|
||||
seq_in (Tensor): Optional RNN sequence length tensor.
|
||||
|
||||
Returns:
|
||||
model (Model): Neural network model.
|
||||
model (models.Model): Neural network model.
|
||||
"""
|
||||
|
||||
assert isinstance(input_dict, dict)
|
||||
@@ -241,7 +241,7 @@ class ModelCatalog(object):
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
model (Model): Neural network model.
|
||||
model (models.Model): Neural network model.
|
||||
"""
|
||||
from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
|
||||
PyTorchFCNet)
|
||||
|
||||
@@ -42,7 +42,7 @@ import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
status_reporter = None # used to report training status back to Ray
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
activation_fn = tf.nn.relu # e.g. tf.nn.relu
|
||||
|
||||
|
||||
def deepnn(x):
|
||||
|
||||
Reference in New Issue
Block a user