[rllib] Auto-synchronize filters for all agents (#2791)

This makes sure we always update the local filter, and adds an option to synchronize the remote filters as well. In APEX_DDPG we previously didn't do either. The first is needed for checkpoint correctness, the second might help performance.
This commit is contained in:
Eric Liang
2018-09-03 20:01:53 -07:00
committed by GitHub
parent a34a7172b4
commit 25ffe57a5c
9 changed files with 46 additions and 29 deletions
+1 -3
View File
@@ -10,7 +10,7 @@ import ray
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils import FilterManager, merge_dicts
from ray.rllib.utils import merge_dicts
from ray.tune.trial import Resources
DEFAULT_CONFIG = with_common_config({
@@ -104,8 +104,6 @@ class A3CAgent(Agent):
start = time.time()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
result = self.optimizer.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
+10 -1
View File
@@ -11,7 +11,7 @@ import pickle
import tensorflow as tf
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.trainable import Trainable
@@ -32,6 +32,8 @@ COMMON_CONFIG = {
"sample_async": False,
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters.
"synchronize_filters": True,
# Whether to clip rewards prior to experience postprocessing
"clip_rewards": True,
# Whether to use rllib or deepmind preprocessors
@@ -197,6 +199,13 @@ class Agent(Trainable):
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
FilterManager.synchronize(
self.local_evaluator.filters,
self.remote_evaluators,
update_remote=self.config["synchronize_filters"])
return Trainable.train(self)
def _setup(self):
-3
View File
@@ -11,7 +11,6 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.utils import FilterManager
from ray.tune.trial import Resources
OPTIMIZER_SHARED_CONFIGS = [
@@ -96,8 +95,6 @@ class ImpalaAgent(Agent):
self.optimizer.step()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
result = self.optimizer.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
+1 -3
View File
@@ -8,7 +8,7 @@ import pickle
import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.utils import FilterManager, merge_dicts
from ray.rllib.utils import merge_dicts
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
from ray.tune.trial import Resources
@@ -113,8 +113,6 @@ class PPOAgent(Agent):
# multi-agent
self.local_evaluator.foreach_trainable_policy(
lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"]))
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
res = self.optimizer.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
+9 -4
View File
@@ -13,7 +13,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[]):
"""Gathers episode metrics from PolicyEvaluator instances."""
episodes = collect_episodes(local_evaluator, remote_evaluators)
return summarize_episodes(episodes)
return summarize_episodes(episodes, episodes)
def collect_episodes(local_evaluator, remote_evaluators=[]):
@@ -30,8 +30,13 @@ def collect_episodes(local_evaluator, remote_evaluators=[]):
return episodes
def summarize_episodes(episodes):
"""Summarizes a set of episode metrics tuples."""
def summarize_episodes(episodes, new_episodes):
"""Summarizes a set of episode metrics tuples.
Arguments:
episodes: smoothed set of episodes including historical ones
new_episodes: just the new episodes in this iteration
"""
episode_rewards = []
episode_lengths = []
@@ -59,5 +64,5 @@ def summarize_episodes(episodes):
episode_reward_min=min_reward,
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episodes_total=len(episode_lengths),
episodes=len(new_episodes),
policy_reward_mean=dict(policy_rewards))
@@ -98,7 +98,7 @@ class PolicyOptimizer(object):
assert len(episodes) <= min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-min_history:]
res = summarize_episodes(episodes)
res = summarize_episodes(episodes, orig_episodes)
res.update(info=self.stats())
return res
@@ -26,6 +26,14 @@ CONFIGS = {
"num_workers": 2
},
"DQN": {},
"APEX_DDPG": {
"observation_filter": "MeanStdFilter",
"num_workers": 2,
"min_iter_time_s": 1,
"optimizer": {
"num_replay_buffer_shards": 1,
},
},
"DDPG": {
"noise_scale": 0.0,
"timesteps_per_iteration": 100
@@ -43,7 +51,7 @@ CONFIGS = {
def test(use_object_store, alg_name, failures):
cls = get_agent_class(alg_name)
if alg_name == "DDPG":
if "DDPG" in alg_name:
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
else:
@@ -61,7 +69,7 @@ def test(use_object_store, alg_name, failures):
alg2.restore(alg1.save())
for _ in range(10):
if alg_name == "DDPG":
if "DDPG" in alg_name:
obs = np.random.uniform(size=3)
else:
obs = np.random.uniform(size=4)
@@ -75,7 +83,7 @@ def test(use_object_store, alg_name, failures):
if __name__ == "__main__":
failures = []
for use_object_store in [False, True]:
for name in ["ES", "DQN", "DDPG", "PPO", "A3C"]:
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]:
test(use_object_store, name, failures)
assert not failures, failures
@@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase):
ev.sample()
ray.get(remote_ev.sample.remote())
result = collect_metrics(ev, [remote_ev])
self.assertEqual(result["episodes_total"], 20)
self.assertEqual(result["episodes"], 20)
self.assertEqual(result["episode_reward_mean"], 10)
def testAsync(self):
@@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 0)
self.assertEqual(result["episodes"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 8)
self.assertEqual(result["episodes"], 8)
indices = []
for env in ev.async_env.vector_env.envs:
self.assertEqual(env.unwrapped.config.worker_index, 0)
@@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 0)
self.assertEqual(result["episodes"], 0)
batch = ev.sample()
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 4)
self.assertEqual(result["episodes"], 4)
def testVectorEnvSupport(self):
ev = PolicyEvaluator(
@@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 0)
self.assertEqual(result["episodes"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_total"], 8)
self.assertEqual(result["episodes"], 8)
def testTruncateEpisodes(self):
ev = PolicyEvaluator(
+6 -4
View File
@@ -11,7 +11,7 @@ class FilterManager(object):
"""
@staticmethod
def synchronize(local_filters, remotes):
def synchronize(local_filters, remotes, update_remote=True):
"""Aggregates all filters from remote evaluators.
Local copy is updated and then broadcasted to all remote evaluators.
@@ -19,12 +19,14 @@ class FilterManager(object):
Args:
local_filters (dict): Filters to be synchronized.
remotes (list): Remote evaluators with filters.
update_remote (bool): Whether to push updates to remote filters.
"""
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes])
for rf in remote_filters:
for k in local_filters:
local_filters[k].apply_changes(rf[k], with_buffer=False)
copies = {k: v.as_serializable() for k, v in local_filters.items()}
remote_copy = ray.put(copies)
[r.sync_filters.remote(remote_copy) for r in remotes]
if update_remote:
copies = {k: v.as_serializable() for k, v in local_filters.items()}
remote_copy = ray.put(copies)
[r.sync_filters.remote(remote_copy) for r in remotes]