From 344b4ef0ff080abec340b162ffcce0f49966986c Mon Sep 17 00:00:00 2001 From: eugenevinitsky Date: Tue, 6 Nov 2018 17:09:34 -1000 Subject: [PATCH] [rllib] Fix filter sync for ES and ARS (#2918) --- doc/source/rllib-training.rst | 2 +- python/ray/rllib/agents/ars/ars.py | 30 ++++++++++++++++++- python/ray/rllib/agents/ars/policies.py | 6 ++++ python/ray/rllib/agents/es/es.py | 30 ++++++++++++++++++- python/ray/rllib/agents/es/policies.py | 6 ++++ .../ray/rllib/test/test_checkpoint_restore.py | 10 +++++-- .../ray/rllib/tuned_examples/swimmer-ars.yaml | 1 + 7 files changed, 80 insertions(+), 5 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 9cd46ea44..c8155e0cf 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -51,7 +51,7 @@ An example of evaluating a previously trained DQN agent is as follows: python ray/python/ray/rllib/rollout.py \ ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1 \ - --run DQN --env CartPole-v0 + --run DQN --env CartPole-v0 --steps 10000 The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1`` diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 67e87057f..ab08dc00e 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -18,6 +18,7 @@ from ray.tune.trial import Resources from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies from ray.rllib.agents.ars import utils +from ray.rllib.utils import FilterManager logger = logging.getLogger(__name__) @@ -84,6 +85,22 @@ class Worker(object): self.sess, self.env.action_space, self.env.observation_space, self.preprocessor, config["observation_filter"], config["model"]) + @property + def filters(self): + return {"default": self.policy.get_filter()} + + def sync_filters(self, new_filters): + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters + def rollout(self, timestep_limit, add_noise=False): rollout_rewards, rollout_length = policies.rollout( self.policy, @@ -201,6 +218,7 @@ class ARSAgent(Agent): num_episodes += sum(len(pair) for pair in result.noisy_lengths) num_timesteps += sum( sum(pair) for pair in result.noisy_lengths) + return results, num_episodes, num_timesteps def _train(self): @@ -276,6 +294,11 @@ class ARSAgent(Agent): if len(all_eval_returns) > 0: self.reward_list.append(eval_returns.mean()) + # Now sync the filters + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) + info = { "weights_norm": np.square(theta).sum(), "weights_std": np.std(theta), @@ -301,12 +324,17 @@ class ARSAgent(Agent): def __getstate__(self): return { "weights": self.policy.get_weights(), + "filter": self.policy.get_filter(), "episodes_so_far": self.episodes_so_far, } def __setstate__(self, state): - self.policy.set_weights(state["weights"]) self.episodes_so_far = state["episodes_so_far"] + self.policy.set_weights(state["weights"]) + self.policy.set_filter(state["filter"]) + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) def compute_action(self, observation): return self.policy.compute(observation, update=True)[0] diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index cd54628bd..27f664655 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -102,5 +102,11 @@ class GenericPolicy(object): def set_weights(self, x): self.variables.set_flat(x) + def set_filter(self, obs_filter): + self.observation_filter = obs_filter + + def get_filter(self): + return self.observation_filter + def get_weights(self): return self.variables.get_flat() diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index ed2ed1869..9d9d5e240 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -18,6 +18,7 @@ from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies from ray.rllib.agents.es import utils from ray.rllib.utils import merge_dicts +from ray.rllib.utils import FilterManager logger = logging.getLogger(__name__) @@ -89,6 +90,22 @@ class Worker(object): self.preprocessor, config["observation_filter"], config["model"], **policy_params) + @property + def filters(self): + return {"default": self.policy.get_filter()} + + def sync_filters(self, new_filters): + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters + def rollout(self, timestep_limit, add_noise=True): rollout_rewards, rollout_length = policies.rollout( self.policy, @@ -207,6 +224,7 @@ class ESAgent(Agent): num_episodes += sum(len(pair) for pair in result.noisy_lengths) num_timesteps += sum( sum(pair) for pair in result.noisy_lengths) + return results, num_episodes, num_timesteps def _train(self): @@ -274,6 +292,11 @@ class ESAgent(Agent): if len(all_eval_returns) > 0: self.reward_list.append(np.mean(eval_returns)) + # Now sync the filters + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) + info = { "weights_norm": np.square(theta).sum(), "grad_norm": np.square(g).sum(), @@ -299,12 +322,17 @@ class ESAgent(Agent): def __getstate__(self): return { "weights": self.policy.get_weights(), + "filter": self.policy.get_filter(), "episodes_so_far": self.episodes_so_far, } def __setstate__(self, state): - self.policy.set_weights(state["weights"]) self.episodes_so_far = state["episodes_so_far"] + self.policy.set_weights(state["weights"]) + self.policy.set_filter(state["filter"]) + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) def compute_action(self, observation): return self.policy.compute(observation, update=False)[0] diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index 0df5ced30..cf2da630e 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -82,3 +82,9 @@ class GenericPolicy(object): def get_weights(self): return self.variables.get_flat() + + def get_filter(self): + return self.observation_filter + + def set_filter(self, observation_filter): + self.observation_filter = observation_filter diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index cb371c90c..aa8fac280 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -23,7 +23,8 @@ CONFIGS = { "ES": { "episodes_per_batch": 10, "train_batch_size": 100, - "num_workers": 2 + "num_workers": 2, + "observation_filter": "MeanStdFilter" }, "DQN": {}, "APEX_DDPG": { @@ -46,6 +47,11 @@ CONFIGS = { "A3C": { "num_workers": 1 }, + "ARS": { + "num_rollouts": 10, + "num_workers": 2, + "observation_filter": "MeanStdFilter" + } } @@ -83,7 +89,7 @@ def test(use_object_store, alg_name, failures): if __name__ == "__main__": failures = [] for use_object_store in [False, True]: - for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]: + for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]: test(use_object_store, name, failures) assert not failures, failures diff --git a/python/ray/rllib/tuned_examples/swimmer-ars.yaml b/python/ray/rllib/tuned_examples/swimmer-ars.yaml index 532bb00b0..effb4cfe1 100644 --- a/python/ray/rllib/tuned_examples/swimmer-ars.yaml +++ b/python/ray/rllib/tuned_examples/swimmer-ars.yaml @@ -1,3 +1,4 @@ +# can expect improvement to -140 reward in ~300-500k timesteps swimmer-ars: env: Swimmer-v2 run: ARS