[sgd] Document and add simple MNIST example (#3236)

This commit is contained in:
Eric Liang
2018-11-10 21:52:20 -08:00
committed by GitHub
parent d681893b0f
commit 53489d2f85
15 changed files with 279 additions and 38 deletions
+11
View File
@@ -0,0 +1,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.experimental.sgd.sgd import DistributedSGD
from ray.experimental.sgd.model import Model
__all__ = [
"DistributedSGD",
"Model",
]
+134
View File
@@ -0,0 +1,134 @@
#!/usr/bin/env python
"""Example of how to train a model with Ray SGD.
We use a small model here, so no speedup for distributing the computation is
expected. This example shows:
- How to set up a simple input pipeline
- How to evaluate model accuracy during training
- How to get and set model weights
- How to train with ray.experimental.sgd.DistributedSGD
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import ray
from ray.tune import run_experiments
from ray.tune.examples.tune_mnist_ray import deepnn
from ray.experimental.sgd.model import Model
from ray.experimental.sgd.sgd import DistributedSGD
from ray.experimental.tfutils import TensorFlowVariables
parser = argparse.ArgumentParser()
parser.add_argument("--redis-address", default=None, type=str)
parser.add_argument("--num-iters", default=10000, type=int)
parser.add_argument("--batch-size", default=50, type=int)
parser.add_argument("--num-workers", default=1, type=int)
parser.add_argument("--devices-per-worker", default=1, type=int)
parser.add_argument("--tune", action="store_true", help="Run in Ray Tune")
parser.add_argument(
"--strategy", default="ps", type=str, help="One of 'simple' or 'ps'")
parser.add_argument(
"--gpu", action="store_true", help="Use GPUs for optimization")
class MNISTModel(Model):
def __init__(self):
# Import data
error = None
for _ in range(10):
try:
self.mnist = input_data.read_data_sets(
"/tmp/tensorflow/mnist/input_data", one_hot=True)
error = None
break
except Exception as e:
error = e
time.sleep(5)
if error:
raise ValueError("Failed to import data", error)
# Set seed and build layers
tf.set_random_seed(0)
self.x = tf.placeholder(tf.float32, [None, 784], name="x")
self.y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
y_conv, self.keep_prob = deepnn(self.x)
# Need to define loss and optimizer attributes
self.loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
labels=self.y_, logits=y_conv))
self.optimizer = tf.train.AdamOptimizer(1e-4)
self.variables = TensorFlowVariables(self.loss,
tf.get_default_session())
# For evaluating test accuracy
correct_prediction = tf.equal(
tf.argmax(y_conv, 1), tf.argmax(self.y_, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def get_feed_dict(self):
batch = self.mnist.train.next_batch(50)
return {
self.x: batch[0],
self.y_: batch[1],
self.keep_prob: 0.5,
}
def test_accuracy(self):
return self.accuracy.eval(
feed_dict={
self.x: self.mnist.test.images,
self.y_: self.mnist.test.labels,
self.keep_prob: 1.0,
})
def train_mnist(config, reporter):
args = config["args"]
sgd = DistributedSGD(
lambda w_i, d_i: MNISTModel(),
num_workers=args.num_workers,
devices_per_worker=args.devices_per_worker,
gpu=args.gpu,
strategy=args.strategy)
# Important: synchronize the initial weights of all model replicas
w0 = sgd.for_model(lambda m: m.variables.get_flat())
sgd.foreach_model(lambda m: m.variables.set_flat(w0))
for i in range(args.num_iters):
if i % 10 == 0:
start = time.time()
loss = sgd.step(fetch_stats=True)["loss"]
acc = sgd.foreach_model(lambda model: model.test_accuracy())
print("Iter", i, "loss", loss, "accuracy", acc)
print("Time per iteration", time.time() - start)
assert len(set(acc)) == 1, ("Models out of sync", acc)
reporter(timesteps_total=i, mean_loss=loss, mean_accuracy=acc[0])
else:
sgd.step()
if __name__ == "__main__":
args = parser.parse_args()
ray.init(redis_address=args.redis_address)
if args.tune:
run_experiments({
"mnist_sgd": {
"run": train_mnist,
"config": {
"args": args,
},
},
})
else:
train_mnist({"args": args}, lambda **kw: None)
+21 -11
View File
@@ -91,7 +91,9 @@ class DistributedSGD(object):
RemoteSGDWorker = ray.remote(**requests)(SGDWorker)
self.workers = []
logger.info("Creating SGD workers ({} total)".format(num_workers))
logger.info(
"Creating SGD workers ({} total, {} devices per worker)".format(
num_workers, devices_per_worker))
for worker_index in range(num_workers):
self.workers.append(
RemoteSGDWorker.remote(
@@ -143,7 +145,15 @@ class DistributedSGD(object):
out = []
for r in results:
out.extend(r)
return r
return out
def for_model(self, fn):
"""Apply the given function to a single model replica.
Returns:
Result from applying the function.
"""
return ray.get(self.workers[0].for_model.remote(fn))
def step(self, fetch_stats=False):
"""Run a single SGD step.
@@ -176,7 +186,7 @@ def _average_gradients(grads):
def _simple_sgd_step(actors):
if len(actors) == 1:
return ray.get(actors[0].compute_apply.remote())
return {"loss": ray.get(actors[0].compute_apply.remote())}
start = time.time()
fetches = ray.get([a.compute_gradients.remote() for a in actors])
@@ -193,18 +203,18 @@ def _simple_sgd_step(actors):
start = time.time()
ray.get([a.apply_gradients.remote(avg_grad) for a in actors])
logger.debug("apply all grads time {}".format(time.time() - start))
return np.mean(losses)
return {"loss": np.mean(losses)}
def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
# Preallocate object ids that actors will write gradient shards to
grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list]
for _ in actors]
logger.info("Generated grad oids")
logger.debug("Generated grad oids")
# Preallocate object ids that param servers will write new weights to
accum_shard_ids = [np.random.bytes(20) for _ in ps_list]
logger.info("Generated accum oids")
logger.debug("Generated accum oids")
# Kick off the fused compute grad / update weights tf run for each actor
losses = []
@@ -214,7 +224,7 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
grad_shard_oids,
accum_shard_ids,
write_timeline=write_timeline))
logger.info("Launched all ps_compute_applys on all actors")
logger.debug("Launched all ps_compute_applys on all actors")
# Issue prefetch ops
for j, (ps, weight_shard_oid) in list(
@@ -224,7 +234,7 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
to_fetch.append(grad_shard_oids[j])
random.shuffle(to_fetch)
ps.prefetch.remote(to_fetch)
logger.info("Launched all prefetch ops")
logger.debug("Launched all prefetch ops")
# Aggregate the gradients produced by the actors. These operations
# run concurrently with the actor methods above.
@@ -233,11 +243,11 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
enumerate(zip(ps_list, accum_shard_ids)))[::-1]:
ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list])
ps_gets.append(ps.get.remote(weight_shard_oid))
logger.info("Launched all aggregate ops")
logger.debug("Launched all aggregate ops")
if write_timeline:
timelines = [ps.get_timeline.remote() for ps in ps_list]
logger.info("launched timeline gets")
logger.debug("Launched timeline gets")
timelines = ray.get(timelines)
t0 = timelines[0]
for t in timelines[1:]:
@@ -247,6 +257,6 @@ def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline):
# Wait for at least the ps gets to finish
ray.get(ps_gets)
if fetch_stats:
return np.mean(ray.get(losses))
return {"loss": np.mean(ray.get(losses))}
else:
return None
+36 -19
View File
@@ -48,23 +48,24 @@ class SGDWorker(object):
device_tmpl = "/gpu:%d"
else:
device_tmpl = "/cpu:%d"
for device_idx in range(num_devices):
device = device_tmpl % device_idx
with tf.device(device):
with tf.variable_scope("device_%d" % device_idx):
model = model_creator(worker_index, device_idx)
self.models.append(model)
model.grads = [
t
for t in model.optimizer.compute_gradients(model.loss)
if t[0] is not None
]
grad_ops.append(model.grads)
with self.sess.as_default():
for device_idx in range(num_devices):
device = device_tmpl % device_idx
with tf.device(device):
with tf.variable_scope("device_%d" % device_idx):
model = model_creator(worker_index, device_idx)
self.models.append(model)
grads = [
t for t in model.optimizer.compute_gradients(
model.loss) if t[0] is not None
]
grad_ops.append(grads)
if num_devices == 1:
assert not max_bytes, \
"grad_shard_bytes > 0 ({}) requires num_devices > 1".format(
max_bytes)
if max_bytes:
raise ValueError(
"Implementation limitation: grad_shard_bytes > 0 "
"({}) currently requires > 1 device".format(max_bytes))
self.packed_grads_and_vars = grad_ops
else:
if max_bytes:
@@ -182,15 +183,28 @@ class SGDWorker(object):
tf.local_variables_initializer())
self.sess.run(init_op)
def _grad_feed_dict(self):
# Aggregate feed dicts for each model on this worker.
feed_dict = {}
for model in self.models:
feed_dict.update(model.get_feed_dict())
return feed_dict
def foreach_model(self, fn):
return [fn(m) for m in self.models]
with self.sess.as_default():
return [fn(m) for m in self.models]
def foreach_worker(self, fn):
return fn(self)
with self.sess.as_default():
return fn(self)
def for_model(self, fn):
with self.sess.as_default():
return fn(self.models[0])
def compute_gradients(self):
start = time.time()
feed_dict = {}
feed_dict = self._grad_feed_dict()
# Aggregate feed dicts for each model on this worker.
for model in self.models:
feed_dict.update(model.get_feed_dict())
@@ -219,6 +233,7 @@ class SGDWorker(object):
fetches = run_timeline(
self.sess,
[self.models[0].loss, self.apply_op, self.nccl_control_out],
feed_dict=self._grad_feed_dict(),
name="compute_apply")
return fetches[0]
@@ -227,7 +242,9 @@ class SGDWorker(object):
agg_grad_shard_oids,
tl_name="ps_compute_apply",
write_timeline=False):
feed_dict = dict(zip(self.plasma_in_grads_oids, out_grad_shard_oids))
feed_dict = self._grad_feed_dict()
feed_dict.update(
dict(zip(self.plasma_in_grads_oids, out_grad_shard_oids)))
feed_dict.update(
dict(zip(self.plasma_out_grads_oids, agg_grad_shard_oids)))
fetch(agg_grad_shard_oids)
+2 -2
View File
@@ -189,7 +189,7 @@ class ModelCatalog(object):
seq_in (Tensor): Optional RNN sequence length tensor.
Returns:
model (Model): Neural network model.
model (models.Model): Neural network model.
"""
assert isinstance(input_dict, dict)
@@ -241,7 +241,7 @@ class ModelCatalog(object):
options (dict): Optional args to pass to the model constructor.
Returns:
model (Model): Neural network model.
model (models.Model): Neural network model.
"""
from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
PyTorchFCNet)
+1 -1
View File
@@ -42,7 +42,7 @@ import tensorflow as tf
FLAGS = None
status_reporter = None # used to report training status back to Ray
activation_fn = None # e.g. tf.nn.relu
activation_fn = tf.nn.relu # e.g. tf.nn.relu
def deepnn(x):