mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:46:49 +08:00
[rllib] Fix filter sync for ES and ARS (#2918)
This commit is contained in:
committed by
Eric Liang
parent
725df3a485
commit
344b4ef0ff
@@ -18,6 +18,7 @@ from ray.tune.trial import Resources
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,6 +85,22 @@ class Worker(object):
|
||||
self.sess, self.env.action_space, self.env.observation_space,
|
||||
self.preprocessor, config["observation_filter"], config["model"])
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {"default": self.policy.get_filter()}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
self.filters[k].sync(new_filters[k])
|
||||
|
||||
def get_filters(self, flush_after=False):
|
||||
return_filters = {}
|
||||
for k, f in self.filters.items():
|
||||
return_filters[k] = f.as_serializable()
|
||||
if flush_after:
|
||||
f.clear_buffer()
|
||||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=False):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
self.policy,
|
||||
@@ -201,6 +218,7 @@ class ARSAgent(Agent):
|
||||
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
|
||||
num_timesteps += sum(
|
||||
sum(pair) for pair in result.noisy_lengths)
|
||||
|
||||
return results, num_episodes, num_timesteps
|
||||
|
||||
def _train(self):
|
||||
@@ -276,6 +294,11 @@ class ARSAgent(Agent):
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(eval_returns.mean())
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(theta).sum(),
|
||||
"weights_std": np.std(theta),
|
||||
@@ -301,12 +324,17 @@ class ARSAgent(Agent):
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"filter": self.policy.get_filter(),
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=True)[0]
|
||||
|
||||
@@ -102,5 +102,11 @@ class GenericPolicy(object):
|
||||
def set_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
def set_filter(self, obs_filter):
|
||||
self.observation_filter = obs_filter
|
||||
|
||||
def get_filter(self):
|
||||
return self.observation_filter
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
@@ -18,6 +18,7 @@ from ray.rllib.agents.es import optimizers
|
||||
from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,6 +90,22 @@ class Worker(object):
|
||||
self.preprocessor, config["observation_filter"], config["model"],
|
||||
**policy_params)
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {"default": self.policy.get_filter()}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
self.filters[k].sync(new_filters[k])
|
||||
|
||||
def get_filters(self, flush_after=False):
|
||||
return_filters = {}
|
||||
for k, f in self.filters.items():
|
||||
return_filters[k] = f.as_serializable()
|
||||
if flush_after:
|
||||
f.clear_buffer()
|
||||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=True):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
self.policy,
|
||||
@@ -207,6 +224,7 @@ class ESAgent(Agent):
|
||||
num_episodes += sum(len(pair) for pair in result.noisy_lengths)
|
||||
num_timesteps += sum(
|
||||
sum(pair) for pair in result.noisy_lengths)
|
||||
|
||||
return results, num_episodes, num_timesteps
|
||||
|
||||
def _train(self):
|
||||
@@ -274,6 +292,11 @@ class ESAgent(Agent):
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(np.mean(eval_returns))
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(theta).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
@@ -299,12 +322,17 @@ class ESAgent(Agent):
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"filter": self.policy.get_filter(),
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=False)[0]
|
||||
|
||||
@@ -82,3 +82,9 @@ class GenericPolicy(object):
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_flat()
|
||||
|
||||
def get_filter(self):
|
||||
return self.observation_filter
|
||||
|
||||
def set_filter(self, observation_filter):
|
||||
self.observation_filter = observation_filter
|
||||
|
||||
@@ -23,7 +23,8 @@ CONFIGS = {
|
||||
"ES": {
|
||||
"episodes_per_batch": 10,
|
||||
"train_batch_size": 100,
|
||||
"num_workers": 2
|
||||
"num_workers": 2,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
},
|
||||
"DQN": {},
|
||||
"APEX_DDPG": {
|
||||
@@ -46,6 +47,11 @@ CONFIGS = {
|
||||
"A3C": {
|
||||
"num_workers": 1
|
||||
},
|
||||
"ARS": {
|
||||
"num_rollouts": 10,
|
||||
"num_workers": 2,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +89,7 @@ def test(use_object_store, alg_name, failures):
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]:
|
||||
test(use_object_store, name, failures)
|
||||
|
||||
assert not failures, failures
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# can expect improvement to -140 reward in ~300-500k timesteps
|
||||
swimmer-ars:
|
||||
env: Swimmer-v2
|
||||
run: ARS
|
||||
|
||||
Reference in New Issue
Block a user