[rllib] Fix filter sync for ES and ARS (#2918)

This commit is contained in:
eugenevinitsky
2018-11-06 17:09:34 -10:00
committed by Eric Liang
parent 725df3a485
commit 344b4ef0ff
7 changed files with 80 additions and 5 deletions
+29 -1
View File
@@ -18,6 +18,7 @@ from ray.tune.trial import Resources
from ray.rllib.agents.ars import optimizers
from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -84,6 +85,22 @@ class Worker(object):
self.sess, self.env.action_space, self.env.observation_space,
self.preprocessor, config["observation_filter"], config["model"])
@property
def filters(self):
return {"default": self.policy.get_filter()}
def sync_filters(self, new_filters):
for k in self.filters:
self.filters[k].sync(new_filters[k])
def get_filters(self, flush_after=False):
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
def rollout(self, timestep_limit, add_noise=False):
rollout_rewards, rollout_length = policies.rollout(
self.policy,
@@ -201,6 +218,7 @@ class ARSAgent(Agent):
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
num_timesteps += sum(
sum(pair) for pair in result.noisy_lengths)
return results, num_episodes, num_timesteps
def _train(self):
@@ -276,6 +294,11 @@ class ARSAgent(Agent):
if len(all_eval_returns) > 0:
self.reward_list.append(eval_returns.mean())
# Now sync the filters
FilterManager.synchronize({
"default": self.policy.get_filter()
}, self.workers)
info = {
"weights_norm": np.square(theta).sum(),
"weights_std": np.std(theta),
@@ -301,12 +324,17 @@ class ARSAgent(Agent):
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"episodes_so_far": self.episodes_so_far,
}
def __setstate__(self, state):
self.policy.set_weights(state["weights"])
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
"default": self.policy.get_filter()
}, self.workers)
def compute_action(self, observation):
return self.policy.compute(observation, update=True)[0]
+6
View File
@@ -102,5 +102,11 @@ class GenericPolicy(object):
def set_weights(self, x):
self.variables.set_flat(x)
def set_filter(self, obs_filter):
self.observation_filter = obs_filter
def get_filter(self):
return self.observation_filter
def get_weights(self):
return self.variables.get_flat()
+29 -1
View File
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.utils import merge_dicts
from ray.rllib.utils import FilterManager
logger = logging.getLogger(__name__)
@@ -89,6 +90,22 @@ class Worker(object):
self.preprocessor, config["observation_filter"], config["model"],
**policy_params)
@property
def filters(self):
return {"default": self.policy.get_filter()}
def sync_filters(self, new_filters):
for k in self.filters:
self.filters[k].sync(new_filters[k])
def get_filters(self, flush_after=False):
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_length = policies.rollout(
self.policy,
@@ -207,6 +224,7 @@ class ESAgent(Agent):
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
num_timesteps += sum(
sum(pair) for pair in result.noisy_lengths)
return results, num_episodes, num_timesteps
def _train(self):
@@ -274,6 +292,11 @@ class ESAgent(Agent):
if len(all_eval_returns) > 0:
self.reward_list.append(np.mean(eval_returns))
# Now sync the filters
FilterManager.synchronize({
"default": self.policy.get_filter()
}, self.workers)
info = {
"weights_norm": np.square(theta).sum(),
"grad_norm": np.square(g).sum(),
@@ -299,12 +322,17 @@ class ESAgent(Agent):
def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"episodes_so_far": self.episodes_so_far,
}
def __setstate__(self, state):
self.policy.set_weights(state["weights"])
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
"default": self.policy.get_filter()
}, self.workers)
def compute_action(self, observation):
return self.policy.compute(observation, update=False)[0]
+6
View File
@@ -82,3 +82,9 @@ class GenericPolicy(object):
def get_weights(self):
return self.variables.get_flat()
def get_filter(self):
return self.observation_filter
def set_filter(self, observation_filter):
self.observation_filter = observation_filter
@@ -23,7 +23,8 @@ CONFIGS = {
"ES": {
"episodes_per_batch": 10,
"train_batch_size": 100,
"num_workers": 2
"num_workers": 2,
"observation_filter": "MeanStdFilter"
},
"DQN": {},
"APEX_DDPG": {
@@ -46,6 +47,11 @@ CONFIGS = {
"A3C": {
"num_workers": 1
},
"ARS": {
"num_rollouts": 10,
"num_workers": 2,
"observation_filter": "MeanStdFilter"
}
}
@@ -83,7 +89,7 @@ def test(use_object_store, alg_name, failures):
if __name__ == "__main__":
failures = []
for use_object_store in [False, True]:
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]:
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]:
test(use_object_store, name, failures)
assert not failures, failures
@@ -1,3 +1,4 @@
# can expect improvement to -140 reward in ~300-500k timesteps
swimmer-ars:
env: Swimmer-v2
run: ARS