mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 02:00:12 +08:00
[sgd] Add checkpointing (#3638)
This commit is contained in:
committed by
Robert Nishihara
parent
5e76d52868
commit
5945b92fd3
@@ -101,6 +101,12 @@ class MNISTModel(Model):
|
||||
})
|
||||
return {"accuracy": accuracy}
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_flat(weights)
|
||||
|
||||
|
||||
def train_mnist(config, reporter):
|
||||
args = config["args"]
|
||||
|
||||
@@ -46,3 +46,27 @@ class Model(object):
|
||||
TensorFlow feed_dict to add to the gradient operation.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_weights(self):
|
||||
"""Return weights from the model.
|
||||
|
||||
Implementing `get_weights` is required for checkpointing and fault
|
||||
tolerance.
|
||||
|
||||
Returns:
|
||||
Numpy array of weights from the model.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"get_weights of %s is not implemented" % self.__class__.__name__)
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the model weights.
|
||||
|
||||
Implementing `set_weights` is required for checkpointing and fault
|
||||
tolerance.
|
||||
|
||||
Args:
|
||||
weights: numpy array of weights for the model.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"set_weights of %s is not implemented" % self.__class__.__name__)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
@@ -177,6 +178,16 @@ class DistributedSGD(object):
|
||||
ray.get([w.warmup.remote() for w in self.workers])
|
||||
logger.info("Warmup complete")
|
||||
|
||||
def save_checkpoint(self, path):
|
||||
w0 = self.for_model(lambda m: m.get_weights())
|
||||
filename = os.path.join(path, "model.npy")
|
||||
np.save(filename, w0)
|
||||
|
||||
def restore_checkpoint(self, path):
|
||||
filename = os.path.join(path, "model.npy")
|
||||
w0 = np.load(filename)
|
||||
self.foreach_model(lambda m: m.set_weights(w0))
|
||||
|
||||
|
||||
def _average_gradients(grads):
|
||||
out = []
|
||||
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.tfbench.test_model import TFBenchModel
|
||||
from ray.experimental.sgd.sgd import DistributedSGD
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--redis-address", default=None, type=str)
|
||||
parser.add_argument("--num-iters", default=10, type=int)
|
||||
parser.add_argument("--batch-size", default=1, type=int)
|
||||
parser.add_argument("--num-workers", default=2, type=int)
|
||||
parser.add_argument("--grad-shard-bytes", default=10000000, type=int)
|
||||
parser.add_argument("--devices-per-worker", default=2, type=int)
|
||||
parser.add_argument("--all-reduce-alg", default="simple", type=str)
|
||||
parser.add_argument("--object-store-memory", default=None, type=int)
|
||||
parser.add_argument("--checkpoint-dir", default="/tmp", type=str)
|
||||
parser.add_argument(
|
||||
"--strategy", default="simple", type=str, help="One of 'simple' or 'ps'")
|
||||
parser.add_argument(
|
||||
"--gpu", action="store_true", help="Use GPUs for optimization")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(
|
||||
redis_address=args.redis_address,
|
||||
object_store_memory=args.object_store_memory)
|
||||
|
||||
model_creator = (
|
||||
lambda worker_idx, device_idx: TFBenchModel(
|
||||
batch=args.batch_size, use_cpus=not args.gpu))
|
||||
|
||||
sgd = DistributedSGD(
|
||||
model_creator,
|
||||
num_workers=args.num_workers,
|
||||
devices_per_worker=args.devices_per_worker,
|
||||
gpu=args.gpu,
|
||||
strategy=args.strategy,
|
||||
grad_shard_bytes=args.grad_shard_bytes,
|
||||
all_reduce_alg=args.all_reduce_alg)
|
||||
|
||||
if not os.path.exists(args.checkpoint_dir):
|
||||
raise ValueError(
|
||||
"Checkpoint directory does not exist: %s" % args.checkpoint_dir)
|
||||
|
||||
def step(i):
|
||||
start = time.time()
|
||||
print("== Step {} ==".format(i))
|
||||
stats = sgd.step(fetch_stats=True)
|
||||
ips = ((args.batch_size * args.num_workers * args.devices_per_worker) /
|
||||
(time.time() - start))
|
||||
print("Iteration time", time.time() - start, "Images per second", ips)
|
||||
print("Current loss", stats)
|
||||
|
||||
i = 0
|
||||
while i < args.num_iters:
|
||||
step(i)
|
||||
i += 1
|
||||
|
||||
print("Saving checkpoint...")
|
||||
sgd.save_checkpoint(args.checkpoint_dir)
|
||||
print("Done saving checkpoint")
|
||||
|
||||
step(i)
|
||||
|
||||
print("Restoring checkpoint")
|
||||
sgd.restore_checkpoint(args.checkpoint_dir)
|
||||
print("Done restoring checkpoint")
|
||||
|
||||
step(i)
|
||||
@@ -6,6 +6,7 @@ import tensorflow as tf
|
||||
|
||||
from tfbench import model_config
|
||||
from ray.experimental.sgd.model import Model
|
||||
from ray.experimental.tfutils import TensorFlowVariables
|
||||
|
||||
|
||||
class MockDataset():
|
||||
@@ -46,6 +47,9 @@ class TFBenchModel(Model):
|
||||
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
|
||||
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)
|
||||
|
||||
self.variables = TensorFlowVariables(self.loss,
|
||||
tf.get_default_session())
|
||||
|
||||
def get_loss(self):
|
||||
return self.loss
|
||||
|
||||
@@ -54,3 +58,9 @@ class TFBenchModel(Model):
|
||||
|
||||
def get_feed_dict(self):
|
||||
return {}
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_flat(weights)
|
||||
|
||||
Reference in New Issue
Block a user