diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index fd2fac08f..19b3529d1 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -6,8 +6,11 @@ from datetime import datetime import logging import numpy as np +import io import os +import gzip import pickle +import shutil import tempfile import time import uuid @@ -147,6 +150,35 @@ class Agent(object): open(checkpoint_path + ".rllib_metadata", "wb")) return checkpoint_path + def save_to_object(self): + """Saves the current model state to a Python object. It also + saves to disk but does not return the checkpoint path. + + Returns: + Object holding checkpoint data. + """ + + checkpoint_prefix = self.save() + + data = {} + base_dir = os.path.dirname(checkpoint_prefix) + for path in os.listdir(base_dir): + path = os.path.join(base_dir, path) + if path.startswith(checkpoint_prefix): + data[os.path.basename(path)] = open(path, "rb").read() + + out = io.BytesIO() + with gzip.GzipFile(fileobj=out, mode="wb") as f: + compressed = pickle.dumps({ + "checkpoint_name": os.path.basename(checkpoint_prefix), + "data": data, + }) + print("Saving checkpoint to object store, {} bytes".format( + len(compressed))) + f.write(compressed) + + return out.getvalue() + def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. @@ -160,6 +192,25 @@ class Agent(object): self._timesteps_total = metadata[2] self._time_total = metadata[3] + def restore_from_object(self, obj): + """Restores training state from a checkpoint object. + + These checkpoints are returned from calls to save_to_object(). + """ + + out = io.BytesIO(obj) + info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read()) + data = info["data"] + tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) + checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"]) + + for file_name, file_contents in data.items(): + with open(os.path.join(tmpdir, file_name), "wb") as f: + f.write(file_contents) + + self.restore(checkpoint_path) + shutil.rmtree(tmpdir) + def stop(self): """Releases all resources used by this agent.""" diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index e8b262d3e..a9a6b9197 100755 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -26,8 +26,9 @@ CONFIGS = { "A3C": {"use_lstm": False}, } -for name in ["ES", "DQN", "PPO", "A3C"]: - cls = get_agent_class(name) + +def test(use_object_store, alg_name): + cls = get_agent_class(alg_name) alg1 = cls("CartPole-v0", CONFIGS[name]) alg2 = cls("CartPole-v0", CONFIGS[name]) @@ -36,11 +37,23 @@ for name in ["ES", "DQN", "PPO", "A3C"]: print("current status: " + str(res)) # Sync the models - alg2.restore(alg1.save()) + if use_object_store: + alg2.restore_from_object(alg1.save_to_object()) + else: + alg2.restore(alg1.save()) for _ in range(10): obs = np.random.uniform(size=4) a1 = get_mean_action(alg1, obs) a2 = get_mean_action(alg2, obs) print("Checking computed actions", alg1, obs, a1, a2) - assert abs(a1-a2) < .1, (a1, a2) + assert abs(a1 - a2) < .1, (a1, a2) + + +if __name__ == "__main__": + # https://github.com/ray-project/ray/issues/1062 for enabling ES test too + for use_object_store in [False, True]: + for name in ["ES", "DQN", "PPO", "A3C"]: + test(use_object_store, name) + + print("All checkpoint restore tests passed!") diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index a77f4a0f6..c40981636 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -91,6 +91,7 @@ class Trial(object): # Local trial state that is updated during the run self.last_result = None self._checkpoint_path = restore_path + self._checkpoint_obj = None self.agent = None self.status = Trial.PENDING self.location = None @@ -106,7 +107,9 @@ class Trial(object): self._setup_agent() if self._checkpoint_path: - self.restore_from_path(path=self._checkpoint_path) + self.restore_from_path(self._checkpoint_path) + elif self._checkpoint_obj: + self.restore_from_obj(self._checkpoint_obj) def stop(self, error=False, stop_logger=True): """Stops this trial. @@ -152,7 +155,7 @@ class Trial(object): assert self.status == Trial.RUNNING, self.status try: - self.checkpoint() + self.checkpoint(to_object_store=True) self.stop(stop_logger=False) self.status = Trial.PAUSED except Exception: @@ -226,16 +229,25 @@ class Trial(object): return ', '.join(pieces) - def checkpoint(self): - """Synchronously checkpoints the state of this trial. + def checkpoint(self, to_object_store=False): + """Checkpoints the state of this trial. - TODO(ekl): we should support a PAUSED state based on checkpointing. + Args: + to_object_store (bool): Whether to save to the Ray object store + (async) vs a path on local disk (sync). """ - path = ray.get(self.agent.save.remote()) + obj = None + path = None + if to_object_store: + obj = self.agent.save_to_object.remote() + else: + path = ray.get(self.agent.save.remote()) self._checkpoint_path = path - print("Saved checkpoint to:", path) - return path + self._checkpoint_obj = obj + + print("Saved checkpoint to:", path or obj) + return path or obj def restore_from_path(self, path): """Restores agent state from specified path. @@ -253,6 +265,18 @@ class Trial(object): print("Error restoring agent:", traceback.format_exc()) self.status = Trial.ERROR + def restore_from_obj(self, obj): + """Restores agent state from the specified object.""" + + if self.agent is None: + print("Unable to restore - no agent") + else: + try: + ray.get(self.agent.restore_from_object.remote(obj)) + except Exception: + print("Error restoring agent:", traceback.format_exc()) + self.status = Trial.ERROR + def _setup_agent(self): self.status = Trial.RUNNING agent_cls = get_agent_class(self.alg)