[rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697)

* fix

* update test

* better error

* compute

* eps fix

* add get_policy() api

* Update agent.py

* better err msg

* fix

* pass in rew
This commit is contained in:
Eric Liang
2019-01-06 19:37:35 -08:00
committed by GitHub
parent df0733cafb
commit e78562b2e8
12 changed files with 111 additions and 34 deletions
+26 -7
View File
@@ -341,9 +341,18 @@ class Agent(Trainable):
raise NotImplementedError
def compute_action(self, observation, state=None, policy_id="default"):
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id="default"):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
@@ -351,6 +360,9 @@ class Agent(Trainable):
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
"""
@@ -361,12 +373,10 @@ class Agent(Trainable):
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed, update=False)
if state:
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(filtered_obs, state),
policy_id=policy_id)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(filtered_obs, state)[0],
policy_id=policy_id)
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)[0]
@property
def iteration(self):
@@ -386,6 +396,15 @@ class Agent(Trainable):
raise NotImplementedError
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.local_evaluator.get_policy(policy_id)
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
+2 -1
View File
@@ -48,6 +48,7 @@ class ApexDDPGAgent(DDPGAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+2 -1
View File
@@ -51,6 +51,7 @@ class ApexAgent(DQNAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
@@ -89,3 +89,7 @@ class PGPolicyGraph(TFPolicyGraph):
@override(PolicyGraph)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicyGraph)
def optimizer(self):
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
+2 -1
View File
@@ -50,6 +50,7 @@ class ApexQMixAgent(QMixAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1