From 25ffe57a5c83d73ff80b603f909890a7c32140e1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 3 Sep 2018 20:01:53 -0700 Subject: [PATCH] [rllib] Auto-synchronize filters for all agents (#2791) This makes sure we always update the local filter, and adds an option to synchronize the remote filters as well. In APEX_DDPG we previously didn't do either. The first is needed for checkpoint correctness, the second might help performance. --- python/ray/rllib/agents/a3c/a3c.py | 4 +--- python/ray/rllib/agents/agent.py | 11 ++++++++++- python/ray/rllib/agents/impala/impala.py | 3 --- python/ray/rllib/agents/ppo/ppo.py | 4 +--- python/ray/rllib/evaluation/metrics.py | 13 +++++++++---- python/ray/rllib/optimizers/policy_optimizer.py | 2 +- python/ray/rllib/test/test_checkpoint_restore.py | 14 +++++++++++--- python/ray/rllib/test/test_policy_evaluator.py | 14 +++++++------- python/ray/rllib/utils/filter_manager.py | 10 ++++++---- 9 files changed, 46 insertions(+), 29 deletions(-) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 329d30d63..ddb298af7 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -10,7 +10,7 @@ import ray from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer -from ray.rllib.utils import FilterManager, merge_dicts +from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources DEFAULT_CONFIG = with_common_config({ @@ -104,8 +104,6 @@ class A3CAgent(Agent): start = time.time() while time.time() - start < self.config["min_iter_time_s"]: self.optimizer.step() - FilterManager.synchronize(self.local_evaluator.filters, - self.remote_evaluators) result = self.optimizer.collect_metrics() result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index abd3379a3..dacda4818 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -11,7 +11,7 @@ import pickle import tensorflow as tf from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.utils import deep_update, merge_dicts +from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.tune.registry import ENV_CREATOR, _global_registry from ray.tune.trainable import Trainable @@ -32,6 +32,8 @@ COMMON_CONFIG = { "sample_async": False, # Which observation filter to apply to the observation "observation_filter": "NoFilter", + # Whether to synchronize the statistics of remote filters. + "synchronize_filters": True, # Whether to clip rewards prior to experience postprocessing "clip_rewards": True, # Whether to use rllib or deepmind preprocessors @@ -197,6 +199,13 @@ class Agent(Trainable): for ev in self.optimizer.remote_evaluators: ev.set_global_vars.remote(self.global_vars) + if (self.config.get("observation_filter", "NoFilter") != "NoFilter" + and hasattr(self, "local_evaluator")): + FilterManager.synchronize( + self.local_evaluator.filters, + self.remote_evaluators, + update_remote=self.config["synchronize_filters"]) + return Trainable.train(self) def _setup(self): diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 8ad6d67a3..eb16eca60 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -11,7 +11,6 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer -from ray.rllib.utils import FilterManager from ray.tune.trial import Resources OPTIMIZER_SHARED_CONFIGS = [ @@ -96,8 +95,6 @@ class ImpalaAgent(Agent): self.optimizer.step() while time.time() - start < self.config["min_iter_time_s"]: self.optimizer.step() - FilterManager.synchronize(self.local_evaluator.filters, - self.remote_evaluators) result = self.optimizer.collect_metrics() result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 5c57d21e2..3f5ce16ef 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -8,7 +8,7 @@ import pickle import ray from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph -from ray.rllib.utils import FilterManager, merge_dicts +from ray.rllib.utils import merge_dicts from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.tune.trial import Resources @@ -113,8 +113,6 @@ class PPOAgent(Agent): # multi-agent self.local_evaluator.foreach_trainable_policy( lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"])) - FilterManager.synchronize(self.local_evaluator.filters, - self.remote_evaluators) res = self.optimizer.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index c12818bda..b2762f4f9 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -13,7 +13,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[]): """Gathers episode metrics from PolicyEvaluator instances.""" episodes = collect_episodes(local_evaluator, remote_evaluators) - return summarize_episodes(episodes) + return summarize_episodes(episodes, episodes) def collect_episodes(local_evaluator, remote_evaluators=[]): @@ -30,8 +30,13 @@ def collect_episodes(local_evaluator, remote_evaluators=[]): return episodes -def summarize_episodes(episodes): - """Summarizes a set of episode metrics tuples.""" +def summarize_episodes(episodes, new_episodes): + """Summarizes a set of episode metrics tuples. + + Arguments: + episodes: smoothed set of episodes including historical ones + new_episodes: just the new episodes in this iteration + """ episode_rewards = [] episode_lengths = [] @@ -59,5 +64,5 @@ def summarize_episodes(episodes): episode_reward_min=min_reward, episode_reward_mean=avg_reward, episode_len_mean=avg_length, - episodes_total=len(episode_lengths), + episodes=len(new_episodes), policy_reward_mean=dict(policy_rewards)) diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 0ce3e03d8..21fcf5f0b 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -98,7 +98,7 @@ class PolicyOptimizer(object): assert len(episodes) <= min_history self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-min_history:] - res = summarize_episodes(episodes) + res = summarize_episodes(episodes, orig_episodes) res.update(info=self.stats()) return res diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index 1776ee8a1..6d2f277f9 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -26,6 +26,14 @@ CONFIGS = { "num_workers": 2 }, "DQN": {}, + "APEX_DDPG": { + "observation_filter": "MeanStdFilter", + "num_workers": 2, + "min_iter_time_s": 1, + "optimizer": { + "num_replay_buffer_shards": 1, + }, + }, "DDPG": { "noise_scale": 0.0, "timesteps_per_iteration": 100 @@ -43,7 +51,7 @@ CONFIGS = { def test(use_object_store, alg_name, failures): cls = get_agent_class(alg_name) - if alg_name == "DDPG": + if "DDPG" in alg_name: alg1 = cls(config=CONFIGS[name], env="Pendulum-v0") alg2 = cls(config=CONFIGS[name], env="Pendulum-v0") else: @@ -61,7 +69,7 @@ def test(use_object_store, alg_name, failures): alg2.restore(alg1.save()) for _ in range(10): - if alg_name == "DDPG": + if "DDPG" in alg_name: obs = np.random.uniform(size=3) else: obs = np.random.uniform(size=4) @@ -75,7 +83,7 @@ def test(use_object_store, alg_name, failures): if __name__ == "__main__": failures = [] for use_object_store in [False, True]: - for name in ["ES", "DQN", "DDPG", "PPO", "A3C"]: + for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]: test(use_object_store, name, failures) assert not failures, failures diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index 7bdf87b8d..b454c4461 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase): ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) - self.assertEqual(result["episodes_total"], 20) + self.assertEqual(result["episodes"], 20) self.assertEqual(result["episode_reward_mean"], 10) def testAsync(self): @@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 0) + self.assertEqual(result["episodes"], 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 8) + self.assertEqual(result["episodes"], 8) indices = [] for env in ev.async_env.vector_env.envs: self.assertEqual(env.unwrapped.config.worker_index, 0) @@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 0) + self.assertEqual(result["episodes"], 0) batch = ev.sample() result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 4) + self.assertEqual(result["episodes"], 4) def testVectorEnvSupport(self): ev = PolicyEvaluator( @@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 0) + self.assertEqual(result["episodes"], 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) - self.assertEqual(result["episodes_total"], 8) + self.assertEqual(result["episodes"], 8) def testTruncateEpisodes(self): ev = PolicyEvaluator( diff --git a/python/ray/rllib/utils/filter_manager.py b/python/ray/rllib/utils/filter_manager.py index 98b0471e9..d67777f43 100644 --- a/python/ray/rllib/utils/filter_manager.py +++ b/python/ray/rllib/utils/filter_manager.py @@ -11,7 +11,7 @@ class FilterManager(object): """ @staticmethod - def synchronize(local_filters, remotes): + def synchronize(local_filters, remotes, update_remote=True): """Aggregates all filters from remote evaluators. Local copy is updated and then broadcasted to all remote evaluators. @@ -19,12 +19,14 @@ class FilterManager(object): Args: local_filters (dict): Filters to be synchronized. remotes (list): Remote evaluators with filters. + update_remote (bool): Whether to push updates to remote filters. """ remote_filters = ray.get( [r.get_filters.remote(flush_after=True) for r in remotes]) for rf in remote_filters: for k in local_filters: local_filters[k].apply_changes(rf[k], with_buffer=False) - copies = {k: v.as_serializable() for k, v in local_filters.items()} - remote_copy = ray.put(copies) - [r.sync_filters.remote(remote_copy) for r in remotes] + if update_remote: + copies = {k: v.as_serializable() for k, v in local_filters.items()} + remote_copy = ray.put(copies) + [r.sync_filters.remote(remote_copy) for r in remotes]