mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 23:50:20 +08:00
[rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697)
* fix * update test * better error * compute * eps fix * add get_policy() api * Update agent.py * better err msg * fix * pass in rew
This commit is contained in:
@@ -341,9 +341,18 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observation, state=None, policy_id="default"):
|
||||
def compute_action(self,
|
||||
observation,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id="default"):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
@@ -351,6 +360,9 @@ class Agent(Trainable):
|
||||
(computed action, rnn state, logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is
|
||||
returned (computed action).
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
"""
|
||||
|
||||
@@ -361,12 +373,10 @@ class Agent(Trainable):
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(filtered_obs, state),
|
||||
policy_id=policy_id)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(filtered_obs, state)[0],
|
||||
policy_id=policy_id)
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)[0]
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
@@ -386,6 +396,15 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class ApexDDPGAgent(DDPGAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -51,6 +51,7 @@ class ApexAgent(DQNAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -89,3 +89,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
|
||||
@@ -50,6 +50,7 @@ class ApexQMixAgent(QMixAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
Reference in New Issue
Block a user