[rllib] Refactor save() / restore() code of agents and avoid O(n_workers) save size (#2982)

This commit is contained in:
Eric Liang
2018-09-30 01:15:13 -07:00
committed by GitHub
parent 747253e0f6
commit 65dcafdc3f
12 changed files with 184 additions and 255 deletions
-28
View File
@@ -2,11 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import time
import ray
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
@@ -108,28 +105,3 @@ class A3CAgent(Agent):
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])
+32 -85
View File
@@ -4,13 +4,13 @@ from __future__ import print_function
import copy
import json
import numpy as np
import os
import pickle
import tempfile
from datetime import datetime
import tensorflow as tf
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
@@ -324,95 +324,39 @@ class Agent(Trainable):
"""
self.local_evaluator.set_weights(weights)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
class _MockAgent(Agent):
"""Mock agent for use in tests"""
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
state["evaluator"] = self.local_evaluator.save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
_agent_name = "MockAgent"
_default_config = {
"mock_error": False,
"persistent_error": False,
"test_variable": 1
}
def _init(self):
self.info = None
self.restored = False
def _train(self):
if self.config["mock_error"] and self.iteration == 1 \
and (self.config["persistent_error"] or not self.restored):
raise Exception("mock error")
return dict(
episode_reward_mean=10,
episode_len_mean=10,
timesteps_this_iter=10,
info={})
def __setstate__(self, state):
if "evaluator" in state:
self.local_evaluator.restore(state["evaluator"])
remote_state = ray.put(state["evaluator"])
for r in self.remote_evaluators:
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, 'wb') as f:
pickle.dump(self.info, f)
return path
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(),
open(checkpoint_path + ".agent_state", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
info = pickle.load(f)
self.info = info
self.restored = True
def set_info(self, info):
self.info = info
return info
def get_info(self):
return self.info
class _SigmoidFakeData(_MockAgent):
"""Agent that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
_agent_name = "SigmoidFakeData"
_default_config = {
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
}
def _train(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
class _ParameterTuningAgent(_MockAgent):
_agent_name = "ParameterTuningAgent"
_default_config = {
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1
}
def _train(self):
return dict(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb"))
self.__setstate__(extra_data)
def get_agent_class(alg):
@@ -458,10 +402,13 @@ def get_agent_class(alg):
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
else:
raise Exception(("Unknown algorithm {}.").format(alg))
+8 -22
View File
@@ -8,8 +8,6 @@ from __future__ import print_function
from collections import namedtuple
import numpy as np
import os
import pickle
import time
import ray
@@ -201,7 +199,6 @@ class ARSAgent(Agent):
]
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.reward_list = []
self.tstart = time.time()
@@ -228,7 +225,6 @@ class ARSAgent(Agent):
def _train(self):
config = self.config
step_tstart = time.time()
theta = self.policy.get_weights()
assert theta.dtype == np.float32
@@ -259,7 +255,6 @@ class ARSAgent(Agent):
len(all_training_lengths))
self.episodes_so_far += num_episodes
self.timesteps_so_far += num_timesteps
# Assemble the results.
eval_returns = np.array(all_eval_returns)
@@ -301,8 +296,6 @@ class ARSAgent(Agent):
if len(all_eval_returns) > 0:
self.reward_list.append(eval_returns.mean())
step_tend = time.time()
tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean())
tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std())
tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean())
@@ -319,9 +312,6 @@ class ARSAgent(Agent):
"update_ratio": update_ratio,
"episodes_this_iter": noisy_lengths.size,
"episodes_so_far": self.episodes_so_far,
"timesteps_so_far": self.timesteps_so_far,
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
result = dict(
episode_reward_mean=np.mean(
@@ -337,19 +327,15 @@ class ARSAgent(Agent):
for w in self.workers:
w.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
weights = self.policy.get_weights()
objects = [weights, self.episodes_so_far, self.timesteps_so_far]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"episodes_so_far": self.episodes_so_far,
}
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.policy.set_weights(objects[0])
self.episodes_so_far = objects[1]
self.timesteps_so_far = objects[2]
def __setstate__(self, state):
self.policy.set_weights(state["weights"])
self.episodes_so_far = state["episodes_so_far"]
def compute_action(self, observation):
return self.policy.compute(observation, update=True)[0]
+11 -29
View File
@@ -2,11 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import time
import ray
from ray.rllib import optimizers
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
@@ -249,30 +246,15 @@ class DQNAgent(Agent):
}, **self.optimizer.stats()))
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
def __getstate__(self):
state = Agent.__getstate__(self)
state.update({
"num_target_updates": self.num_target_updates,
"last_target_update_ts": self.last_target_update_ts,
})
return state
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
extra_data = [
self.local_evaluator.save(),
ray.get([e.save.remote() for e in self.remote_evaluators]),
self.optimizer.save(), self.num_target_updates,
self.last_target_update_ts
]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.local_evaluator.restore(extra_data[0])
ray.get([
e.restore.remote(d)
for (d, e) in zip(extra_data[1], self.remote_evaluators)
])
self.optimizer.restore(extra_data[2])
self.num_target_updates = extra_data[3]
self.last_target_update_ts = extra_data[4]
def __setstate__(self, state):
Agent.__setstate__(self, state)
self.num_target_updates = state["num_target_updates"]
self.last_target_update_ts = state["last_target_update_ts"]
+8 -27
View File
@@ -7,8 +7,6 @@ from __future__ import print_function
from collections import namedtuple
import numpy as np
import os
import pickle
import time
import ray
@@ -180,7 +178,6 @@ class ESAgent(Agent):
]
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.reward_list = []
self.tstart = time.time()
@@ -207,7 +204,6 @@ class ESAgent(Agent):
def _train(self):
config = self.config
step_tstart = time.time()
theta = self.policy.get_weights()
assert theta.dtype == np.float32
@@ -238,7 +234,6 @@ class ESAgent(Agent):
len(all_training_lengths))
self.episodes_so_far += num_episodes
self.timesteps_so_far += num_timesteps
# Assemble the results.
eval_returns = np.array(all_eval_returns)
@@ -271,7 +266,6 @@ class ESAgent(Agent):
if len(all_eval_returns) > 0:
self.reward_list.append(np.mean(eval_returns))
step_tend = time.time()
tlogger.record_tabular("EvalEpRewStd", eval_returns.std())
tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean())
@@ -285,11 +279,6 @@ class ESAgent(Agent):
tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size)
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
tlogger.record_tabular("TimestepsThisIter", noisy_lengths.sum())
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()
info = {
@@ -298,10 +287,6 @@ class ESAgent(Agent):
"update_ratio": update_ratio,
"episodes_this_iter": noisy_lengths.size,
"episodes_so_far": self.episodes_so_far,
"timesteps_this_iter": noisy_lengths.sum(),
"timesteps_so_far": self.timesteps_so_far,
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
reward_mean = np.mean(self.reward_list[-self.report_length:])
@@ -318,19 +303,15 @@ class ESAgent(Agent):
for w in self.workers:
w.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
weights = self.policy.get_weights()
objects = [weights, self.episodes_so_far, self.timesteps_so_far]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"episodes_so_far": self.episodes_so_far,
}
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.policy.set_weights(objects[0])
self.episodes_so_far = objects[1]
self.timesteps_so_far = objects[2]
def __setstate__(self, state):
self.policy.set_weights(state["weights"])
self.episodes_so_far = state["episodes_so_far"]
def compute_action(self, observation):
return self.policy.compute(observation, update=False)[0]
-28
View File
@@ -2,11 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import time
import ray
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
@@ -99,28 +96,3 @@ class ImpalaAgent(Agent):
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])
+99
View File
@@ -0,0 +1,99 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import numpy as np
from ray.rllib.agents.agent import Agent
class _MockAgent(Agent):
"""Mock agent for use in tests"""
_agent_name = "MockAgent"
_default_config = {
"mock_error": False,
"persistent_error": False,
"test_variable": 1
}
def _init(self):
self.info = None
self.restored = False
def _train(self):
if self.config["mock_error"] and self.iteration == 1 \
and (self.config["persistent_error"] or not self.restored):
raise Exception("mock error")
return dict(
episode_reward_mean=10,
episode_len_mean=10,
timesteps_this_iter=10,
info={})
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, 'wb') as f:
pickle.dump(self.info, f)
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
info = pickle.load(f)
self.info = info
self.restored = True
def set_info(self, info):
self.info = info
return info
def get_info(self):
return self.info
class _SigmoidFakeData(_MockAgent):
"""Agent that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
_agent_name = "SigmoidFakeData"
_default_config = {
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
}
def _train(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
class _ParameterTuningAgent(_MockAgent):
_agent_name = "ParameterTuningAgent"
_default_config = {
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1
}
def _train(self):
return dict(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
-26
View File
@@ -2,10 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.utils import merge_dicts
@@ -134,25 +130,3 @@ class PPOAgent(Agent):
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {})))
return res
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = [self.local_evaluator.save(), agent_state]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.local_evaluator.restore(extra_data[0])
ray.get([
a.restore.remote(o)
for (a, o) in zip(self.remote_evaluators, extra_data[1])
])
+1 -1
View File
@@ -64,5 +64,5 @@ def summarize_episodes(episodes, new_episodes):
episode_reward_min=min_reward,
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episodes=len(new_episodes),
episodes_this_iter=len(new_episodes),
policy_reward_mean=dict(policy_rewards))
@@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase):
ev.sample()
ray.get(remote_ev.sample.remote())
result = collect_metrics(ev, [remote_ev])
self.assertEqual(result["episodes"], 20)
self.assertEqual(result["episodes_this_iter"], 20)
self.assertEqual(result["episode_reward_mean"], 10)
def testAsync(self):
@@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 0)
self.assertEqual(result["episodes_this_iter"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 8)
self.assertEqual(result["episodes_this_iter"], 8)
indices = []
for env in ev.async_env.vector_env.envs:
self.assertEqual(env.unwrapped.config.worker_index, 0)
@@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 0)
self.assertEqual(result["episodes_this_iter"], 0)
batch = ev.sample()
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 4)
self.assertEqual(result["episodes_this_iter"], 4)
def testVectorEnvSupport(self):
ev = PolicyEvaluator(
@@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 0)
self.assertEqual(result["episodes_this_iter"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes"], 8)
self.assertEqual(result["episodes_this_iter"], 8)
def testTruncateEpisodes(self):
ev = PolicyEvaluator(
+6
View File
@@ -16,6 +16,12 @@ NODE_IP = "node_ip"
# (Auto-filled) The pid of the training process.
PID = "pid"
# Number of timesteps in this iteration.
EPISODES_THIS_ITER = "episodes_this_iter"
# (Optional/Auto-filled) Accumulated time in seconds for this experiment.
EPISODES_TOTAL = "episodes_total"
# Number of timesteps in this iteration.
TIMESTEPS_THIS_ITER = "timesteps_this_iter"
+12 -2
View File
@@ -17,7 +17,8 @@ import uuid
import ray
from ray.tune.logger import UnifiedLogger
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL)
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
EPISODES_THIS_ITER, EPISODES_TOTAL)
from ray.tune.trial import Resources
logger = logging.getLogger(__name__)
@@ -77,6 +78,7 @@ class Trainable(object):
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = None
self._episodes_total = None
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
@@ -162,8 +164,15 @@ class Trainable(object):
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]
# self._timesteps_total should only be tracked if increments provided
if result.get(EPISODES_THIS_ITER):
if self._episodes_total is None:
self._episodes_total = 0
self._episodes_total += result[EPISODES_THIS_ITER]
# self._timesteps_total should not override user-provided total
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
result.setdefault(EPISODES_TOTAL, self._episodes_total)
# Provides auto-filled neg_mean_loss for avoiding regressions
if result.get("mean_loss"):
@@ -205,7 +214,7 @@ class Trainable(object):
checkpoint_path = self._save(checkpoint_dir or self.logdir)
pickle.dump([
self._experiment_id, self._iteration, self._timesteps_total,
self._time_total
self._time_total, self._episodes_total
], open(checkpoint_path + ".tune_metadata", "wb"))
return checkpoint_path
@@ -256,6 +265,7 @@ class Trainable(object):
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
self._episodes_total = metadata[4]
self._restored = True
def restore_from_object(self, obj):