mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 05:22:26 +08:00
[rllib] Auto-synchronize filters for all agents (#2791)
This makes sure we always update the local filter, and adds an option to synchronize the remote filters as well. In APEX_DDPG we previously didn't do either. The first is needed for checkpoint correctness, the second might help performance.
This commit is contained in:
@@ -10,7 +10,7 @@ import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import FilterManager, merge_dicts
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
@@ -104,8 +104,6 @@ class A3CAgent(Agent):
|
||||
start = time.time()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(self.local_evaluator.filters,
|
||||
self.remote_evaluators)
|
||||
result = self.optimizer.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pickle
|
||||
import tensorflow as tf
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils import deep_update, merge_dicts
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
@@ -32,6 +32,8 @@ COMMON_CONFIG = {
|
||||
"sample_async": False,
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "NoFilter",
|
||||
# Whether to synchronize the statistics of remote filters.
|
||||
"synchronize_filters": True,
|
||||
# Whether to clip rewards prior to experience postprocessing
|
||||
"clip_rewards": True,
|
||||
# Whether to use rllib or deepmind preprocessors
|
||||
@@ -197,6 +199,13 @@ class Agent(Trainable):
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters,
|
||||
self.remote_evaluators,
|
||||
update_remote=self.config["synchronize_filters"])
|
||||
|
||||
return Trainable.train(self)
|
||||
|
||||
def _setup(self):
|
||||
|
||||
@@ -11,7 +11,6 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
@@ -96,8 +95,6 @@ class ImpalaAgent(Agent):
|
||||
self.optimizer.step()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(self.local_evaluator.filters,
|
||||
self.remote_evaluators)
|
||||
result = self.optimizer.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
|
||||
@@ -8,7 +8,7 @@ import pickle
|
||||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.utils import FilterManager, merge_dicts
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
@@ -113,8 +113,6 @@ class PPOAgent(Agent):
|
||||
# multi-agent
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"]))
|
||||
FilterManager.synchronize(self.local_evaluator.filters,
|
||||
self.remote_evaluators)
|
||||
res = self.optimizer.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
|
||||
@@ -13,7 +13,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[]):
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
|
||||
episodes = collect_episodes(local_evaluator, remote_evaluators)
|
||||
return summarize_episodes(episodes)
|
||||
return summarize_episodes(episodes, episodes)
|
||||
|
||||
|
||||
def collect_episodes(local_evaluator, remote_evaluators=[]):
|
||||
@@ -30,8 +30,13 @@ def collect_episodes(local_evaluator, remote_evaluators=[]):
|
||||
return episodes
|
||||
|
||||
|
||||
def summarize_episodes(episodes):
|
||||
"""Summarizes a set of episode metrics tuples."""
|
||||
def summarize_episodes(episodes, new_episodes):
|
||||
"""Summarizes a set of episode metrics tuples.
|
||||
|
||||
Arguments:
|
||||
episodes: smoothed set of episodes including historical ones
|
||||
new_episodes: just the new episodes in this iteration
|
||||
"""
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
@@ -59,5 +64,5 @@ def summarize_episodes(episodes):
|
||||
episode_reward_min=min_reward,
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
episodes_total=len(episode_lengths),
|
||||
episodes=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards))
|
||||
|
||||
@@ -98,7 +98,7 @@ class PolicyOptimizer(object):
|
||||
assert len(episodes) <= min_history
|
||||
self.episode_history.extend(orig_episodes)
|
||||
self.episode_history = self.episode_history[-min_history:]
|
||||
res = summarize_episodes(episodes)
|
||||
res = summarize_episodes(episodes, orig_episodes)
|
||||
res.update(info=self.stats())
|
||||
return res
|
||||
|
||||
|
||||
@@ -26,6 +26,14 @@ CONFIGS = {
|
||||
"num_workers": 2
|
||||
},
|
||||
"DQN": {},
|
||||
"APEX_DDPG": {
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"num_workers": 2,
|
||||
"min_iter_time_s": 1,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
},
|
||||
"DDPG": {
|
||||
"noise_scale": 0.0,
|
||||
"timesteps_per_iteration": 100
|
||||
@@ -43,7 +51,7 @@ CONFIGS = {
|
||||
|
||||
def test(use_object_store, alg_name, failures):
|
||||
cls = get_agent_class(alg_name)
|
||||
if alg_name == "DDPG":
|
||||
if "DDPG" in alg_name:
|
||||
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
else:
|
||||
@@ -61,7 +69,7 @@ def test(use_object_store, alg_name, failures):
|
||||
alg2.restore(alg1.save())
|
||||
|
||||
for _ in range(10):
|
||||
if alg_name == "DDPG":
|
||||
if "DDPG" in alg_name:
|
||||
obs = np.random.uniform(size=3)
|
||||
else:
|
||||
obs = np.random.uniform(size=4)
|
||||
@@ -75,7 +83,7 @@ def test(use_object_store, alg_name, failures):
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C"]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]:
|
||||
test(use_object_store, name, failures)
|
||||
|
||||
assert not failures, failures
|
||||
|
||||
@@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
ev.sample()
|
||||
ray.get(remote_ev.sample.remote())
|
||||
result = collect_metrics(ev, [remote_ev])
|
||||
self.assertEqual(result["episodes_total"], 20)
|
||||
self.assertEqual(result["episodes"], 20)
|
||||
self.assertEqual(result["episode_reward_mean"], 10)
|
||||
|
||||
def testAsync(self):
|
||||
@@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 0)
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 8)
|
||||
self.assertEqual(result["episodes"], 8)
|
||||
indices = []
|
||||
for env in ev.async_env.vector_env.envs:
|
||||
self.assertEqual(env.unwrapped.config.worker_index, 0)
|
||||
@@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 0)
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
batch = ev.sample()
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 4)
|
||||
self.assertEqual(result["episodes"], 4)
|
||||
|
||||
def testVectorEnvSupport(self):
|
||||
ev = PolicyEvaluator(
|
||||
@@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 0)
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes_total"], 8)
|
||||
self.assertEqual(result["episodes"], 8)
|
||||
|
||||
def testTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
|
||||
@@ -11,7 +11,7 @@ class FilterManager(object):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def synchronize(local_filters, remotes):
|
||||
def synchronize(local_filters, remotes, update_remote=True):
|
||||
"""Aggregates all filters from remote evaluators.
|
||||
|
||||
Local copy is updated and then broadcasted to all remote evaluators.
|
||||
@@ -19,12 +19,14 @@ class FilterManager(object):
|
||||
Args:
|
||||
local_filters (dict): Filters to be synchronized.
|
||||
remotes (list): Remote evaluators with filters.
|
||||
update_remote (bool): Whether to push updates to remote filters.
|
||||
"""
|
||||
remote_filters = ray.get(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes])
|
||||
for rf in remote_filters:
|
||||
for k in local_filters:
|
||||
local_filters[k].apply_changes(rf[k], with_buffer=False)
|
||||
copies = {k: v.as_serializable() for k, v in local_filters.items()}
|
||||
remote_copy = ray.put(copies)
|
||||
[r.sync_filters.remote(remote_copy) for r in remotes]
|
||||
if update_remote:
|
||||
copies = {k: v.as_serializable() for k, v in local_filters.items()}
|
||||
remote_copy = ray.put(copies)
|
||||
[r.sync_filters.remote(remote_copy) for r in remotes]
|
||||
|
||||
Reference in New Issue
Block a user