From 5f430da18075878fbefd7b9c33cc22bb65710d9d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 16 Aug 2018 14:37:21 -0700 Subject: [PATCH] [rllib] Provide internal access to episode state in compute_actions() and allow returning extra batches (#2559) The goal of this PR is to allow custom policies to perform model-based rollouts. In the multi-agent setting, this requires access to not only policies of other agents, but also their current observations. Also, you might want to return the model-based trajectories as part of the rollout for efficiency. compute_actions() now takes a new keyword arg episodes pull out internal episode class into a top-level file add function to return extra trajectories from an episode that will be appended to the sample batch documentation --- doc/source/rllib-models.rst | 21 +++ doc/source/rllib.rst | 1 + python/ray/rllib/evaluation/__init__.py | 4 +- python/ray/rllib/evaluation/episode.py | 119 +++++++++++++++ .../ray/rllib/evaluation/policy_evaluator.py | 1 + python/ray/rllib/evaluation/policy_graph.py | 23 ++- python/ray/rllib/evaluation/sample_batch.py | 5 + python/ray/rllib/evaluation/sampler.py | 135 +++++++----------- .../ray/rllib/evaluation/tf_policy_graph.py | 10 +- .../rllib/evaluation/torch_policy_graph.py | 7 +- python/ray/rllib/test/test_multi_agent_env.py | 45 ++++++ .../ray/rllib/test/test_policy_evaluator.py | 12 +- 12 files changed, 286 insertions(+), 97 deletions(-) create mode 100644 python/ray/rllib/evaluation/episode.py diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 098cfb6ec..70b5ed756 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -125,3 +125,24 @@ Then, you can create an agent with your custom policy graph by: agent = DDPGAgent(...) That's it. In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. + +Model-Based Rollouts +-------------------- + +With a custom policy graph, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicyGraph for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy graph: + +.. code-block:: python + + class ModelBasedPolicyGraph(PGPolicyGraph): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): + # compute a batch of actions based on the current obs_batch + # and state of each episode (i.e., for multiagent). You can do + # whatever is needed here, e.g., MCTS rollouts. + return action_batch + + +If you want take this rollouts data and append it to the sample batch, use the ``add_extra_batch()`` method of the `episode objects `__ passed in. For an example of this, see the ``testReturningModelBasedRolloutsData`` `unit test `__. diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 19b883746..29f8acf5c 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -58,6 +58,7 @@ Models and Preprocessors * `Custom Models `__ * `Custom Preprocessors `__ * `Customizing Policy Graphs `__ +* `Model-Based Rollouts `__ RLlib Concepts -------------- diff --git a/python/ray/rllib/evaluation/__init__.py b/python/ray/rllib/evaluation/__init__.py index fdc8cfbff..b3d7b4d5d 100644 --- a/python/ray/rllib/evaluation/__init__.py +++ b/python/ray/rllib/evaluation/__init__.py @@ -1,3 +1,4 @@ +from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.interface import EvaluatorInterface from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -15,5 +16,6 @@ __all__ = [ "EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph", "TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder", "MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler", - "compute_advantages", "compute_targets", "collect_metrics" + "compute_advantages", "compute_targets", "collect_metrics", + "MultiAgentEpisode" ] diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py new file mode 100644 index 000000000..fc99d79fb --- /dev/null +++ b/python/ray/rllib/evaluation/episode.py @@ -0,0 +1,119 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +import random + +import numpy as np + + +class MultiAgentEpisode(object): + """Tracks the current state of a (possibly multi-agent) episode. + + The APIs in this class should be considered experimental, but we should + avoid changing things for the sake of changing them since users may + depend on them for advanced algorithms. + + Attributes: + new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder. + add_extra_batch (func): Return a built MultiAgentBatch to the sampler. + batch_builder (obj): Batch builder for the current episode. + total_reward (float): Summed reward across all agents in this episode. + length (int): Length of this episode. + episode_id (int): Unique id identifying this trajectory. + agent_rewards (dict): Summed rewards broken down by agent. + + Use case 1: Model-based rollouts in multi-agent: + A custom compute_actions() function in a policy graph can inspect the + current episode state and perform a number of rollouts based on the + policies and state of other agents in the environment. + + Use case 2: Returning extra rollouts data. + The model rollouts can be returned back to the sampler by calling: + + >>> batch = episode.new_batch_builder() + >>> for each transition: + batch.add_values(...) # see sampler for usage + >>> episode.extra_batches.add(batch.build_and_reset()) + """ + + def __init__(self, policies, policy_mapping_fn, batch_builder_factory, + extra_batch_callback): + self.new_batch_builder = batch_builder_factory + self.add_extra_batch = extra_batch_callback + self.batch_builder = batch_builder_factory() + self.total_reward = 0.0 + self.length = 0 + self.episode_id = random.randrange(2e9) + self.agent_rewards = defaultdict(float) + self._policies = policies + self._policy_mapping_fn = policy_mapping_fn + self._agent_to_policy = {} + self._agent_to_rnn_state = {} + self._agent_to_last_obs = {} + self._agent_to_last_action = {} + self._agent_to_last_pi_info = {} + + def policy_for(self, agent_id): + """Returns the policy graph for the specified agent. + + If the agent is new, the policy mapping fn will be called to bind the + agent to a policy for the duration of the episode. + """ + + if agent_id not in self._agent_to_policy: + self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id) + return self._agent_to_policy[agent_id] + + def last_observation_for(self, agent_id): + """Returns the last observation for the specified agent.""" + + return self._agent_to_last_obs.get(agent_id) + + def last_action_for(self, agent_id): + """Returns the last action for the specified agent.""" + + action = self._agent_to_last_action[agent_id] + # Concatenate tuple actions + if isinstance(action, list): + expanded = [] + for a in action: + if len(a.shape) == 1: + expanded.append(np.expand_dims(a, 1)) + else: + expanded.append(a) + action = np.concatenate(expanded, axis=1).flatten() + return action + + def rnn_state_for(self, agent_id): + """Returns the last RNN state for the specified agent.""" + + if agent_id not in self._agent_to_rnn_state: + policy = self._policies[self.policy_for(agent_id)] + self._agent_to_rnn_state[agent_id] = policy.get_initial_state() + return self._agent_to_rnn_state[agent_id] + + def last_pi_info_for(self, agent_id): + """Returns the last info object for the specified agent.""" + + return self._agent_to_last_pi_info[agent_id] + + def _add_agent_rewards(self, reward_dict): + for agent_id, reward in reward_dict.items(): + if reward is not None: + self.agent_rewards[agent_id, + self.policy_for(agent_id)] += reward + self.total_reward += reward + + def _set_rnn_state(self, agent_id, rnn_state): + self._agent_to_rnn_state[agent_id] = rnn_state + + def _set_last_observation(self, agent_id, obs): + self._agent_to_last_obs[agent_id] = obs + + def _set_last_action(self, agent_id, action): + self._agent_to_last_action[agent_id] = action + + def _set_last_pi_info(self, agent_id, pi_info): + self._agent_to_last_pi_info[agent_id] = pi_info diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 82b9569e0..3f789a4a1 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -292,6 +292,7 @@ class PolicyEvaluator(EvaluatorInterface): batch = self.sampler.get_data() steps_so_far += batch.count batches.append(batch) + batches.extend(self.sampler.get_extra_batches()) batch = batches[0].concat_samples(batches) if self.compress_observations: diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 32534d7c7..4b49a9b4b 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -37,13 +37,20 @@ class PolicyGraph(object): self.observation_space = observation_space self.action_space = action_space - def compute_actions(self, obs_batch, state_batches, is_training=False): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): """Compute actions for the current policy. Arguments: obs_batch (np.ndarray): batch of observations state_batches (list): list of RNN state input batches, if any is_training (bool): whether we are training the policy + episodes (list): MultiAgentEpisode for each obs in obs_batch. + This provides access to all of the internal episode state, + which may be useful for model-based or multiagent algorithms. Returns: actions (np.ndarray): batch of output actions, with shape like @@ -55,13 +62,20 @@ class PolicyGraph(object): """ raise NotImplementedError - def compute_single_action(self, obs, state, is_training=False): + def compute_single_action(self, + obs, + state, + is_training=False, + episode=None): """Unbatched version of compute_actions. Arguments: obs (obj): single observation state_batches (list): list of RNN state inputs, if any is_training (bool): whether we are training the policy + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multiagent algorithms. Returns: actions (obj): single action @@ -70,13 +84,16 @@ class PolicyGraph(object): """ [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], is_training) + [obs], [[s] for s in state], is_training, episodes=[episode]) return action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} def postprocess_trajectory(self, sample_batch, other_agent_batches=None): """Implements algorithm-specific trajectory postprocessing. + This will be called on each trajectory fragment computed during policy + evaluation. Each fragment is guaranteed to be only from one episode. + Arguments: sample_batch (SampleBatch): batch of experiences for the policy, which will contain at most one episode trajectory. diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index 109db4d3f..0ecc56609 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -117,6 +117,11 @@ class MultiAgentSampleBatchBuilder(object): other_batches = pre_batches.copy() del other_batches[agent_id] policy = self.policy_map[self.agent_to_policy[agent_id]] + if any(pre_batch["dones"][:-1]) or len(set( + pre_batch["eps_id"])) > 1: + raise ValueError( + "Batches sent to postprocessing must only contain steps " + "from a single trajectory.", pre_batch) post_batches[agent_id] = policy.postprocess_trajectory( pre_batch, other_batches) diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 53966cb9c..f88e1cdae 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -3,13 +3,13 @@ from __future__ import division from __future__ import print_function from collections import defaultdict, namedtuple -import numpy as np -import random import six.moves.queue as queue import threading +from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \ MultiAgentBatch +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.utils.tf_run_builder import TFRunBuilder @@ -44,10 +44,11 @@ class SyncSampler(object): self.policies = policies self.policy_mapping_fn = policy_mapping_fn self._obs_filters = obs_filters + self.extra_batches = queue.Queue() self.rollout_provider = _env_runner( - self.async_vector_env, self.policies, self.policy_mapping_fn, - self.num_local_steps, self.horizon, self._obs_filters, pack, - tf_sess) + self.async_vector_env, self.extra_batches.put, self.policies, + self.policy_mapping_fn, self.num_local_steps, self.horizon, + self._obs_filters, pack, tf_sess) self.metrics_queue = queue.Queue() def get_data(self): @@ -67,6 +68,15 @@ class SyncSampler(object): break return completed + def get_extra_batches(self): + extra = [] + while True: + try: + extra.append(self.extra_batches.get_nowait()) + except queue.Empty: + break + return extra + class AsyncSampler(threading.Thread): """This class interacts with the environment and tells it what to do. @@ -89,6 +99,7 @@ class AsyncSampler(threading.Thread): self.async_vector_env = AsyncVectorEnv.wrap_async(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) + self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() self.num_local_steps = num_local_steps self.horizon = horizon @@ -108,9 +119,9 @@ class AsyncSampler(threading.Thread): def _run(self): rollout_provider = _env_runner( - self.async_vector_env, self.policies, self.policy_mapping_fn, - self.num_local_steps, self.horizon, self._obs_filters, self.pack, - self.tf_sess) + self.async_vector_env, self.extra_batches.put, self.policies, + self.policy_mapping_fn, self.num_local_steps, self.horizon, + self._obs_filters, self.pack, self.tf_sess) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -153,8 +164,18 @@ class AsyncSampler(threading.Thread): break return completed + def get_extra_batches(self): + extra = [] + while True: + try: + extra.append(self.extra_batches.get_nowait()) + except queue.Empty: + break + return extra + def _env_runner(async_vector_env, + extra_batch_callback, policies, policy_mapping_fn, num_local_steps, @@ -166,6 +187,7 @@ def _env_runner(async_vector_env, Args: async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv. + extra_batch_callback (fn): function to send extra batch data to. policies (dict): Map of policy ids to PolicyGraph instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The @@ -204,8 +226,8 @@ def _env_runner(async_vector_env, return MultiAgentSampleBatchBuilder(policies) def new_episode(): - return _MultiAgentEpisode(policies, policy_mapping_fn, - get_batch_builder) + return MultiAgentEpisode(policies, policy_mapping_fn, + get_batch_builder, extra_batch_callback) active_episodes = defaultdict(new_episode) @@ -227,7 +249,7 @@ def _env_runner(async_vector_env, if not new_episode: episode.length += 1 episode.batch_builder.count += 1 - episode.add_agent_rewards(rewards[env_id]) + episode._add_agent_rewards(rewards[env_id]) # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: @@ -250,7 +272,7 @@ def _env_runner(async_vector_env, episode.rnn_state_for(agent_id))) last_observation = episode.last_observation_for(agent_id) - episode.set_last_observation(agent_id, filtered_obs) + episode._set_last_observation(agent_id, filtered_obs) # Record transition info if applicable if last_observation is not None and \ @@ -294,7 +316,7 @@ def _env_runner(async_vector_env, policy_id = episode.policy_for(agent_id) filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) - episode.set_last_observation(agent_id, filtered_obs) + episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, episode.rnn_state_for(agent_id))) @@ -302,6 +324,7 @@ def _env_runner(async_vector_env, # Batch eval policy actions if possible if tf_sess: builder = TFRunBuilder(tf_sess, "policy_eval") + pending_fetches = {} else: builder = None eval_results = {} @@ -310,16 +333,21 @@ def _env_runner(async_vector_env, rnn_in = _to_column_format([t.rnn_state for t in eval_data]) rnn_in_cols[policy_id] = rnn_in policy = _get_or_raise(policies, policy_id) - if builder: - eval_results[policy_id] = policy.build_compute_actions( + if builder and (policy.compute_actions.__code__ is + TFPolicyGraph.compute_actions.__code__): + pending_fetches[policy_id] = policy.build_compute_actions( builder, [t.obs for t in eval_data], rnn_in, is_training=True) else: eval_results[policy_id] = policy.compute_actions( - [t.obs for t in eval_data], rnn_in, is_training=True) + [t.obs for t in eval_data], + rnn_in, + is_training=True, + episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: - eval_results = {k: builder.get(v) for k, v in eval_results.items()} + for k, v in pending_fetches.items(): + eval_results[k] = builder.get(v) # Record the policy eval results for policy_id, eval_data in to_eval.items(): @@ -335,16 +363,16 @@ def _env_runner(async_vector_env, agent_id = eval_data[i].agent_id actions_to_send[env_id][agent_id] = action episode = active_episodes[env_id] - episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) - episode.set_last_pi_info( + episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode._set_last_pi_info( agent_id, {k: v[i] for k, v in pi_info_cols.items()}) if env_id in off_policy_actions and \ agent_id in off_policy_actions[env_id]: - episode.set_last_action( + episode._set_last_action( agent_id, off_policy_actions[env_id][agent_id]) else: - episode.set_last_action(agent_id, action) + episode._set_last_action(agent_id, action) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. @@ -362,68 +390,3 @@ def _get_or_raise(mapping, policy_id): "Could not find policy for agent: agent policy id `{}` not " "in policy map keys {}.".format(policy_id, mapping.keys())) return mapping[policy_id] - - -class _MultiAgentEpisode(object): - def __init__(self, policies, policy_mapping_fn, batch_builder_factory): - self.batch_builder = batch_builder_factory() - self.total_reward = 0.0 - self.length = 0 - self.episode_id = random.randrange(2e9) - self.agent_rewards = defaultdict(float) - self._policies = policies - self._policy_mapping_fn = policy_mapping_fn - self._agent_to_policy = {} - self._agent_to_rnn_state = {} - self._agent_to_last_obs = {} - self._agent_to_last_action = {} - self._agent_to_last_pi_info = {} - - def add_agent_rewards(self, reward_dict): - for agent_id, reward in reward_dict.items(): - if reward is not None: - self.agent_rewards[agent_id, - self.policy_for(agent_id)] += reward - self.total_reward += reward - - def policy_for(self, agent_id): - if agent_id not in self._agent_to_policy: - self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id) - return self._agent_to_policy[agent_id] - - def rnn_state_for(self, agent_id): - if agent_id not in self._agent_to_rnn_state: - policy = self._policies[self.policy_for(agent_id)] - self._agent_to_rnn_state[agent_id] = policy.get_initial_state() - return self._agent_to_rnn_state[agent_id] - - def last_observation_for(self, agent_id): - return self._agent_to_last_obs.get(agent_id) - - def last_action_for(self, agent_id): - action = self._agent_to_last_action[agent_id] - # Concatenate tuple actions - if isinstance(action, list): - expanded = [] - for a in action: - if len(a.shape) == 1: - expanded.append(np.expand_dims(a, 1)) - else: - expanded.append(a) - action = np.concatenate(expanded, axis=1).flatten() - return action - - def last_pi_info_for(self, agent_id): - return self._agent_to_last_pi_info[agent_id] - - def set_rnn_state(self, agent_id, rnn_state): - self._agent_to_rnn_state[agent_id] = rnn_state - - def set_last_observation(self, agent_id, obs): - self._agent_to_last_obs[agent_id] = obs - - def set_last_action(self, agent_id, action): - self._agent_to_last_action[agent_id] = action - - def set_last_pi_info(self, agent_id, pi_info): - self._agent_to_last_pi_info[agent_id] = pi_info diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index ce9e0803b..0c085c231 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -104,7 +104,8 @@ class TFPolicyGraph(PolicyGraph): builder, obs_batch, state_batches=None, - is_training=False): + is_training=False, + episodes=None): state_batches = state_batches or [] assert len(self._state_inputs) == len(state_batches), \ (self._state_inputs, state_batches) @@ -118,8 +119,11 @@ class TFPolicyGraph(PolicyGraph): [self.extra_compute_action_fetches()]) return fetches[0], fetches[1:-1], fetches[-1] - def compute_actions(self, obs_batch, state_batches=None, - is_training=False): + def compute_actions(self, + obs_batch, + state_batches=None, + is_training=False, + episodes=None): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self.build_compute_actions(builder, obs_batch, state_batches, is_training) diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 069ca2244..741357f3a 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -67,8 +67,11 @@ class TorchPolicyGraph(PolicyGraph): """Custom PyTorch optimizer to use.""" return torch.optim.Adam(self._model.parameters()) - def compute_actions(self, obs_batch, state_batches=None, - is_training=False): + def compute_actions(self, + obs_batch, + state_batches=None, + is_training=False, + episodes=None): if state_batches: raise NotImplementedError("Torch RNN support") with self.lock: diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 42a6de1c1..96eaabaf1 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -306,6 +306,51 @@ class TestMultiAgentEnv(unittest.TestCase): self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) + def testReturningModelBasedRolloutsData(self): + class ModelBasedPolicyGraph(PGPolicyGraph): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): + # Pretend we did a model-based rollout and want to return + # the extra trajectory. + builder = episodes[0].new_batch_builder() + rollout_id = random.randint(0, 10000) + for t in range(5): + builder.add_values( + agent_id="extra_0", + policy_id="p1", # use p1 so we can easily check it + t=t, + eps_id=rollout_id, # new id for each rollout + obs=obs_batch[0], + actions=0, + rewards=0, + dones=t == 4, + infos={}, + new_obs=obs_batch[0]) + batch = builder.build_and_reset() + episodes[0].add_extra_batch(batch) + + # Just return zeros for actions + return [0] * len(obs_batch), [], {} + + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + ev = PolicyEvaluator( + env_creator=lambda _: MultiCartpole(2), + policy_graph={ + "p0": (ModelBasedPolicyGraph, obs_space, act_space, {}), + "p1": (ModelBasedPolicyGraph, obs_space, act_space, {}), + }, + policy_mapping_fn=lambda agent_id: "p0", + batch_steps=5) + batch = ev.sample() + self.assertEqual(batch.count, 5) + self.assertEqual(batch.policy_batches["p0"].count, 10) + self.assertEqual(batch.policy_batches["p1"].count, 25) + def testTrainMultiCartpoleSinglePolicy(self): n = 10 register_env("multi_cartpole", lambda _: MultiCartpole(n)) diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index 1e70e8291..1aa559df6 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -17,7 +17,11 @@ from ray.tune.registry import register_env class MockPolicyGraph(PolicyGraph): - def compute_actions(self, obs_batch, state_batches, is_training=False): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): return [0] * len(obs_batch), [], {} def postprocess_trajectory(self, batch, other_agent_batches=None): @@ -25,7 +29,11 @@ class MockPolicyGraph(PolicyGraph): class BadPolicyGraph(PolicyGraph): - def compute_actions(self, obs_batch, state_batches, is_training=False): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): raise Exception("intentional error") def postprocess_trajectory(self, batch, other_agent_batches=None):