From e78562b2e8d2affc8b0f7fabde37aa192b4385fc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 6 Jan 2019 19:37:35 -0800 Subject: [PATCH] [rllib] Misc fixes: set lr for PG, better error message for LSTM/PPO, fix multi-agent/APEX (#3697) * fix * update test * better error * compute * eps fix * add get_policy() api * Update agent.py * better err msg * fix * pass in rew --- doc/source/rllib-training.rst | 16 ++++----- python/ray/rllib/agents/agent.py | 33 +++++++++++++++---- python/ray/rllib/agents/ddpg/apex.py | 3 +- python/ray/rllib/agents/dqn/apex.py | 3 +- python/ray/rllib/agents/pg/pg_policy_graph.py | 4 +++ python/ray/rllib/agents/qmix/apex.py | 3 +- .../ray/rllib/evaluation/policy_evaluator.py | 9 +++++ python/ray/rllib/evaluation/policy_graph.py | 30 +++++++++++++---- .../ray/rllib/evaluation/tf_policy_graph.py | 15 ++++++--- .../optimizers/async_samples_optimizer.py | 3 ++ python/ray/rllib/optimizers/multi_gpu_impl.py | 6 ++-- .../ray/rllib/test/test_supported_spaces.py | 20 ++++++++++- 12 files changed, 111 insertions(+), 34 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index cc73ec086..5b0947dae 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -179,23 +179,21 @@ Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. -You can also access just the "master" copy of the agent state through ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.local_evaluator.policy_map["default"].get_weights()``. This is also equivalent to ``agent.local_evaluator.for_policy(lambda p: p.get_weights())``: +You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default"].get_weights()``: .. code-block:: python - # Get weights of the local policy + # Get weights of the default local policy + agent.get_policy().get_weights() + + # Same as above agent.local_evaluator.policy_map["default"].get_weights() - # Same as above - agent.local_evaluator.for_policy(lambda p: p.get_weights()) - # Get list of weights of each evaluator, including remote replicas - agent.optimizer.foreach_evaluator( - lambda ev: ev.for_policy(lambda p: p.get_weights())) + agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights()) # Same as above - agent.optimizer.foreach_evaluator_with_index( - lambda ev, i: ev.for_policy(lambda p: p.get_weights())) + agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights()) Global Coordination ~~~~~~~~~~~~~~~~~~~ diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index cce3f449d..26f12b29a 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -341,9 +341,18 @@ class Agent(Trainable): raise NotImplementedError - def compute_action(self, observation, state=None, policy_id="default"): + def compute_action(self, + observation, + state=None, + prev_action=None, + prev_reward=None, + info=None, + policy_id="default"): """Computes an action for the specified policy. + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_actions() on it directly. + Arguments: observation (obj): observation from the environment. state (list): RNN hidden state, if any. If state is not None, @@ -351,6 +360,9 @@ class Agent(Trainable): (computed action, rnn state, logits dictionary). Otherwise compute_single_action(...)[0] is returned (computed action). + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any policy_id (str): policy to query (only applies to multi-agent). """ @@ -361,12 +373,10 @@ class Agent(Trainable): filtered_obs = self.local_evaluator.filters[policy_id]( preprocessed, update=False) if state: - return self.local_evaluator.for_policy( - lambda p: p.compute_single_action(filtered_obs, state), - policy_id=policy_id) - return self.local_evaluator.for_policy( - lambda p: p.compute_single_action(filtered_obs, state)[0], - policy_id=policy_id) + return self.get_policy(policy_id).compute_single_action( + filtered_obs, state, prev_action, prev_reward, info) + return self.get_policy(policy_id).compute_single_action( + filtered_obs, state, prev_action, prev_reward, info)[0] @property def iteration(self): @@ -386,6 +396,15 @@ class Agent(Trainable): raise NotImplementedError + def get_policy(self, policy_id=DEFAULT_POLICY_ID): + """Return policy graph for the specified id, or None. + + Arguments: + policy_id (str): id of policy graph to return. + """ + + return self.local_evaluator.get_policy(policy_id) + def get_weights(self, policies=None): """Return a dictionary of policy ids to weights. diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 6b3465013..9f4395737 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -48,6 +48,7 @@ class ApexDDPGAgent(DDPGAgent): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.for_policy(lambda p: p.update_target()) + self.local_evaluator.foreach_trainable_policy( + lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index c9b15e0ec..fbe130dd3 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -51,6 +51,7 @@ class ApexAgent(DQNAgent): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.for_policy(lambda p: p.update_target()) + self.local_evaluator.foreach_trainable_policy( + lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 59e9a9eff..1594a4934 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -89,3 +89,7 @@ class PGPolicyGraph(TFPolicyGraph): @override(PolicyGraph) def get_initial_state(self): return self.model.state_init + + @override(TFPolicyGraph) + def optimizer(self): + return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index 9f471faef..022ea2c45 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -50,6 +50,7 @@ class ApexQMixAgent(QMixAgent): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.for_policy(lambda p: p.update_target()) + self.local_evaluator.foreach_trainable_policy( + lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index aaaf47c45..5b7e590f4 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -489,6 +489,15 @@ class PolicyEvaluator(EvaluatorInterface): self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch + def get_policy(self, policy_id=DEFAULT_POLICY_ID): + """Return policy graph for the specified id, or None. + + Arguments: + policy_id (str): id of policy graph to return. + """ + + return self.policy_map.get(policy_id) + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): """Apply the given function to the specified policy graph.""" diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index fc4be5706..3a1b0e116 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -71,9 +71,9 @@ class PolicyGraph(object): def compute_single_action(self, obs, state, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, + prev_action=None, + prev_reward=None, + info=None, episode=None, **kwargs): """Unbatched version of compute_actions. @@ -81,9 +81,9 @@ class PolicyGraph(object): Arguments: obs (obj): single observation state_batches (list): list of RNN state inputs, if any - prev_action_batch (np.ndarray): batch of previous action values - prev_reward_batch (np.ndarray): batch of previous rewards - info_batch (list): batch of info objects + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any episode (MultiAgentEpisode): this provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. @@ -95,8 +95,24 @@ class PolicyGraph(object): info (dict): dictionary of extra features, if any """ + prev_action_batch = None + prev_reward_batch = None + info_batch = None + episodes = None + if prev_action is not None: + prev_action_batch = [prev_action] + if prev_reward is not None: + prev_reward_batch = [prev_reward] + if info is not None: + info_batch = [info] + if episode is not None: + episodes = [episode] [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], episodes=[episode]) + [obs], [[s] for s in state], + prev_action_batch=prev_action_batch, + prev_reward_batch=prev_reward_batch, + info_batch=info_batch, + episodes=episodes) return action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 7574864c9..2073fca1d 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -301,8 +301,8 @@ class TFPolicyGraph(PolicyGraph): tf.saved_model.signature_def_utils.build_signature_def( input_signature, output_signature, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) - signature_def_key = \ - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501 + signature_def_key = (tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY) signature_def_map = {signature_def_key: signature_def} return signature_def_map @@ -314,8 +314,10 @@ class TFPolicyGraph(PolicyGraph): prev_reward_batch=None, episodes=None): state_batches = state_batches or [] - assert len(self._state_inputs) == len(state_batches), \ - (self._state_inputs, state_batches) + if len(self._state_inputs) != len(state_batches): + raise ValueError( + "Must pass in RNN state batches for placeholders {}, got {}". + format(self._state_inputs, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: @@ -339,7 +341,10 @@ class TFPolicyGraph(PolicyGraph): return fetches[0], fetches[1] def _build_apply_gradients(self, builder, gradients): - assert len(gradients) == len(self._grads), (gradients, self._grads) + if len(gradients) != len(self._grads): + raise ValueError( + "Unexpected number of gradients to apply, got {} for {}". + format(gradients, self._grads)) builder.add_feed_dict(self.extra_apply_grad_feed_dict()) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict(dict(zip(self._grads, gradients))) diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 9322e1fcd..60b4eb691 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -292,6 +292,9 @@ class TFMultiGPULearner(LearnerThread): logger.info("TFMultiGPULearner devices {}".format(self.devices)) assert self.train_batch_size % len(self.devices) == 0 assert self.train_batch_size >= len(self.devices), "batch too small" + + if set(self.local_evaluator.policy_map.keys()) != {"default"}: + raise NotImplementedError("Multi-gpu mode for multi-agent") self.policy = self.local_evaluator.policy_map["default"] # per-GPU graph copies created below must share vars with the policy diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 0a03df41c..07bb9b886 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -161,8 +161,10 @@ class LocalSyncParallelOptimizer(object): seq_batch_size = make_divisible_by( len(smallest_array), len(self.devices)) if seq_batch_size < len(self.devices): - raise ValueError("Must load at least 1 tuple sequence per device, " - "got only {} total.".format(len(smallest_array))) + raise ValueError( + "Must load at least 1 tuple sequence per device. Try " + "increasing `sgd_minibatch_size` or reducing `max_seq_len` " + "to ensure that at least one sequence fits per device.") self._loaded_per_device_batch_size = ( seq_batch_size // len(self.devices) * self._loaded_max_seq_len) diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index fbfd1f5ea..03cab83f2 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -92,7 +92,7 @@ def check_support(alg, config, stats, check_bounds=False): def check_support_multiagent(alg, config): register_env("multi_mountaincar", lambda _: MultiMountainCar(2)) register_env("multi_cartpole", lambda _: MultiCartpole(2)) - if alg == "DDPG": + if "DDPG" in alg: a = get_agent_class(alg)(config=config, env="multi_mountaincar") else: a = get_agent_class(alg)(config=config, env="multi_cartpole") @@ -169,6 +169,24 @@ class ModelSupportedSpaces(unittest.TestCase): self.assertEqual(num_unexpected_errors, 0) def testMultiAgent(self): + check_support_multiagent( + "APEX", { + "num_workers": 2, + "timesteps_per_iteration": 1000, + "num_gpus": 0, + "min_iter_time_s": 1, + "learning_starts": 1000, + "target_network_update_freq": 100, + }) + check_support_multiagent( + "APEX_DDPG", { + "num_workers": 2, + "timesteps_per_iteration": 1000, + "num_gpus": 0, + "min_iter_time_s": 1, + "learning_starts": 1000, + "target_network_update_freq": 100, + }) check_support_multiagent("IMPALA", {"num_gpus": 0}) check_support_multiagent("DQN", {"timesteps_per_iteration": 1}) check_support_multiagent("A3C", {