mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:21:15 +08:00
[tune] [rllib] Allow checkpointing to object store instead of local disk (#1212)
* wip * use normal pickle * fix checkpoint test * comment * Comment * fix test * fix lint * fix py 3.5 * Update agent.py * fix lint
This commit is contained in:
@@ -6,8 +6,11 @@ from datetime import datetime
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import gzip
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
@@ -147,6 +150,35 @@ class Agent(object):
|
||||
open(checkpoint_path + ".rllib_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
@@ -160,6 +192,25 @@ class Agent(object):
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
out = io.BytesIO(obj)
|
||||
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this agent."""
|
||||
|
||||
|
||||
@@ -26,8 +26,9 @@ CONFIGS = {
|
||||
"A3C": {"use_lstm": False},
|
||||
}
|
||||
|
||||
for name in ["ES", "DQN", "PPO", "A3C"]:
|
||||
cls = get_agent_class(name)
|
||||
|
||||
def test(use_object_store, alg_name):
|
||||
cls = get_agent_class(alg_name)
|
||||
alg1 = cls("CartPole-v0", CONFIGS[name])
|
||||
alg2 = cls("CartPole-v0", CONFIGS[name])
|
||||
|
||||
@@ -36,11 +37,23 @@ for name in ["ES", "DQN", "PPO", "A3C"]:
|
||||
print("current status: " + str(res))
|
||||
|
||||
# Sync the models
|
||||
alg2.restore(alg1.save())
|
||||
if use_object_store:
|
||||
alg2.restore_from_object(alg1.save_to_object())
|
||||
else:
|
||||
alg2.restore(alg1.save())
|
||||
|
||||
for _ in range(10):
|
||||
obs = np.random.uniform(size=4)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
assert abs(a1-a2) < .1, (a1, a2)
|
||||
assert abs(a1 - a2) < .1, (a1, a2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "PPO", "A3C"]:
|
||||
test(use_object_store, name)
|
||||
|
||||
print("All checkpoint restore tests passed!")
|
||||
|
||||
@@ -91,6 +91,7 @@ class Trial(object):
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self._checkpoint_path = restore_path
|
||||
self._checkpoint_obj = None
|
||||
self.agent = None
|
||||
self.status = Trial.PENDING
|
||||
self.location = None
|
||||
@@ -106,7 +107,9 @@ class Trial(object):
|
||||
|
||||
self._setup_agent()
|
||||
if self._checkpoint_path:
|
||||
self.restore_from_path(path=self._checkpoint_path)
|
||||
self.restore_from_path(self._checkpoint_path)
|
||||
elif self._checkpoint_obj:
|
||||
self.restore_from_obj(self._checkpoint_obj)
|
||||
|
||||
def stop(self, error=False, stop_logger=True):
|
||||
"""Stops this trial.
|
||||
@@ -152,7 +155,7 @@ class Trial(object):
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
try:
|
||||
self.checkpoint()
|
||||
self.checkpoint(to_object_store=True)
|
||||
self.stop(stop_logger=False)
|
||||
self.status = Trial.PAUSED
|
||||
except Exception:
|
||||
@@ -226,16 +229,25 @@ class Trial(object):
|
||||
|
||||
return ', '.join(pieces)
|
||||
|
||||
def checkpoint(self):
|
||||
"""Synchronously checkpoints the state of this trial.
|
||||
def checkpoint(self, to_object_store=False):
|
||||
"""Checkpoints the state of this trial.
|
||||
|
||||
TODO(ekl): we should support a PAUSED state based on checkpointing.
|
||||
Args:
|
||||
to_object_store (bool): Whether to save to the Ray object store
|
||||
(async) vs a path on local disk (sync).
|
||||
"""
|
||||
|
||||
path = ray.get(self.agent.save.remote())
|
||||
obj = None
|
||||
path = None
|
||||
if to_object_store:
|
||||
obj = self.agent.save_to_object.remote()
|
||||
else:
|
||||
path = ray.get(self.agent.save.remote())
|
||||
self._checkpoint_path = path
|
||||
print("Saved checkpoint to:", path)
|
||||
return path
|
||||
self._checkpoint_obj = obj
|
||||
|
||||
print("Saved checkpoint to:", path or obj)
|
||||
return path or obj
|
||||
|
||||
def restore_from_path(self, path):
|
||||
"""Restores agent state from specified path.
|
||||
@@ -253,6 +265,18 @@ class Trial(object):
|
||||
print("Error restoring agent:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def restore_from_obj(self, obj):
|
||||
"""Restores agent state from the specified object."""
|
||||
|
||||
if self.agent is None:
|
||||
print("Unable to restore - no agent")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.agent.restore_from_object.remote(obj))
|
||||
except Exception:
|
||||
print("Error restoring agent:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def _setup_agent(self):
|
||||
self.status = Trial.RUNNING
|
||||
agent_cls = get_agent_class(self.alg)
|
||||
|
||||
Reference in New Issue
Block a user