[rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697)

* fix

* update test

* better error

* compute

* eps fix

* add get_policy() api

* Update agent.py

* better err msg

* fix

* pass in rew
This commit is contained in:
Eric Liang
2019-01-06 19:37:35 -08:00
committed by GitHub
parent df0733cafb
commit e78562b2e8
12 changed files with 111 additions and 34 deletions
+7 -9
View File
@@ -179,23 +179,21 @@ Accessing Policy State
~~~~~~~~~~~~~~~~~~~~~~
It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
You can also access just the "master" copy of the agent state through ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.local_evaluator.policy_map["default"].get_weights()``. This is also equivalent to ``agent.local_evaluator.for_policy(lambda p: p.get_weights())``:
You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default"].get_weights()``:
.. code-block:: python
# Get weights of the local policy
# Get weights of the default local policy
agent.get_policy().get_weights()
# Same as above
agent.local_evaluator.policy_map["default"].get_weights()
# Same as above
agent.local_evaluator.for_policy(lambda p: p.get_weights())
# Get list of weights of each evaluator, including remote replicas
agent.optimizer.foreach_evaluator(
lambda ev: ev.for_policy(lambda p: p.get_weights()))
agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
# Same as above
agent.optimizer.foreach_evaluator_with_index(
lambda ev, i: ev.for_policy(lambda p: p.get_weights()))
agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
Global Coordination
~~~~~~~~~~~~~~~~~~~
+26 -7
View File
@@ -341,9 +341,18 @@ class Agent(Trainable):
raise NotImplementedError
def compute_action(self, observation, state=None, policy_id="default"):
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id="default"):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
@@ -351,6 +360,9 @@ class Agent(Trainable):
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
"""
@@ -361,12 +373,10 @@ class Agent(Trainable):
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed, update=False)
if state:
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(filtered_obs, state),
policy_id=policy_id)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(filtered_obs, state)[0],
policy_id=policy_id)
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)[0]
@property
def iteration(self):
@@ -386,6 +396,15 @@ class Agent(Trainable):
raise NotImplementedError
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.local_evaluator.get_policy(policy_id)
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
+2 -1
View File
@@ -48,6 +48,7 @@ class ApexDDPGAgent(DDPGAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+2 -1
View File
@@ -51,6 +51,7 @@ class ApexAgent(DQNAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
@@ -89,3 +89,7 @@ class PGPolicyGraph(TFPolicyGraph):
@override(PolicyGraph)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicyGraph)
def optimizer(self):
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
+2 -1
View File
@@ -50,6 +50,7 @@ class ApexQMixAgent(QMixAgent):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
@@ -489,6 +489,15 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
return grad_fetch
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.policy_map.get(policy_id)
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""
+23 -7
View File
@@ -71,9 +71,9 @@ class PolicyGraph(object):
def compute_single_action(self,
obs,
state,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
prev_action=None,
prev_reward=None,
info=None,
episode=None,
**kwargs):
"""Unbatched version of compute_actions.
@@ -81,9 +81,9 @@ class PolicyGraph(object):
Arguments:
obs (obj): single observation
state_batches (list): list of RNN state inputs, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
info_batch (list): batch of info objects
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
@@ -95,8 +95,24 @@ class PolicyGraph(object):
info (dict): dictionary of extra features, if any
"""
prev_action_batch = None
prev_reward_batch = None
info_batch = None
episodes = None
if prev_action is not None:
prev_action_batch = [prev_action]
if prev_reward is not None:
prev_reward_batch = [prev_reward]
if info is not None:
info_batch = [info]
if episode is not None:
episodes = [episode]
[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state], episodes=[episode])
[obs], [[s] for s in state],
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes)
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
+10 -5
View File
@@ -301,8 +301,8 @@ class TFPolicyGraph(PolicyGraph):
tf.saved_model.signature_def_utils.build_signature_def(
input_signature, output_signature,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
signature_def_key = \
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
signature_def_key = (tf.saved_model.signature_constants.
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
signature_def_map = {signature_def_key: signature_def}
return signature_def_map
@@ -314,8 +314,10 @@ class TFPolicyGraph(PolicyGraph):
prev_reward_batch=None,
episodes=None):
state_batches = state_batches or []
assert len(self._state_inputs) == len(state_batches), \
(self._state_inputs, state_batches)
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
@@ -339,7 +341,10 @@ class TFPolicyGraph(PolicyGraph):
return fetches[0], fetches[1]
def _build_apply_gradients(self, builder, gradients):
assert len(gradients) == len(self._grads), (gradients, self._grads)
if len(gradients) != len(self._grads):
raise ValueError(
"Unexpected number of gradients to apply, got {} for {}".
format(gradients, self._grads))
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
@@ -292,6 +292,9 @@ class TFMultiGPULearner(LearnerThread):
logger.info("TFMultiGPULearner devices {}".format(self.devices))
assert self.train_batch_size % len(self.devices) == 0
assert self.train_batch_size >= len(self.devices), "batch too small"
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.policy = self.local_evaluator.policy_map["default"]
# per-GPU graph copies created below must share vars with the policy
@@ -161,8 +161,10 @@ class LocalSyncParallelOptimizer(object):
seq_batch_size = make_divisible_by(
len(smallest_array), len(self.devices))
if seq_batch_size < len(self.devices):
raise ValueError("Must load at least 1 tuple sequence per device, "
"got only {} total.".format(len(smallest_array)))
raise ValueError(
"Must load at least 1 tuple sequence per device. Try "
"increasing `sgd_minibatch_size` or reducing `max_seq_len` "
"to ensure that at least one sequence fits per device.")
self._loaded_per_device_batch_size = (
seq_batch_size // len(self.devices) * self._loaded_max_seq_len)
+19 -1
View File
@@ -92,7 +92,7 @@ def check_support(alg, config, stats, check_bounds=False):
def check_support_multiagent(alg, config):
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
register_env("multi_cartpole", lambda _: MultiCartpole(2))
if alg == "DDPG":
if "DDPG" in alg:
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
else:
a = get_agent_class(alg)(config=config, env="multi_cartpole")
@@ -169,6 +169,24 @@ class ModelSupportedSpaces(unittest.TestCase):
self.assertEqual(num_unexpected_errors, 0)
def testMultiAgent(self):
check_support_multiagent(
"APEX", {
"num_workers": 2,
"timesteps_per_iteration": 1000,
"num_gpus": 0,
"min_iter_time_s": 1,
"learning_starts": 1000,
"target_network_update_freq": 100,
})
check_support_multiagent(
"APEX_DDPG", {
"num_workers": 2,
"timesteps_per_iteration": 1000,
"num_gpus": 0,
"min_iter_time_s": 1,
"learning_starts": 1000,
"target_network_update_freq": 100,
})
check_support_multiagent("IMPALA", {"num_gpus": 0})
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
check_support_multiagent("A3C", {