mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697)
* fix * update test * better error * compute * eps fix * add get_policy() api * Update agent.py * better err msg * fix * pass in rew
This commit is contained in:
@@ -179,23 +179,21 @@ Accessing Policy State
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
|
||||
|
||||
You can also access just the "master" copy of the agent state through ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.local_evaluator.policy_map["default"].get_weights()``. This is also equivalent to ``agent.local_evaluator.for_policy(lambda p: p.get_weights())``:
|
||||
You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default"].get_weights()``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Get weights of the local policy
|
||||
# Get weights of the default local policy
|
||||
agent.get_policy().get_weights()
|
||||
|
||||
# Same as above
|
||||
agent.local_evaluator.policy_map["default"].get_weights()
|
||||
|
||||
# Same as above
|
||||
agent.local_evaluator.for_policy(lambda p: p.get_weights())
|
||||
|
||||
# Get list of weights of each evaluator, including remote replicas
|
||||
agent.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.for_policy(lambda p: p.get_weights()))
|
||||
agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
|
||||
|
||||
# Same as above
|
||||
agent.optimizer.foreach_evaluator_with_index(
|
||||
lambda ev, i: ev.for_policy(lambda p: p.get_weights()))
|
||||
agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
|
||||
|
||||
Global Coordination
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -341,9 +341,18 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observation, state=None, policy_id="default"):
|
||||
def compute_action(self,
|
||||
observation,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id="default"):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
@@ -351,6 +360,9 @@ class Agent(Trainable):
|
||||
(computed action, rnn state, logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is
|
||||
returned (computed action).
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
"""
|
||||
|
||||
@@ -361,12 +373,10 @@ class Agent(Trainable):
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(filtered_obs, state),
|
||||
policy_id=policy_id)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(filtered_obs, state)[0],
|
||||
policy_id=policy_id)
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs, state, prev_action, prev_reward, info)[0]
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
@@ -386,6 +396,15 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class ApexDDPGAgent(DDPGAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -51,6 +51,7 @@ class ApexAgent(DQNAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -89,3 +89,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
|
||||
@@ -50,6 +50,7 @@ class ApexQMixAgent(QMixAgent):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -489,6 +489,15 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
|
||||
return grad_fetch
|
||||
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
"""
|
||||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy graph."""
|
||||
|
||||
|
||||
@@ -71,9 +71,9 @@ class PolicyGraph(object):
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
episode=None,
|
||||
**kwargs):
|
||||
"""Unbatched version of compute_actions.
|
||||
@@ -81,9 +81,9 @@ class PolicyGraph(object):
|
||||
Arguments:
|
||||
obs (obj): single observation
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
info_batch (list): batch of info objects
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
@@ -95,8 +95,24 @@ class PolicyGraph(object):
|
||||
info (dict): dictionary of extra features, if any
|
||||
"""
|
||||
|
||||
prev_action_batch = None
|
||||
prev_reward_batch = None
|
||||
info_batch = None
|
||||
episodes = None
|
||||
if prev_action is not None:
|
||||
prev_action_batch = [prev_action]
|
||||
if prev_reward is not None:
|
||||
prev_reward_batch = [prev_reward]
|
||||
if info is not None:
|
||||
info_batch = [info]
|
||||
if episode is not None:
|
||||
episodes = [episode]
|
||||
[action], state_out, info = self.compute_actions(
|
||||
[obs], [[s] for s in state], episodes=[episode])
|
||||
[obs], [[s] for s in state],
|
||||
prev_action_batch=prev_action_batch,
|
||||
prev_reward_batch=prev_reward_batch,
|
||||
info_batch=info_batch,
|
||||
episodes=episodes)
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
|
||||
@@ -301,8 +301,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
tf.saved_model.signature_def_utils.build_signature_def(
|
||||
input_signature, output_signature,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
|
||||
signature_def_key = \
|
||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
|
||||
signature_def_key = (tf.saved_model.signature_constants.
|
||||
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
signature_def_map = {signature_def_key: signature_def}
|
||||
return signature_def_map
|
||||
|
||||
@@ -314,8 +314,10 @@ class TFPolicyGraph(PolicyGraph):
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
state_batches = state_batches or []
|
||||
assert len(self._state_inputs) == len(state_batches), \
|
||||
(self._state_inputs, state_batches)
|
||||
if len(self._state_inputs) != len(state_batches):
|
||||
raise ValueError(
|
||||
"Must pass in RNN state batches for placeholders {}, got {}".
|
||||
format(self._state_inputs, state_batches))
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
@@ -339,7 +341,10 @@ class TFPolicyGraph(PolicyGraph):
|
||||
return fetches[0], fetches[1]
|
||||
|
||||
def _build_apply_gradients(self, builder, gradients):
|
||||
assert len(gradients) == len(self._grads), (gradients, self._grads)
|
||||
if len(gradients) != len(self._grads):
|
||||
raise ValueError(
|
||||
"Unexpected number of gradients to apply, got {} for {}".
|
||||
format(gradients, self._grads))
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
||||
|
||||
@@ -292,6 +292,9 @@ class TFMultiGPULearner(LearnerThread):
|
||||
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_evaluator.policy_map["default"]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
|
||||
@@ -161,8 +161,10 @@ class LocalSyncParallelOptimizer(object):
|
||||
seq_batch_size = make_divisible_by(
|
||||
len(smallest_array), len(self.devices))
|
||||
if seq_batch_size < len(self.devices):
|
||||
raise ValueError("Must load at least 1 tuple sequence per device, "
|
||||
"got only {} total.".format(len(smallest_array)))
|
||||
raise ValueError(
|
||||
"Must load at least 1 tuple sequence per device. Try "
|
||||
"increasing `sgd_minibatch_size` or reducing `max_seq_len` "
|
||||
"to ensure that at least one sequence fits per device.")
|
||||
self._loaded_per_device_batch_size = (
|
||||
seq_batch_size // len(self.devices) * self._loaded_max_seq_len)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def check_support(alg, config, stats, check_bounds=False):
|
||||
def check_support_multiagent(alg, config):
|
||||
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(2))
|
||||
if alg == "DDPG":
|
||||
if "DDPG" in alg:
|
||||
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
|
||||
else:
|
||||
a = get_agent_class(alg)(config=config, env="multi_cartpole")
|
||||
@@ -169,6 +169,24 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
self.assertEqual(num_unexpected_errors, 0)
|
||||
|
||||
def testMultiAgent(self):
|
||||
check_support_multiagent(
|
||||
"APEX", {
|
||||
"num_workers": 2,
|
||||
"timesteps_per_iteration": 1000,
|
||||
"num_gpus": 0,
|
||||
"min_iter_time_s": 1,
|
||||
"learning_starts": 1000,
|
||||
"target_network_update_freq": 100,
|
||||
})
|
||||
check_support_multiagent(
|
||||
"APEX_DDPG", {
|
||||
"num_workers": 2,
|
||||
"timesteps_per_iteration": 1000,
|
||||
"num_gpus": 0,
|
||||
"min_iter_time_s": 1,
|
||||
"learning_starts": 1000,
|
||||
"target_network_update_freq": 100,
|
||||
})
|
||||
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
||||
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
|
||||
check_support_multiagent("A3C", {
|
||||
|
||||
Reference in New Issue
Block a user