[tune] [rllib] Allow checkpointing to object store instead of local disk (#1212)

* wip

* use normal pickle

* fix checkpoint test

* comment

* Comment

* fix test

* fix lint

* fix py 3.5

* Update agent.py

* fix lint
This commit is contained in:
Eric Liang
2017-11-19 00:36:43 -08:00
committed by GitHub
parent d986294c2b
commit ae4e1dd396
3 changed files with 100 additions and 12 deletions
+51
View File
@@ -6,8 +6,11 @@ from datetime import datetime
import logging
import numpy as np
import io
import os
import gzip
import pickle
import shutil
import tempfile
import time
import uuid
@@ -147,6 +150,35 @@ class Agent(object):
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)
return out.getvalue()
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
@@ -160,6 +192,25 @@ class Agent(object):
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""
out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def stop(self):
"""Releases all resources used by this agent."""
@@ -26,8 +26,9 @@ CONFIGS = {
"A3C": {"use_lstm": False},
}
for name in ["ES", "DQN", "PPO", "A3C"]:
cls = get_agent_class(name)
def test(use_object_store, alg_name):
cls = get_agent_class(alg_name)
alg1 = cls("CartPole-v0", CONFIGS[name])
alg2 = cls("CartPole-v0", CONFIGS[name])
@@ -36,11 +37,23 @@ for name in ["ES", "DQN", "PPO", "A3C"]:
print("current status: " + str(res))
# Sync the models
alg2.restore(alg1.save())
if use_object_store:
alg2.restore_from_object(alg1.save_to_object())
else:
alg2.restore(alg1.save())
for _ in range(10):
obs = np.random.uniform(size=4)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
assert abs(a1-a2) < .1, (a1, a2)
assert abs(a1 - a2) < .1, (a1, a2)
if __name__ == "__main__":
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
for use_object_store in [False, True]:
for name in ["ES", "DQN", "PPO", "A3C"]:
test(use_object_store, name)
print("All checkpoint restore tests passed!")
+32 -8
View File
@@ -91,6 +91,7 @@ class Trial(object):
# Local trial state that is updated during the run
self.last_result = None
self._checkpoint_path = restore_path
self._checkpoint_obj = None
self.agent = None
self.status = Trial.PENDING
self.location = None
@@ -106,7 +107,9 @@ class Trial(object):
self._setup_agent()
if self._checkpoint_path:
self.restore_from_path(path=self._checkpoint_path)
self.restore_from_path(self._checkpoint_path)
elif self._checkpoint_obj:
self.restore_from_obj(self._checkpoint_obj)
def stop(self, error=False, stop_logger=True):
"""Stops this trial.
@@ -152,7 +155,7 @@ class Trial(object):
assert self.status == Trial.RUNNING, self.status
try:
self.checkpoint()
self.checkpoint(to_object_store=True)
self.stop(stop_logger=False)
self.status = Trial.PAUSED
except Exception:
@@ -226,16 +229,25 @@ class Trial(object):
return ', '.join(pieces)
def checkpoint(self):
"""Synchronously checkpoints the state of this trial.
def checkpoint(self, to_object_store=False):
"""Checkpoints the state of this trial.
TODO(ekl): we should support a PAUSED state based on checkpointing.
Args:
to_object_store (bool): Whether to save to the Ray object store
(async) vs a path on local disk (sync).
"""
path = ray.get(self.agent.save.remote())
obj = None
path = None
if to_object_store:
obj = self.agent.save_to_object.remote()
else:
path = ray.get(self.agent.save.remote())
self._checkpoint_path = path
print("Saved checkpoint to:", path)
return path
self._checkpoint_obj = obj
print("Saved checkpoint to:", path or obj)
return path or obj
def restore_from_path(self, path):
"""Restores agent state from specified path.
@@ -253,6 +265,18 @@ class Trial(object):
print("Error restoring agent:", traceback.format_exc())
self.status = Trial.ERROR
def restore_from_obj(self, obj):
"""Restores agent state from the specified object."""
if self.agent is None:
print("Unable to restore - no agent")
else:
try:
ray.get(self.agent.restore_from_object.remote(obj))
except Exception:
print("Error restoring agent:", traceback.format_exc())
self.status = Trial.ERROR
def _setup_agent(self):
self.status = Trial.RUNNING
agent_cls = get_agent_class(self.alg)