From 0b6112b726e10a66e63f44f1508e439aecf01faa Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 23 Jun 2018 18:32:16 -0700 Subject: [PATCH] [rllib] Part 1 of multiagent support: make sampler path support multiagent envs (#2268) This refactors the RLlib sampler to support multi-agent environments. The main changes were: AsyncVectorEnv now produces dicts of env_id -> agent_id -> value rather than env_id -> value. This lets it model both vectorized and multi-agent envs (or both). The sampler class operates over the above nested dict structure for all envs. Single agent envs just return a dict with one agent_id=single_agent. When sample() is called on a policy evaluator, in the single agent case we return a SampleBatch, otherwise we return a MultiAgentBatch (which is a list of sample batches per policy). Left for another PR: Exposing multi-agent in the public interfaces. Optimizations such as evaluating multiple policies in one TF run. --- python/ray/rllib/a3c/a3c.py | 3 +- python/ray/rllib/a3c/a3c_torch_policy.py | 1 + python/ray/rllib/optimizers/__init__.py | 4 +- python/ray/rllib/optimizers/sample_batch.py | 137 ++++++- python/ray/rllib/ppo/loss.py | 2 +- python/ray/rllib/ppo/ppo_evaluator.py | 5 +- .../test/test_common_policy_evaluator.py | 21 +- python/ray/rllib/test/test_multi_agent_env.py | 155 ++++++++ python/ray/rllib/test/test_serving_env.py | 20 +- python/ray/rllib/utils/async_vector_env.py | 241 +++++++++++-- .../rllib/utils/common_policy_evaluator.py | 55 +-- python/ray/rllib/utils/multi_agent_env.py | 60 ++++ python/ray/rllib/utils/policy_graph.py | 9 +- python/ray/rllib/utils/sampler.py | 339 ++++++++++-------- python/ray/rllib/utils/serving_env.py | 50 +-- python/ray/rllib/utils/vector_env.py | 24 +- 16 files changed, 849 insertions(+), 277 deletions(-) create mode 100644 python/ray/rllib/test/test_multi_agent_env.py create mode 100644 python/ray/rllib/utils/multi_agent_env.py diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 6375bc90d..04c9ce4df 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -154,7 +154,8 @@ class A3CAgent(Agent): def compute_action(self, observation, state=None): if state is None: state = [] - obs = self.local_evaluator.obs_filter(observation, update=False) + obs = self.local_evaluator.filters["default"]( + observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action( obs, state, is_training=False)[0]) diff --git a/python/ray/rllib/a3c/a3c_torch_policy.py b/python/ray/rllib/a3c/a3c_torch_policy.py index c4ff8d98d..a1cb5d866 100644 --- a/python/ray/rllib/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/a3c/a3c_torch_policy.py @@ -18,6 +18,7 @@ class SharedTorchPolicy(PolicyGraph): """A simple, non-recurrent PyTorch policy example.""" def __init__(self, obs_space, action_space, config): + PolicyGraph.__init__(self, obs_space, action_space, config) self.local_steps = 0 self.config = config self.summarize = config.get("summarize") diff --git a/python/ray/rllib/optimizers/__init__.py b/python/ray/rllib/optimizers/__init__.py index 95be536c0..9bcd38899 100644 --- a/python/ray/rllib/optimizers/__init__.py +++ b/python/ray/rllib/optimizers/__init__.py @@ -3,7 +3,7 @@ from ray.rllib.optimizers.async_optimizer import AsyncOptimizer from ray.rllib.optimizers.local_sync import LocalSyncOptimizer from ray.rllib.optimizers.local_sync_replay import LocalSyncReplayOptimizer from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer -from ray.rllib.optimizers.sample_batch import SampleBatch +from ray.rllib.optimizers.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \ TFMultiGPUSupport @@ -11,4 +11,4 @@ from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \ __all__ = [ "ApexOptimizer", "AsyncOptimizer", "LocalSyncOptimizer", "LocalSyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch", - "PolicyEvaluator", "TFMultiGPUSupport"] + "PolicyEvaluator", "TFMultiGPUSupport", "MultiAgentBatch"] diff --git a/python/ray/rllib/optimizers/sample_batch.py b/python/ray/rllib/optimizers/sample_batch.py index 1abab5b14..83df66aa2 100644 --- a/python/ray/rllib/optimizers/sample_batch.py +++ b/python/ray/rllib/optimizers/sample_batch.py @@ -14,7 +14,6 @@ class SampleBatchBuilder(object): """ def __init__(self): - self.postprocessed = [] self.buffers = collections.defaultdict(list) self.count = 0 @@ -25,29 +24,131 @@ class SampleBatchBuilder(object): self.buffers[k].append(v) self.count += 1 - def postprocess_batch_so_far(self, postprocessor): - """Apply the given postprocessor to any unprocessed rows.""" + def add_batch(self, batch): + """Add the given batch of values to this batch.""" - batch = postprocessor(self._build_buffers()) - self.postprocessed.append(batch) + for k, column in batch.items(): + self.buffers[k].extend(column) + self.count += batch.count - def build_and_reset(self, postprocessor): - """Returns a sample batch including all previously added values. + def build_and_reset(self): + """Returns a sample batch including all previously added values.""" - Any unprocessed rows will be first postprocessed with the given - postprocessor. The internal state of this builder will be reset. - """ - - self.postprocess_batch_so_far(postprocessor) - batch = SampleBatch.concat_samples(self.postprocessed) - self.postprocessed = [] + batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()}) + self.buffers.clear() self.count = 0 return batch - def _build_buffers(self): - batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()}) - self.buffers.clear() - return batch + +class MultiAgentSampleBatchBuilder(object): + """Util to build SampleBatches for each policy in a multi-agent env. + + Input data is per-agent, while output data is per-policy. There is an M:N + mapping between agents and policies. We retain one local batch builder + per agent. When an agent is done, then its local batch is appended into the + corresponding policy batch for the agent's policy. + """ + + def __init__(self, policy_map): + """Initialize a MultiAgentSampleBatchBuilder. + + Arguments: + policy_map (dict): Maps policy ids to policy graph instances. + """ + + self.policy_map = policy_map + self.policy_builders = { + k: SampleBatchBuilder() for k in policy_map.keys()} + self.agent_builders = {} + self.agent_to_policy = {} + self.count = 0 # increment this manually + + def has_pending_data(self): + """Returns whether there is pending unprocessed data.""" + + return len(self.agent_builders) > 0 + + def add_values(self, agent_id, policy_id, **values): + """Add the given dictionary (row) of values to this batch. + + Arguments: + agent_id (obj): Unique id for the agent we are adding values for. + policy_id (obj): Unique id for policy controlling the agent. + values (dict): Row of values to add for this agent. + """ + + if agent_id not in self.agent_builders: + self.agent_builders[agent_id] = SampleBatchBuilder() + self.agent_to_policy[agent_id] = policy_id + builder = self.agent_builders[agent_id] + builder.add_values(**values) + + def postprocess_batch_so_far(self): + """Apply policy postprocessors to any unprocessed rows. + + This pushes the postprocessed per-agent batches onto the per-policy + builders, clearing per-agent state. + """ + + # Materialize the batches so far + pre_batches = {} + for agent_id, builder in self.agent_builders.items(): + pre_batches[agent_id] = ( + self.policy_map[self.agent_to_policy[agent_id]], + builder.build_and_reset()) + + # Apply postprocessor + post_batches = {} + for agent_id, (_, pre_batch) in pre_batches.items(): + other_batches = pre_batches.copy() + del other_batches[agent_id] + policy = self.policy_map[self.agent_to_policy[agent_id]] + post_batches[agent_id] = policy.postprocess_trajectory( + pre_batch, other_batches) + + # Append into policy batches and reset + for agent_id, post_batch in post_batches.items(): + self.policy_builders[self.agent_to_policy[agent_id]].add_batch( + post_batch) + self.agent_builders.clear() + self.agent_to_policy.clear() + + def build_and_reset(self): + """Returns the accumulated sample batches for each policy. + + Any unprocessed rows will be first postprocessed with a policy + postprocessor. The internal state of this builder will be reset. + """ + + self.postprocess_batch_so_far() + policy_batches = {} + for policy_id, policy_batch_builder in self.policy_builders.items(): + policy_batches[policy_id] = policy_batch_builder.build_and_reset() + self.count = 0 + return MultiAgentBatch.wrap_as_needed(policy_batches) + + +class MultiAgentBatch(object): + def __init__(self, policy_batches): + self.policy_batches = policy_batches + + @staticmethod + def wrap_as_needed(batches): + if len(batches) == 1 and "default" in batches: + return batches["default"] + return MultiAgentBatch(batches) + + @staticmethod + def concat_samples(samples): + policy_batches = collections.defaultdict(list) + for s in samples: + assert isinstance(s, MultiAgentBatch) + for policy_id, batch in s.policy_batches.items(): + policy_batches[policy_id].append(batch) + out = {} + for policy_id, batches in policy_batches.items(): + out[policy_id] = SampleBatch.concat_samples(batches) + return MultiAgentBatch(out) class SampleBatch(object): diff --git a/python/ray/rllib/ppo/loss.py b/python/ray/rllib/ppo/loss.py index dd0e03f47..e40e03cb3 100644 --- a/python/ray/rllib/ppo/loss.py +++ b/python/ray/rllib/ppo/loss.py @@ -88,7 +88,7 @@ class ProximalPolicyGraph(object): feed_dict={self.observations: observations}) return action, [], {"vf_preds": vf, "logprobs": logprobs} - def postprocess_trajectory(self, batch): + def postprocess_trajectory(self, batch, other_agent_batches=None): return batch def get_initial_state(self): diff --git a/python/ray/rllib/ppo/ppo_evaluator.py b/python/ray/rllib/ppo/ppo_evaluator.py index 472e79dcd..2cac85cd5 100644 --- a/python/ray/rllib/ppo/ppo_evaluator.py +++ b/python/ray/rllib/ppo/ppo_evaluator.py @@ -79,8 +79,9 @@ class PPOEvaluator(TFMultiGPUSupport): self.filters = {"obs_filter": self.obs_filter, "rew_filter": self.rew_filter} self.sampler = SyncSampler( - self.env, self.common_policy, self.obs_filter, - self.config["horizon"], self.config["horizon"]) + self.env, {"default": self.common_policy}, lambda _: "default", + {"default": self.obs_filter}, self.config["horizon"], + self.config["horizon"]) def tf_loss_inputs(self): return self.inputs diff --git a/python/ray/rllib/test/test_common_policy_evaluator.py b/python/ray/rllib/test/test_common_policy_evaluator.py index 229734e97..10a31b098 100644 --- a/python/ray/rllib/test/test_common_policy_evaluator.py +++ b/python/ray/rllib/test/test_common_policy_evaluator.py @@ -20,7 +20,7 @@ class MockPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, is_training=False): return [0] * len(obs_batch), [], {} - def postprocess_trajectory(self, batch): + def postprocess_trajectory(self, batch, other_agent_batches=None): return compute_advantages(batch, 100.0, 0.9, use_gae=False) @@ -28,7 +28,7 @@ class BadPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, is_training=False): raise Exception("intentional error") - def postprocess_trajectory(self, batch): + def postprocess_trajectory(self, batch, other_agent_batches=None): return compute_advantages(batch, 100.0, 0.9, use_gae=False) @@ -211,6 +211,9 @@ class TestCommonPolicyEvaluator(unittest.TestCase): batch_mode="complete_episodes") batch = ev.sample() self.assertEqual(batch.count, 20) + self.assertEqual( + batch["t"].tolist(), + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) def testFilterSync(self): ev = CommonPolicyEvaluator( @@ -221,7 +224,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase): time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) - obs_f = filters["obs_filter"] + obs_f = filters["default"] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) @@ -235,8 +238,8 @@ class TestCommonPolicyEvaluator(unittest.TestCase): filters = ev.get_filters(flush_after=False) time.sleep(2) filters2 = ev.get_filters(flush_after=False) - obs_f = filters["obs_filter"] - obs_f2 = filters2["obs_filter"] + obs_f = filters["default"] + obs_f2 = filters2["default"] self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n) @@ -250,15 +253,15 @@ class TestCommonPolicyEvaluator(unittest.TestCase): # Current State filters = ev.get_filters(flush_after=False) - obs_f = filters["obs_filter"] + obs_f = filters["default"] self.assertLessEqual(obs_f.buffer.n, 20) new_obsf = obs_f.copy() new_obsf.rs._n = 100 - ev.sync_filters({"obs_filter": new_obsf}) + ev.sync_filters({"default": new_obsf}) filters = ev.get_filters(flush_after=False) - obs_f = filters["obs_filter"] + obs_f = filters["default"] self.assertGreaterEqual(obs_f.rs.n, 100) self.assertLessEqual(obs_f.buffer.n, 20) @@ -266,7 +269,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase): time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) - obs_f = filters["obs_filter"] + obs_f = filters["default"] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) return obs_f diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py new file mode 100644 index 000000000..2e8b8169c --- /dev/null +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -0,0 +1,155 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import ray +from ray.rllib.test.test_common_policy_evaluator import MockEnv +from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync +from ray.rllib.utils.multi_agent_env import MultiAgentEnv + + +class BasicMultiAgent(MultiAgentEnv): + """Env of N independent agents, each of which exits after 25 steps.""" + + def __init__(self, num): + self.agents = [MockEnv(25) for _ in range(num)] + self.dones = set() + + def reset(self): + self.dones = set() + return {i: a.reset() for i, a in enumerate(self.agents)} + + def step(self, action_dict): + obs, rew, done, info = {}, {}, {}, {} + for i, action in action_dict.items(): + obs[i], rew[i], done[i], info[i] = self.agents[i].step(action) + if done[i]: + self.dones.add(i) + done["__all__"] = len(self.dones) == len(self.agents) + return obs, rew, done, info + + +class RoundRobinMultiAgent(MultiAgentEnv): + """Env of N independent agents, each of which exits after 5 steps. + + On each step() of the env, only one agent takes an action.""" + + def __init__(self, num): + self.agents = [MockEnv(5) for _ in range(num)] + self.dones = set() + self.last_obs = {} + self.last_rew = {} + self.last_done = {} + self.last_info = {} + self.i = 0 + self.num = num + + def reset(self): + self.dones = set() + return {i: a.reset() for i, a in enumerate(self.agents)} + + def step(self, action_dict): + assert len(self.dones) != len(self.agents) + for i, action in action_dict.items(): + (self.last_obs[i], self.last_rew[i], self.last_done[i], + self.last_info[i]) = self.agents[i].step(action) + if self.last_done[i]: + self.dones.add(i) + obs = {self.i: self.last_obs[i]} + rew = {self.i: self.last_rew[i]} + done = {self.i: self.last_done[i]} + info = {self.i: self.last_info[i]} + self.i += 1 + self.i %= self.num + done["__all__"] = len(self.dones) == len(self.agents) + return obs, rew, done, info + + +class TestMultiAgentEnv(unittest.TestCase): + def testBasicMock(self): + env = BasicMultiAgent(4) + obs = env.reset() + self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0}) + for _ in range(24): + obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0}) + self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0}) + self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1}) + self.assertEqual( + done, + {0: False, 1: False, 2: False, 3: False, "__all__": False}) + obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0}) + self.assertEqual( + done, {0: True, 1: True, 2: True, 3: True, "__all__": True}) + + def testRoundRobinMock(self): + env = RoundRobinMultiAgent(2) + obs = env.reset() + self.assertEqual(obs, {0: 0, 1: 0}) + obs, rew, done, info = env.step({0: 0, 1: 0}) + self.assertEqual(obs, {0: 0}) + for _ in range(4): + obs, rew, done, info = env.step({0: 0}) + self.assertEqual(obs, {1: 0}) + self.assertEqual(done["__all__"], False) + obs, rew, done, info = env.step({1: 0}) + self.assertEqual(obs, {0: 0}) + self.assertEqual(done["__all__"], True) + + def testVectorizeBasic(self): + env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}}) + self.assertEqual( + dones, + {0: {0: False, 1: False, "__all__": False}, + 1: {0: False, 1: False, "__all__": False}}) + for _ in range(24): + env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}}) + self.assertEqual( + dones, + {0: {0: False, 1: False, "__all__": False}, + 1: {0: False, 1: False, "__all__": False}}) + env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual( + dones, + {0: {0: True, 1: True, "__all__": True}, + 1: {0: True, 1: True, "__all__": True}}) + + # Reset processing + self.assertRaises( + ValueError, + lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})) + self.assertEqual(env.try_reset(0), {0: 0, 1: 0}) + self.assertEqual(env.try_reset(1), {0: 0, 1: 0}) + env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}}) + self.assertEqual( + dones, + {0: {0: False, 1: False, "__all__": False}, + 1: {0: False, 1: False, "__all__": False}}) + + def testVectorizeRoundRobin(self): + env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) + self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}}) + env.send_actions({0: {0: 0}, 1: {0: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}}) + env.send_actions({0: {0: 0}, 1: {0: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}}) + + +if __name__ == '__main__': + ray.init() + unittest.main(verbosity=2) diff --git a/python/ray/rllib/test/test_serving_env.py b/python/ray/rllib/test/test_serving_env.py index 4c4613a6d..94b7f8673 100644 --- a/python/ray/rllib/test/test_serving_env.py +++ b/python/ray/rllib/test/test_serving_env.py @@ -60,16 +60,16 @@ class PartOffPolicyServing(ServingEnv): class SimpleOffPolicyServing(ServingEnv): - def __init__(self, env): + def __init__(self, env, fixed_action): ServingEnv.__init__(self, env.action_space, env.observation_space) self.env = env + self.fixed_action = fixed_action def run(self): eid = self.start_episode() obs = self.env.reset() while True: - # Take random actions - action = self.env.action_space.sample() + action = self.fixed_action self.log_action(eid, obs, action) obs, reward, done, info = self.env.step(action) self.log_returns(eid, reward, info=info) @@ -131,13 +131,15 @@ class TestServingEnv(unittest.TestCase): def testServingEnvOffPolicy(self): ev = CommonPolicyEvaluator( - env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)), + env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 50) + self.assertEqual(batch["actions"][0], 42) + self.assertEqual(batch["actions"][-1], 42) def testServingEnvBadActions(self): ev = CommonPolicyEvaluator( @@ -185,6 +187,16 @@ class TestServingEnv(unittest.TestCase): return raise Exception("failed to improve reward") + def testServingEnvHorizonNotSupported(self): + ev = CommonPolicyEvaluator( + env_creator=lambda _: SimpleServing(MockEnv(25)), + policy_graph=MockPolicyGraph, + episode_horizon=20, + batch_steps=10, + batch_mode="complete_episodes") + ev.sample() + self.assertRaises(Exception, lambda: ev.sample()) + if __name__ == '__main__': ray.init() diff --git a/python/ray/rllib/utils/async_vector_env.py b/python/ray/rllib/utils/async_vector_env.py index 3ef7c3e8b..266907a3a 100644 --- a/python/ray/rllib/utils/async_vector_env.py +++ b/python/ray/rllib/utils/async_vector_env.py @@ -2,66 +2,121 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from ray.rllib.utils.serving_env import ServingEnv +from ray.rllib.utils.vector_env import VectorEnv +from ray.rllib.utils.multi_agent_env import MultiAgentEnv + class AsyncVectorEnv(object): """The lowest-level env interface used by RLlib for sampling. - AsyncVectorEnv models multiple agents executing asynchronously. A call to - poll() returns observations from ready agents, and actions for those agents + AsyncVectorEnv models multiple agents executing asynchronously in multiple + environments. A call to poll() returns observations from ready agents + keyed by their environment and agent ids, and actions for those agents can be sent back via send_actions(). All other env types can be adapted to AsyncVectorEnv. RLlib handles these conversions internally in CommonPolicyEvaluator, for example: gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv + rllib.MultiAgentEnv => rllib.AsyncVectorEnv rllib.ServingEnv => rllib.AsyncVectorEnv Examples: >>> env = MyAsyncVectorEnv() >>> obs, rewards, dones, infos, off_policy_actions = env.poll() >>> print(obs) - {"car_0": [2.4, 1.6], "car_1": [3.4, -3.2]} - >>> env.send_actions({"car_0": 0, "car_1": 1}) + { + "env_0": { + "car_0": [2.4, 1.6], + "car_1": [3.4, -3.2], + } + } + >>> env.send_actions( + actions={ + "env_0": { + "car_0": 0, + "car_1": 1, + } + }) >>> obs, rewards, dones, infos, off_policy_actions = env.poll() >>> print(obs) - {"car_0": [4.1, 1.7], "car_1": [3.2, -4.2]} + { + "env_0": { + "car_0": [4.1, 1.7], + "car_1": [3.2, -4.2], + } + } + >>> print(dones) + { + "env_0": { + "__all__": False, + "car_0": False, + "car_1": True, + } + } """ + @staticmethod + def wrap_async(env, make_env=None, num_envs=1): + """Wraps any env type as needed to expose the async interface.""" + if not isinstance(env, AsyncVectorEnv): + if isinstance(env, MultiAgentEnv): + env = _MultiAgentEnvToAsync( + make_env=make_env, existing_envs=[env], num_envs=num_envs) + elif isinstance(env, ServingEnv): + if num_envs != 1: + raise ValueError( + "ServingEnv does not currently support num_envs > 1.") + env = _ServingEnvToAsync(env) + elif isinstance(env, VectorEnv): + env = _VectorEnvToAsync(env) + else: + env = VectorEnv.wrap( + make_env=make_env, existing_envs=[env], num_envs=num_envs) + env = _VectorEnvToAsync(env) + assert isinstance(env, AsyncVectorEnv) + return env + def poll(self): """Returns observations from ready agents. - The returns are dicts mapping from agent episode ids to values. The - number of agents can vary over time. + The returns are two-level dicts mapping from env_id to a dict of + agent_id to values. The number of agents and envs can vary over time. Returns: - obs (dict): New observations for each ready episode. - rewards (dict): Reward values for each ready episode. If the + obs (dict): New observations for each ready agent. + rewards (dict): Reward values for each ready agent. If the episode is just started, the value will be None. - dones (dict): Done values for each ready episode. If True, the - episode is terminated. - infos (dict): Info values for each ready episode. + dones (dict): Done values for each ready agent. The special key + "__all__" is used to indicate env termination. + infos (dict): Info values for each ready agent. off_policy_actions (dict): Agents may take off-policy actions. When that happens, there will be an entry in this dict that contains - the taken action. + the taken action. There is no need to send_actions() for agents + that have already chosen off-policy actions. """ raise NotImplementedError def send_actions(self, action_dict): """Called to send actions back to running agents in this env. + Actions should be sent for each ready agent that returned observations + in the previous poll() call. + Arguments: - action_dict (dict): Actions for each agent to take. + action_dict (dict): Actions values keyed by env_id and agent_id. """ raise NotImplementedError - def try_reset(self, agent_id): - """Attempt to reset the agent with the given id. + def try_reset(self, env_id): + """Attempt to reset the env with the given id. If the environment does not support synchronous reset, None can be returned here. Returns: - obs (obj|None): Resetted observation or None if not supported. + obs (dict|None): Resetted observation or None if not supported. """ return None @@ -74,8 +129,63 @@ class AsyncVectorEnv(object): return None +# Fixed agent identifier when there is only the single agent in the env +_DUMMY_AGENT_ID = "single_agent" + + +def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID): + return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()} + + +class _ServingEnvToAsync(AsyncVectorEnv): + """Internal adapter of ServingEnv to AsyncVectorEnv.""" + + def __init__(self, serving_env): + self.serving_env = serving_env + serving_env.start() + + def poll(self): + with self.serving_env._results_avail_condition: + results = self._poll() + while len(results[0]) == 0: + self.serving_env._results_avail_condition.wait() + results = self._poll() + if not self.serving_env.isAlive(): + raise Exception("Serving thread has stopped.") + limit = self.serving_env._max_concurrent_episodes + assert len(results[0]) < limit, \ + ("Too many concurrent episodes, were some leaked? This ServingEnv " + "was created with max_concurrent={}".format(limit)) + return results + + def _poll(self): + all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {} + off_policy_actions = {} + for eid, episode in self.serving_env._episodes.copy().items(): + data = episode.get_data() + if episode.cur_done: + del self.serving_env._episodes[eid] + if data: + all_obs[eid] = data["obs"] + all_rewards[eid] = data["reward"] + all_dones[eid] = data["done"] + all_infos[eid] = data["info"] + if "off_policy_action" in data: + off_policy_actions[eid] = data["off_policy_action"] + return _with_dummy_agent_id(all_obs), \ + _with_dummy_agent_id(all_rewards), \ + _with_dummy_agent_id(all_dones, "__all__"), \ + _with_dummy_agent_id(all_infos), \ + _with_dummy_agent_id(off_policy_actions) + + def send_actions(self, action_dict): + for eid, action in action_dict.items(): + self.serving_env._episodes[eid].action_queue.put( + action[_DUMMY_AGENT_ID]) + + class _VectorEnvToAsync(AsyncVectorEnv): - """Wraps VectorEnv to implement AsyncVectorEnv. + """Internal adapter of VectorEnv to AsyncVectorEnv. We assume the caller will always send the full vector of actions in each call to send_actions(), and that they call reset_at() on all completed @@ -99,17 +209,104 @@ class _VectorEnvToAsync(AsyncVectorEnv): self.cur_rewards = [] self.cur_dones = [] self.cur_infos = [] - return new_obs, rewards, dones, infos, {} + return _with_dummy_agent_id(new_obs), \ + _with_dummy_agent_id(rewards), \ + _with_dummy_agent_id(dones, "__all__"), \ + _with_dummy_agent_id(infos), {} def send_actions(self, action_dict): action_vector = [None] * self.num_envs for i in range(self.num_envs): - action_vector[i] = action_dict[i] + action_vector[i] = action_dict[i][_DUMMY_AGENT_ID] self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \ self.vector_env.vector_step(action_vector) - def try_reset(self, agent_id): - return self.vector_env.reset_at(agent_id) + def try_reset(self, env_id): + return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)} def get_unwrapped(self): return self.vector_env.get_unwrapped() + + +class _MultiAgentEnvToAsync(AsyncVectorEnv): + """Internal adapter of MultiAgentEnv to AsyncVectorEnv. + + This also supports vectorization if num_envs > 1. + """ + + def __init__(self, make_env, existing_envs, num_envs): + """Wrap existing multi-agent envs. + + Arguments: + make_env (func|None): Factory that produces a new multiagent env. + Must be defined if the number of existing envs is less than + num_envs. + existing_envs (list): List of existing multiagent envs. + num_envs (int): Desired num multiagent envs to keep total. + """ + self.make_env = make_env + self.envs = existing_envs + self.num_envs = num_envs + self.dones = set() + while len(self.envs) < self.num_envs: + self.envs.append(self.make_env()) + for env in self.envs: + assert isinstance(env, MultiAgentEnv) + self.env_states = [_MultiAgentEnvState(env) for env in self.envs] + + def poll(self): + obs, rewards, dones, infos = {}, {}, {}, {} + for i, env_state in enumerate(self.env_states): + obs[i], rewards[i], dones[i], infos[i] = env_state.poll() + return obs, rewards, dones, infos, {} + + def send_actions(self, action_dict): + for env_id, agent_dict in action_dict.items(): + if env_id in self.dones: + raise ValueError("Env {} is already done".format(env_id)) + env = self.envs[env_id] + obs, rewards, dones, infos = env.step(agent_dict) + if dones["__all__"]: + self.dones.add(env_id) + self.env_states[env_id].observe(obs, rewards, dones, infos) + + def try_reset(self, env_id): + obs = self.env_states[env_id].reset() + if obs is not None: + self.dones.remove(env_id) + return obs + + +class _MultiAgentEnvState(object): + def __init__(self, env): + assert isinstance(env, MultiAgentEnv) + self.env = env + self.reset() + + def poll(self): + if self.last_obs is None: + raise ValueError("Need to send action after polling") + obs, rew, dones, info = ( + self.last_obs, self.last_rewards, self.last_dones, self.last_infos) + self.last_obs = None + self.last_rewards = None + self.last_dones = None + self.last_infos = None + return obs, rew, dones, info + + def observe(self, obs, rewards, dones, infos): + self.last_obs = obs + self.last_rewards = rewards + self.last_dones = dones + self.last_infos = infos + + def reset(self): + self.last_obs = self.env.reset() + self.last_rewards = { + agent_id: None for agent_id in self.last_obs.keys()} + self.last_dones = { + agent_id: False for agent_id in self.last_obs.keys()} + self.last_infos = { + agent_id: {} for agent_id in self.last_obs.keys()} + self.last_dones["__all__"] = False + return self.last_obs diff --git a/python/ray/rllib/utils/common_policy_evaluator.py b/python/ray/rllib/utils/common_policy_evaluator.py index 95be18fce..c5b0e1e03 100644 --- a/python/ray/rllib/utils/common_policy_evaluator.py +++ b/python/ray/rllib/utils/common_policy_evaluator.py @@ -8,14 +8,15 @@ import tensorflow as tf import ray from ray.rllib.models import ModelCatalog -from ray.rllib.optimizers import SampleBatch +from ray.rllib.optimizers import MultiAgentBatch from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator -from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync +from ray.rllib.utils.async_vector_env import AsyncVectorEnv from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.utils.compression import pack from ray.rllib.utils.filter import get_filter +from ray.rllib.utils.multi_agent_env import MultiAgentEnv from ray.rllib.utils.sampler import AsyncSampler, SyncSampler -from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync +from ray.rllib.utils.serving_env import ServingEnv from ray.rllib.utils.tf_policy_graph import TFPolicyGraph from ray.rllib.utils.vector_env import VectorEnv from ray.tune.result import TrainingResult @@ -152,6 +153,7 @@ class CommonPolicyEvaluator(PolicyEvaluator): self.env = env_creator(env_config) if isinstance(self.env, VectorEnv) or \ isinstance(self.env, ServingEnv) or \ + isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, AsyncVectorEnv): def wrap(env): return env # we can't auto-wrap these env types @@ -169,8 +171,6 @@ class CommonPolicyEvaluator(PolicyEvaluator): def make_env(): return wrap(env_creator(env_config)) - self.policy_map = {} - if issubclass(policy_graph, TFPolicyGraph): with tf.Graph().as_default(): if tf_session_creator: @@ -186,25 +186,21 @@ class CommonPolicyEvaluator(PolicyEvaluator): policy = policy_graph( self.env.observation_space, self.env.action_space, policy_config) + self.policy_map = { "default": policy } - self.obs_filter = get_filter( - observation_filter, self.env.observation_space.shape) - self.filters = {"obs_filter": self.obs_filter} + self.filters = { + # TODO(ekl) make the obs space dependent on policy + policy_id: get_filter( + observation_filter, self.env.observation_space.shape) + for (policy_id, policy) in self.policy_map.items() + } # Always use vector env for consistency even if num_envs = 1 - if not isinstance(self.env, AsyncVectorEnv): - if isinstance(self.env, ServingEnv): - self.vector_env = _ServingEnvToAsync(self.env) - else: - if not isinstance(self.env, VectorEnv): - self.env = VectorEnv.wrap( - make_env, [self.env], num_envs=num_envs) - self.vector_env = _VectorEnvToAsync(self.env) - else: - self.vector_env = self.env + self.async_env = AsyncVectorEnv.wrap_async( + self.env, make_env=make_env, num_envs=num_envs) if self.batch_mode == "truncate_episodes": if batch_steps % num_envs != 0: @@ -222,19 +218,21 @@ class CommonPolicyEvaluator(PolicyEvaluator): "Unsupported batch mode: {}".format(self.batch_mode)) if sample_async: self.sampler = AsyncSampler( - self.vector_env, self.policy_map["default"], self.obs_filter, - batch_steps, horizon=episode_horizon, pack=pack_episodes) + self.async_env, self.policy_map, lambda agent_id: "default", + self.filters, batch_steps, horizon=episode_horizon, + pack=pack_episodes) self.sampler.start() else: self.sampler = SyncSampler( - self.vector_env, self.policy_map["default"], self.obs_filter, - batch_steps, horizon=episode_horizon, pack=pack_episodes) + self.async_env, self.policy_map, lambda agent_id: "default", + self.filters, batch_steps, horizon=episode_horizon, + pack=pack_episodes) def sample(self): """Evaluate the current policies and return a batch of experiences. Return: - SampleBatch from evaluating the current policies. + SampleBatch|MultiAgentBatch from evaluating the current policies. """ batches = [self.sampler.get_data()] @@ -243,11 +241,16 @@ class CommonPolicyEvaluator(PolicyEvaluator): batch = self.sampler.get_data() steps_so_far += batch.count batches.append(batch) - batch = SampleBatch.concat_samples(batches) + batch = batches[0].concat_samples(batches) if self.compress_observations: - batch["obs"] = [pack(o) for o in batch["obs"]] - batch["new_obs"] = [pack(o) for o in batch["new_obs"]] + if isinstance(batch, MultiAgentBatch): + for data in batch.policy_batches.values(): + data["obs"] = [pack(o) for o in data["obs"]] + data["new_obs"] = [pack(o) for o in data["new_obs"]] + else: + batch["obs"] = [pack(o) for o in batch["obs"]] + batch["new_obs"] = [pack(o) for o in batch["new_obs"]] return batch diff --git a/python/ray/rllib/utils/multi_agent_env.py b/python/ray/rllib/utils/multi_agent_env.py new file mode 100644 index 000000000..9a3015fff --- /dev/null +++ b/python/ray/rllib/utils/multi_agent_env.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class MultiAgentEnv(object): + """An environment that hosts multiple independent agents. + + Agents are identified by (string) agent ids. + + Examples: + >>> env = MyMultiAgentEnv() + >>> obs = env.reset() + >>> print(obs) + { + "car_0": [2.4, 1.6], + "car_1": [3.4, -3.2], + "traffic_light_1": [0, 3, 5, 1], + } + >>> obs, rewards, dones, infos = env.step( + action_dict={ + "car_0": 1, "car_1": 0, "traffic_light_1": 2, + }) + >>> print(rewards) + { + "car_0": 3, + "car_1": -1, + "traffic_light_1": 0, + } + >>> print(dones) + { + "car_0": False, + "car_1": True, + "__all__": False, + } + """ + + def reset(self): + """Resets the env and returns observations from ready agents. + + Returns: + obs (dict): New observations for each ready agent. + """ + raise NotImplementedError + + def step(self, action_dict): + """Returns observations from ready agents. + + The returns are dicts mapping from agent_id strings to values. The + number of agents in the env can vary over time. + + Returns: + obs (dict): New observations for each ready agent. + rewards (dict): Reward values for each ready agent. If the + episode is just started, the value will be None. + dones (dict): Done values for each ready agent. The special key + "__all__" is used to indicate env termination. + infos (dict): Info values for each ready agent. + """ + raise NotImplementedError diff --git a/python/ray/rllib/utils/policy_graph.py b/python/ray/rllib/utils/policy_graph.py index 45f48684c..91272a75a 100644 --- a/python/ray/rllib/utils/policy_graph.py +++ b/python/ray/rllib/utils/policy_graph.py @@ -25,7 +25,9 @@ class PolicyGraph(object): action_space (gym.Space): Action space of the env. config (dict): Policy-specific configuration data. """ - pass + + self.observation_space = observation_space + self.action_space = action_space def compute_actions(self, obs_batch, state_batches, is_training=False): """Compute actions for the current policy. @@ -70,8 +72,9 @@ class PolicyGraph(object): Arguments: sample_batch (SampleBatch): batch of experiences for the policy, which will contain at most one episode trajectory. - other_agent_batches (dict): In a multi-agent env, this contains the - experience batches seen by other agents. + other_agent_batches (dict): In a multi-agent env, this contains a + mapping of agent ids to (policy_graph, agent_batch) tuples + containing the policy graph and experiences of the other agent. Returns: SampleBatch: postprocessed sample batch. diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index 3a3364f8e..0a0aa36a2 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -7,13 +7,16 @@ import numpy as np import six.moves.queue as queue import threading -from ray.rllib.optimizers.sample_batch import SampleBatchBuilder -from ray.rllib.utils.vector_env import VectorEnv -from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync +from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \ + MultiAgentBatch +from ray.rllib.utils.async_vector_env import AsyncVectorEnv -CompletedRollout = namedtuple("CompletedRollout", - ["episode_length", "episode_reward"]) +RolloutMetrics = namedtuple( + "RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"]) + +PolicyEvalData = namedtuple( + "PolicyEvalData", ["env_id", "agent_id", "obs", "rnn_state"]) class SyncSampler(object): @@ -26,26 +29,23 @@ class SyncSampler(object): thread.""" def __init__( - self, env, policy, obs_filter, num_local_steps, - horizon=None, pack=False): - if not isinstance(env, AsyncVectorEnv): - if not isinstance(env, VectorEnv): - env = VectorEnv.wrap(make_env=None, existing_envs=[env]) - env = _VectorEnvToAsync(env) - self.async_vector_env = env + self, env, policies, policy_mapping_fn, obs_filters, + num_local_steps, horizon=None, pack=False): + self.async_vector_env = AsyncVectorEnv.wrap_async(env) self.num_local_steps = num_local_steps self.horizon = horizon - self.policy = policy - self._obs_filter = obs_filter - self.rollout_provider = _env_runner(self.async_vector_env, self.policy, - self.num_local_steps, self.horizon, - self._obs_filter, pack) + self.policies = policies + self.policy_mapping_fn = policy_mapping_fn + self._obs_filters = obs_filters + self.rollout_provider = _env_runner( + self.async_vector_env, self.policies, self.policy_mapping_fn, + self.num_local_steps, self.horizon, self._obs_filters, pack) self.metrics_queue = queue.Queue() def get_data(self): while True: item = next(self.rollout_provider) - if isinstance(item, CompletedRollout): + if isinstance(item, RolloutMetrics): self.metrics_queue.put(item) else: return item @@ -67,23 +67,20 @@ class AsyncSampler(threading.Thread): accumulate and the gradient can be calculated on up to 5 batches.""" def __init__( - self, env, policy, obs_filter, num_local_steps, - horizon=None, pack=False): - assert getattr( - obs_filter, "is_concurrent", - False), ("Observation Filter must support concurrent updates.") - if not isinstance(env, AsyncVectorEnv): - if not isinstance(env, VectorEnv): - env = VectorEnv.wrap(make_env=None, existing_envs=[env]) - env = _VectorEnvToAsync(env) - self.async_vector_env = env + self, env, policies, policy_mapping_fn, obs_filters, + num_local_steps, horizon=None, pack=False): + for _, f in obs_filters.items(): + assert getattr(f, "is_concurrent", False), \ + "Observation Filter must support concurrent updates." + self.async_vector_env = AsyncVectorEnv.wrap_async(env) threading.Thread.__init__(self) self.queue = queue.Queue(5) self.metrics_queue = queue.Queue() self.num_local_steps = num_local_steps self.horizon = horizon - self.policy = policy - self._obs_filter = obs_filter + self.policies = policies + self.policy_mapping_fn = policy_mapping_fn + self._obs_filters = obs_filters self.daemon = True self.pack = pack @@ -95,15 +92,15 @@ class AsyncSampler(threading.Thread): raise e def _run(self): - rollout_provider = _env_runner(self.async_vector_env, self.policy, - self.num_local_steps, self.horizon, - self._obs_filter, self.pack) + rollout_provider = _env_runner( + self.async_vector_env, self.policies, self.policy_mapping_fn, + self.num_local_steps, self.horizon, self._obs_filters, self.pack) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is # set to some large number. This is an empirical observation. item = next(rollout_provider) - if isinstance(item, CompletedRollout): + if isinstance(item, RolloutMetrics): self.metrics_queue.put(item) else: self.queue.put(item, timeout=600.0) @@ -115,8 +112,9 @@ class AsyncSampler(threading.Thread): if isinstance(rollout, BaseException): raise rollout - # We can't auto-concat rollouts in vector mode - if self.async_vector_env.num_envs > 1: + # We can't auto-concat rollouts in these modes + if self.async_vector_env.num_envs > 1 or \ + isinstance(rollout, MultiAgentBatch): return rollout # Auto-concat rollouts; TODO(ekl) is this important for A3C perf? @@ -141,23 +139,22 @@ class AsyncSampler(threading.Thread): def _env_runner( - async_vector_env, policy, num_local_steps, horizon, obs_filter, pack): - """This implements the logic of the thread runner. - - It continually runs the policy, and as long as the rollout exceeds a - certain length, the thread runner appends the policy to the queue. Yields - when `timestep_limit` is surpassed, environment terminates, or - `num_local_steps` is reached. + async_vector_env, policies, policy_mapping_fn, num_local_steps, + horizon, obs_filters, pack): + """This implements the common experience collection logic. Args: - async_vector_env: env implementing AsyncVectorEnv. - policy: Policy used to interact with environment. Also sets fields - to be included in `SampleBatch`. - num_local_steps: Number of steps before `SampleBatch` is yielded. Set - to infinity to yield complete episodes. - horizon: Horizon of the episode. - obs_filter: Filter used to process observations. - pack: Whether to pack multiple episodes into each batch. This + async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv. + policies (dict): Map of policy ids to PolicyGraph instances. + policy_mapping_fn (func): Function that maps agent ids to policy ids. + This is called when an agent first enters the environment. The + agent is then "bound" to the returned policy for the episode. + num_local_steps (int): Number of episode steps before `SampleBatch` is + yielded. Set to infinity to yield complete episodes. + horizon (int): Horizon of the episode. + obs_filters (dict): Map of policy id to filter used to process + observations for the policy. + pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `num_local_steps` in size. Yields: @@ -181,110 +178,131 @@ def _env_runner( if batch_builder_pool: return batch_builder_pool.pop() else: - return SampleBatchBuilder() + return MultiAgentSampleBatchBuilder(policies) - episodes = defaultdict( - lambda: _Episode(policy.get_initial_state(), get_batch_builder)) + active_episodes = defaultdict( + lambda: _MultiAgentEpisode( + policies, policy_mapping_fn, get_batch_builder)) while True: - # Get observations from ready envs + # Get observations from all ready agents unfiltered_obs, rewards, dones, infos, off_policy_actions = \ async_vector_env.poll() - ready_eids = [] - ready_obs = [] - ready_rnn_states = [] - # Process and record the new observations - for eid, raw_obs in unfiltered_obs.items(): - episode = episodes[eid] - filtered_obs = obs_filter(raw_obs) - ready_eids.append(eid) - ready_obs.append(filtered_obs) - ready_rnn_states.append(episode.rnn_state) + # Map of policy_id to list of PolicyEvalData + to_eval = defaultdict(list) - if episode.last_observation is None: - episode.last_observation = filtered_obs - continue # This is the initial observation after a reset + # For each environment + for env_id, agent_obs in unfiltered_obs.items(): + new_episode = env_id not in active_episodes + episode = active_episodes[env_id] + if not new_episode: + episode.length += 1 + episode.batch_builder.count += 1 + episode.add_agent_rewards(rewards[env_id]) - episode.length += 1 - episode.total_reward += rewards[eid] - - # Handle episode terminations - if dones[eid] or episode.length >= horizon: - done = True - yield CompletedRollout(episode.length, episode.total_reward) + # Check episode termination conditions + if dones[env_id]["__all__"] or episode.length >= horizon: + all_done = True + yield RolloutMetrics( + episode.length, episode.total_reward, + dict(episode.agent_rewards)) else: - done = False + all_done = False - if infos[eid].get("training_enabled", True): - episode.batch_builder.add_values( - obs=episode.last_observation, - actions=episode.last_action_flat(), - rewards=rewards[eid], - dones=done, - new_obs=filtered_obs, - **episode.last_pi_info) + # For each agent in the environment + for agent_id, raw_obs in agent_obs.items(): + policy_id = episode.policy_for(agent_id) + filtered_obs = obs_filters[policy_id](raw_obs) + agent_done = bool(all_done or dones[env_id].get(agent_id)) + if not agent_done: + to_eval[policy_id].append( + PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id))) - # Cut the batch if we're not packing multiple episodes into - # one, or if we've exceeded the requested batch size. - if (done and not pack) or \ + last_observation = episode.last_observation_for(agent_id) + episode.set_last_observation(agent_id, filtered_obs) + + # Record transition info if applicable + if last_observation is not None and \ + infos[env_id][agent_id].get("training_enabled", True): + episode.batch_builder.add_values( + agent_id, + policy_id, + t=episode.length - 1, + obs=last_observation, + actions=episode.last_action_for(agent_id), + rewards=rewards[env_id][agent_id], + dones=agent_done, + infos=infos[env_id][agent_id], + new_obs=filtered_obs, + **episode.last_pi_info_for(agent_id)) + + # Cut the batch if we're not packing multiple episodes into one, + # or if we've exceeded the requested batch size. + if episode.batch_builder.has_pending_data(): + if (all_done and not pack) or \ episode.batch_builder.count >= num_local_steps: - yield episode.batch_builder.build_and_reset( - policy.postprocess_trajectory) - elif done: - # Make sure postprocessor never crosses episode boundaries - episode.batch_builder.postprocess_batch_so_far( - policy.postprocess_trajectory) + yield episode.batch_builder.build_and_reset() + elif all_done: + # Make sure postprocessor stays within one episode + episode.batch_builder.postprocess_batch_so_far() - if done: + if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) - del episodes[eid] - resetted_obs = async_vector_env.try_reset(eid) + del active_episodes[env_id] + resetted_obs = async_vector_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list assert horizon == float("inf"), \ "Setting episode horizon requires reset() support." - ready_eids.pop() - ready_obs.pop() - ready_rnn_states.pop() else: - # Reset successful, put in the new obs as ready - episode = episodes[eid] - episode.last_observation = obs_filter(resetted_obs) - ready_obs[-1] = episode.last_observation - ready_rnn_states[-1] = episode.rnn_state - else: - episode.last_observation = filtered_obs + # Creates a new episode + episode = active_episodes[env_id] + for agent_id, raw_obs in resetted_obs.items(): + policy_id = episode.policy_for(agent_id) + filtered_obs = obs_filters[policy_id](raw_obs) + episode.set_last_observation(agent_id, filtered_obs) + to_eval[policy_id].append( + PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id))) - if not ready_eids: - continue # No actions to take + # Map of env_id -> agent_id -> action + action_dict = defaultdict(dict) - # Compute action for ready envs - ready_rnn_state_cols = _to_column_format(ready_rnn_states) - actions, new_rnn_state_cols, pi_info_cols = policy.compute_actions( - ready_obs, ready_rnn_state_cols, is_training=True) - - # Add RNN state info - for f_i, column in enumerate(ready_rnn_state_cols): - pi_info_cols["state_in_{}".format(f_i)] = column - for f_i, column in enumerate(new_rnn_state_cols): - pi_info_cols["state_out_{}".format(f_i)] = column + # TODO(ekl) fuse all policy evaluation into one TF run + for policy_id, eval_data in to_eval.items(): + rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) + actions, rnn_out_cols, pi_info_cols = \ + policies[policy_id].compute_actions( + [t.obs for t in eval_data], rnn_in_cols, is_training=True) + # Add RNN state info + for f_i, column in enumerate(rnn_in_cols): + pi_info_cols["state_in_{}".format(f_i)] = column + for f_i, column in enumerate(rnn_out_cols): + pi_info_cols["state_out_{}".format(f_i)] = column + # Save output rows + for i, action in enumerate(actions): + env_id = eval_data[i].env_id + agent_id = eval_data[i].agent_id + action_dict[env_id][agent_id] = action + episode = active_episodes[env_id] + episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode.set_last_pi_info( + agent_id, {k: v[i] for k, v in pi_info_cols.items()}) + if env_id in off_policy_actions and \ + agent_id in off_policy_actions[env_id]: + episode.set_last_action( + agent_id, off_policy_actions[env_id][agent_id]) + else: + episode.set_last_action(agent_id, action) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. - async_vector_env.send_actions(dict(zip(ready_eids, actions))) - - # Store the computed action info - for i, eid in enumerate(ready_eids): - episode = episodes[eid] - if eid in off_policy_actions: - episode.last_action = off_policy_actions[eid] - else: - episode.last_action = actions[i] - episode.rnn_state = [column[i] for column in new_rnn_state_cols] - episode.last_pi_info = { - k: column[i] for k, column in pi_info_cols.items()} + async_vector_env.send_actions(dict(action_dict)) def _to_column_format(rnn_state_rows): @@ -293,18 +311,57 @@ def _to_column_format(rnn_state_rows): [row[i] for row in rnn_state_rows] for i in range(num_cols)] -class _Episode(object): - def __init__(self, init_rnn_state, batch_builder_factory): - self.rnn_state = init_rnn_state +class _MultiAgentEpisode(object): + def __init__(self, policies, policy_mapping_fn, batch_builder_factory): self.batch_builder = batch_builder_factory() - self.last_action = None - self.last_observation = None - self.last_pi_info = None self.total_reward = 0.0 self.length = 0 + self.agent_rewards = defaultdict(float) + self._policies = policies + self._policy_mapping_fn = policy_mapping_fn + self._agent_to_policy = {} + self._agent_to_rnn_state = {} + self._agent_to_last_obs = {} + self._agent_to_last_action = {} + self._agent_to_last_pi_info = {} - def last_action_flat(self): - # Concatenate multiagent actions - if isinstance(self.last_action, list): - return np.concatenate(self.last_action, axis=0).flatten() - return self.last_action + def add_agent_rewards(self, reward_dict): + for agent_id, reward in reward_dict.items(): + self.agent_rewards[agent_id] += reward + self.total_reward += reward + + def policy_for(self, agent_id): + if agent_id not in self._agent_to_policy: + self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id) + return self._agent_to_policy[agent_id] + + def rnn_state_for(self, agent_id): + if agent_id not in self._agent_to_rnn_state: + policy = self._policies[self.policy_for(agent_id)] + self._agent_to_rnn_state[agent_id] = policy.get_initial_state() + return self._agent_to_rnn_state[agent_id] + + def last_observation_for(self, agent_id): + return self._agent_to_last_obs.get(agent_id) + + def last_action_for(self, agent_id): + action = self._agent_to_last_action[agent_id] + # Concatenate tuple actions + if isinstance(action, list): + action = np.concatenate(action, axis=0).flatten() + return action + + def last_pi_info_for(self, agent_id): + return self._agent_to_last_pi_info[agent_id] + + def set_rnn_state(self, agent_id, rnn_state): + self._agent_to_rnn_state[agent_id] = rnn_state + + def set_last_observation(self, agent_id, obs): + self._agent_to_last_obs[agent_id] = obs + + def set_last_action(self, agent_id, action): + self._agent_to_last_action[agent_id] = action + + def set_last_pi_info(self, agent_id, pi_info): + self._agent_to_last_pi_info[agent_id] = pi_info diff --git a/python/ray/rllib/utils/serving_env.py b/python/ray/rllib/utils/serving_env.py index 827a725b3..d3928536b 100644 --- a/python/ray/rllib/utils/serving_env.py +++ b/python/ray/rllib/utils/serving_env.py @@ -6,11 +6,9 @@ from six.moves import queue import threading import uuid -from ray.rllib.utils.async_vector_env import AsyncVectorEnv - class ServingEnv(threading.Thread): - """Environment that provides policy serving. + """An environment that provides policy serving. Unlike simulator envs, control is inverted. The environment queries the policy to obtain actions and logs observations and rewards for training. @@ -91,7 +89,7 @@ class ServingEnv(threading.Thread): raise ValueError( "Episode {} is already started".format(episode_id)) - self._episodes[episode_id] = _Episode( + self._episodes[episode_id] = _ServingEnvEpisode( episode_id, self._results_avail_condition, training_enabled) return episode_id @@ -165,49 +163,7 @@ class ServingEnv(threading.Thread): return self._episodes[episode_id] -class _ServingEnvToAsync(AsyncVectorEnv): - """Internal adapter of ServingEnv to AsyncVectorEnv.""" - - def __init__(self, serving_env): - self.serving_env = serving_env - serving_env.start() - - def poll(self): - with self.serving_env._results_avail_condition: - results = self._poll() - while len(results[0]) == 0: - self.serving_env._results_avail_condition.wait() - results = self._poll() - if not self.serving_env.isAlive(): - raise Exception("Serving thread has stopped.") - limit = self.serving_env._max_concurrent_episodes - assert len(results[0]) < limit, \ - ("Too many concurrent episodes, were some leaked? This ServingEnv " - "was created with max_concurrent={}".format(limit)) - return results - - def _poll(self): - all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {} - off_policy_actions = {} - for eid, episode in self.serving_env._episodes.copy().items(): - data = episode.get_data() - if episode.cur_done: - del self.serving_env._episodes[eid] - if data: - all_obs[eid] = data["obs"] - all_rewards[eid] = data["reward"] - all_dones[eid] = data["done"] - all_infos[eid] = data["info"] - if "off_policy_action" in data: - off_policy_actions[eid] = data["off_policy_action"] - return all_obs, all_rewards, all_dones, all_infos, off_policy_actions - - def send_actions(self, action_dict): - for eid, action in action_dict.items(): - self.serving_env._episodes[eid].action_queue.put(action) - - -class _Episode(object): +class _ServingEnvEpisode(object): """Tracked state for each active episode.""" def __init__(self, episode_id, results_avail_condition, training_enabled): diff --git a/python/ray/rllib/utils/vector_env.py b/python/ray/rllib/utils/vector_env.py index e9d655ba9..926048c48 100644 --- a/python/ray/rllib/utils/vector_env.py +++ b/python/ray/rllib/utils/vector_env.py @@ -22,22 +22,44 @@ class VectorEnv(object): return _VectorizedGymEnv(make_env, existing_envs or [], num_envs) def vector_reset(self): + """Resets all environments. + + Returns: + obs (list): Vector of observations from each environment. + """ raise NotImplementedError def reset_at(self, index): + """Resets a single environment. + + Returns: + obs (obj): Observations from the resetted environment. + """ raise NotImplementedError def vector_step(self, actions): + """Vectorized step. + + Arguments: + actions (list): Actions for each env. + + Returns: + obs (list): New observations for each env. + rewards (list): Reward values for each env. + dones (list): Done values for each env. + infos (list): Info values for each env. + """ raise NotImplementedError def get_unwrapped(self): + """Returns a single instance of the underlying env.""" raise NotImplementedError class _VectorizedGymEnv(VectorEnv): """Internal wrapper for gym envs to implement VectorEnv. - Arguents: + Arguments: make_env (func|None): Factory that produces a new gym env. Must be defined if the number of existing envs is less than num_envs. existing_envs (list): List of existing gym envs.