[sgd] Add checkpointing (#3638)

This commit is contained in:
Peter Schafhalter
2019-01-08 15:29:30 -08:00
committed by Robert Nishihara
parent 5e76d52868
commit 5945b92fd3
6 changed files with 136 additions and 0 deletions
@@ -101,6 +101,12 @@ class MNISTModel(Model):
})
return {"accuracy": accuracy}
def get_weights(self):
return self.variables.get_flat()
def set_weights(self, weights):
self.variables.set_flat(weights)
def train_mnist(config, reporter):
args = config["args"]
+24
View File
@@ -46,3 +46,27 @@ class Model(object):
TensorFlow feed_dict to add to the gradient operation.
"""
return {}
def get_weights(self):
"""Return weights from the model.
Implementing `get_weights` is required for checkpointing and fault
tolerance.
Returns:
Numpy array of weights from the model.
"""
raise NotImplementedError(
"get_weights of %s is not implemented" % self.__class__.__name__)
def set_weights(self, weights):
"""Sets the model weights.
Implementing `set_weights` is required for checkpointing and fault
tolerance.
Args:
weights: numpy array of weights for the model.
"""
raise NotImplementedError(
"set_weights of %s is not implemented" % self.__class__.__name__)
+11
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import logging
import os
import random
import time
@@ -177,6 +178,16 @@ class DistributedSGD(object):
ray.get([w.warmup.remote() for w in self.workers])
logger.info("Warmup complete")
def save_checkpoint(self, path):
w0 = self.for_model(lambda m: m.get_weights())
filename = os.path.join(path, "model.npy")
np.save(filename, w0)
def restore_checkpoint(self, path):
filename = os.path.join(path, "model.npy")
w0 = np.load(filename)
self.foreach_model(lambda m: m.set_weights(w0))
def _average_gradients(grads):
out = []
+77
View File
@@ -0,0 +1,77 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import time
import ray
from ray.experimental.sgd.tfbench.test_model import TFBenchModel
from ray.experimental.sgd.sgd import DistributedSGD
parser = argparse.ArgumentParser()
parser.add_argument("--redis-address", default=None, type=str)
parser.add_argument("--num-iters", default=10, type=int)
parser.add_argument("--batch-size", default=1, type=int)
parser.add_argument("--num-workers", default=2, type=int)
parser.add_argument("--grad-shard-bytes", default=10000000, type=int)
parser.add_argument("--devices-per-worker", default=2, type=int)
parser.add_argument("--all-reduce-alg", default="simple", type=str)
parser.add_argument("--object-store-memory", default=None, type=int)
parser.add_argument("--checkpoint-dir", default="/tmp", type=str)
parser.add_argument(
"--strategy", default="simple", type=str, help="One of 'simple' or 'ps'")
parser.add_argument(
"--gpu", action="store_true", help="Use GPUs for optimization")
if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init(
redis_address=args.redis_address,
object_store_memory=args.object_store_memory)
model_creator = (
lambda worker_idx, device_idx: TFBenchModel(
batch=args.batch_size, use_cpus=not args.gpu))
sgd = DistributedSGD(
model_creator,
num_workers=args.num_workers,
devices_per_worker=args.devices_per_worker,
gpu=args.gpu,
strategy=args.strategy,
grad_shard_bytes=args.grad_shard_bytes,
all_reduce_alg=args.all_reduce_alg)
if not os.path.exists(args.checkpoint_dir):
raise ValueError(
"Checkpoint directory does not exist: %s" % args.checkpoint_dir)
def step(i):
start = time.time()
print("== Step {} ==".format(i))
stats = sgd.step(fetch_stats=True)
ips = ((args.batch_size * args.num_workers * args.devices_per_worker) /
(time.time() - start))
print("Iteration time", time.time() - start, "Images per second", ips)
print("Current loss", stats)
i = 0
while i < args.num_iters:
step(i)
i += 1
print("Saving checkpoint...")
sgd.save_checkpoint(args.checkpoint_dir)
print("Done saving checkpoint")
step(i)
print("Restoring checkpoint")
sgd.restore_checkpoint(args.checkpoint_dir)
print("Done restoring checkpoint")
step(i)
@@ -6,6 +6,7 @@ import tensorflow as tf
from tfbench import model_config
from ray.experimental.sgd.model import Model
from ray.experimental.tfutils import TensorFlowVariables
class MockDataset():
@@ -46,6 +47,9 @@ class TFBenchModel(Model):
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)
self.variables = TensorFlowVariables(self.loss,
tf.get_default_session())
def get_loss(self):
return self.loss
@@ -54,3 +58,9 @@ class TFBenchModel(Model):
def get_feed_dict(self):
return {}
def get_weights(self):
return self.variables.get_flat()
def set_weights(self, weights):
self.variables.set_flat(weights)