[rllib] Part 1 of multiagent support: make sampler path support multiagent envs (#2268)

This refactors the RLlib sampler to support multi-agent environments. The main changes were:

AsyncVectorEnv now produces dicts of env_id -> agent_id -> value rather than env_id -> value. This lets it model both vectorized and multi-agent envs (or both).
The sampler class operates over the above nested dict structure for all envs. Single agent envs just return a dict with one agent_id=single_agent.
When sample() is called on a policy evaluator, in the single agent case we return a SampleBatch, otherwise we return a MultiAgentBatch (which is a list of sample batches per policy).
Left for another PR:

Exposing multi-agent in the public interfaces.
Optimizations such as evaluating multiple policies in one TF run.
This commit is contained in:
Eric Liang
2018-06-23 18:32:16 -07:00
committed by GitHub
parent 9c3bab5c42
commit 0b6112b726
16 changed files with 849 additions and 277 deletions
+2 -1
View File
@@ -154,7 +154,8 @@ class A3CAgent(Agent):
def compute_action(self, observation, state=None):
if state is None:
state = []
obs = self.local_evaluator.obs_filter(observation, update=False)
obs = self.local_evaluator.filters["default"](
observation, update=False)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
obs, state, is_training=False)[0])
+1
View File
@@ -18,6 +18,7 @@ class SharedTorchPolicy(PolicyGraph):
"""A simple, non-recurrent PyTorch policy example."""
def __init__(self, obs_space, action_space, config):
PolicyGraph.__init__(self, obs_space, action_space, config)
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
+2 -2
View File
@@ -3,7 +3,7 @@ from ray.rllib.optimizers.async_optimizer import AsyncOptimizer
from ray.rllib.optimizers.local_sync import LocalSyncOptimizer
from ray.rllib.optimizers.local_sync_replay import LocalSyncReplayOptimizer
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
from ray.rllib.optimizers.sample_batch import SampleBatch
from ray.rllib.optimizers.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \
TFMultiGPUSupport
@@ -11,4 +11,4 @@ from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \
__all__ = [
"ApexOptimizer", "AsyncOptimizer", "LocalSyncOptimizer",
"LocalSyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch",
"PolicyEvaluator", "TFMultiGPUSupport"]
"PolicyEvaluator", "TFMultiGPUSupport", "MultiAgentBatch"]
+119 -18
View File
@@ -14,7 +14,6 @@ class SampleBatchBuilder(object):
"""
def __init__(self):
self.postprocessed = []
self.buffers = collections.defaultdict(list)
self.count = 0
@@ -25,29 +24,131 @@ class SampleBatchBuilder(object):
self.buffers[k].append(v)
self.count += 1
def postprocess_batch_so_far(self, postprocessor):
"""Apply the given postprocessor to any unprocessed rows."""
def add_batch(self, batch):
"""Add the given batch of values to this batch."""
batch = postprocessor(self._build_buffers())
self.postprocessed.append(batch)
for k, column in batch.items():
self.buffers[k].extend(column)
self.count += batch.count
def build_and_reset(self, postprocessor):
"""Returns a sample batch including all previously added values.
def build_and_reset(self):
"""Returns a sample batch including all previously added values."""
Any unprocessed rows will be first postprocessed with the given
postprocessor. The internal state of this builder will be reset.
"""
self.postprocess_batch_so_far(postprocessor)
batch = SampleBatch.concat_samples(self.postprocessed)
self.postprocessed = []
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
self.buffers.clear()
self.count = 0
return batch
def _build_buffers(self):
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
self.buffers.clear()
return batch
class MultiAgentSampleBatchBuilder(object):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
mapping between agents and policies. We retain one local batch builder
per agent. When an agent is done, then its local batch is appended into the
corresponding policy batch for the agent's policy.
"""
def __init__(self, policy_map):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy graph instances.
"""
self.policy_map = policy_map
self.policy_builders = {
k: SampleBatchBuilder() for k in policy_map.keys()}
self.agent_builders = {}
self.agent_to_policy = {}
self.count = 0 # increment this manually
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
return len(self.agent_builders) > 0
def add_values(self, agent_id, policy_id, **values):
"""Add the given dictionary (row) of values to this batch.
Arguments:
agent_id (obj): Unique id for the agent we are adding values for.
policy_id (obj): Unique id for policy controlling the agent.
values (dict): Row of values to add for this agent.
"""
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
builder = self.agent_builders[agent_id]
builder.add_values(**values)
def postprocess_batch_so_far(self):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
"""
# Materialize the batches so far
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
self.policy_map[self.agent_to_policy[agent_id]],
builder.build_and_reset())
# Apply postprocessor
post_batches = {}
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches)
# Append into policy batches and reset
for agent_id, post_batch in post_batches.items():
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
self.agent_to_policy.clear()
def build_and_reset(self):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
"""
self.postprocess_batch_so_far()
policy_batches = {}
for policy_id, policy_batch_builder in self.policy_builders.items():
policy_batches[policy_id] = policy_batch_builder.build_and_reset()
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches)
class MultiAgentBatch(object):
def __init__(self, policy_batches):
self.policy_batches = policy_batches
@staticmethod
def wrap_as_needed(batches):
if len(batches) == 1 and "default" in batches:
return batches["default"]
return MultiAgentBatch(batches)
@staticmethod
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out)
class SampleBatch(object):
+1 -1
View File
@@ -88,7 +88,7 @@ class ProximalPolicyGraph(object):
feed_dict={self.observations: observations})
return action, [], {"vf_preds": vf, "logprobs": logprobs}
def postprocess_trajectory(self, batch):
def postprocess_trajectory(self, batch, other_agent_batches=None):
return batch
def get_initial_state(self):
+3 -2
View File
@@ -79,8 +79,9 @@ class PPOEvaluator(TFMultiGPUSupport):
self.filters = {"obs_filter": self.obs_filter,
"rew_filter": self.rew_filter}
self.sampler = SyncSampler(
self.env, self.common_policy, self.obs_filter,
self.config["horizon"], self.config["horizon"])
self.env, {"default": self.common_policy}, lambda _: "default",
{"default": self.obs_filter}, self.config["horizon"],
self.config["horizon"])
def tf_loss_inputs(self):
return self.inputs
@@ -20,7 +20,7 @@ class MockPolicyGraph(PolicyGraph):
def compute_actions(self, obs_batch, state_batches, is_training=False):
return [0] * len(obs_batch), [], {}
def postprocess_trajectory(self, batch):
def postprocess_trajectory(self, batch, other_agent_batches=None):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
@@ -28,7 +28,7 @@ class BadPolicyGraph(PolicyGraph):
def compute_actions(self, obs_batch, state_batches, is_training=False):
raise Exception("intentional error")
def postprocess_trajectory(self, batch):
def postprocess_trajectory(self, batch, other_agent_batches=None):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
@@ -211,6 +211,9 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
batch_mode="complete_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 20)
self.assertEqual(
batch["t"].tolist(),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def testFilterSync(self):
ev = CommonPolicyEvaluator(
@@ -221,7 +224,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters["obs_filter"]
obs_f = filters["default"]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
@@ -235,8 +238,8 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
filters = ev.get_filters(flush_after=False)
time.sleep(2)
filters2 = ev.get_filters(flush_after=False)
obs_f = filters["obs_filter"]
obs_f2 = filters2["obs_filter"]
obs_f = filters["default"]
obs_f2 = filters2["default"]
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
@@ -250,15 +253,15 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
# Current State
filters = ev.get_filters(flush_after=False)
obs_f = filters["obs_filter"]
obs_f = filters["default"]
self.assertLessEqual(obs_f.buffer.n, 20)
new_obsf = obs_f.copy()
new_obsf.rs._n = 100
ev.sync_filters({"obs_filter": new_obsf})
ev.sync_filters({"default": new_obsf})
filters = ev.get_filters(flush_after=False)
obs_f = filters["obs_filter"]
obs_f = filters["default"]
self.assertGreaterEqual(obs_f.rs.n, 100)
self.assertLessEqual(obs_f.buffer.n, 20)
@@ -266,7 +269,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters["obs_filter"]
obs_f = filters["default"]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
return obs_f
@@ -0,0 +1,155 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import ray
from ray.rllib.test.test_common_policy_evaluator import MockEnv
from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""
def __init__(self, num):
self.agents = [MockEnv(25) for _ in range(num)]
self.dones = set()
def reset(self):
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
def step(self, action_dict):
obs, rew, done, info = {}, {}, {}, {}
for i, action in action_dict.items():
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
if done[i]:
self.dones.add(i)
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
class RoundRobinMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 5 steps.
On each step() of the env, only one agent takes an action."""
def __init__(self, num):
self.agents = [MockEnv(5) for _ in range(num)]
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
self.num = num
def reset(self):
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
def step(self, action_dict):
assert len(self.dones) != len(self.agents)
for i, action in action_dict.items():
(self.last_obs[i], self.last_rew[i], self.last_done[i],
self.last_info[i]) = self.agents[i].step(action)
if self.last_done[i]:
self.dones.add(i)
obs = {self.i: self.last_obs[i]}
rew = {self.i: self.last_rew[i]}
done = {self.i: self.last_done[i]}
info = {self.i: self.last_info[i]}
self.i += 1
self.i %= self.num
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
class TestMultiAgentEnv(unittest.TestCase):
def testBasicMock(self):
env = BasicMultiAgent(4)
obs = env.reset()
self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
for _ in range(24):
obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1})
self.assertEqual(
done,
{0: False, 1: False, 2: False, 3: False, "__all__": False})
obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
self.assertEqual(
done, {0: True, 1: True, 2: True, 3: True, "__all__": True})
def testRoundRobinMock(self):
env = RoundRobinMultiAgent(2)
obs = env.reset()
self.assertEqual(obs, {0: 0, 1: 0})
obs, rew, done, info = env.step({0: 0, 1: 0})
self.assertEqual(obs, {0: 0})
for _ in range(4):
obs, rew, done, info = env.step({0: 0})
self.assertEqual(obs, {1: 0})
self.assertEqual(done["__all__"], False)
obs, rew, done, info = env.step({1: 0})
self.assertEqual(obs, {0: 0})
self.assertEqual(done["__all__"], True)
def testVectorizeBasic(self):
env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
self.assertEqual(
dones,
{0: {0: False, 1: False, "__all__": False},
1: {0: False, 1: False, "__all__": False}})
for _ in range(24):
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
self.assertEqual(
dones,
{0: {0: False, 1: False, "__all__": False},
1: {0: False, 1: False, "__all__": False}})
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(
dones,
{0: {0: True, 1: True, "__all__": True},
1: {0: True, 1: True, "__all__": True}})
# Reset processing
self.assertRaises(
ValueError,
lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}))
self.assertEqual(env.try_reset(0), {0: 0, 1: 0})
self.assertEqual(env.try_reset(1), {0: 0, 1: 0})
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
self.assertEqual(
dones,
{0: {0: False, 1: False, "__all__": False},
1: {0: False, 1: False, "__all__": False}})
def testVectorizeRoundRobin(self):
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
env.send_actions({0: {0: 0}, 1: {0: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
env.send_actions({0: {0: 0}, 1: {0: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
if __name__ == '__main__':
ray.init()
unittest.main(verbosity=2)
+16 -4
View File
@@ -60,16 +60,16 @@ class PartOffPolicyServing(ServingEnv):
class SimpleOffPolicyServing(ServingEnv):
def __init__(self, env):
def __init__(self, env, fixed_action):
ServingEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
self.fixed_action = fixed_action
def run(self):
eid = self.start_episode()
obs = self.env.reset()
while True:
# Take random actions
action = self.env.action_space.sample()
action = self.fixed_action
self.log_action(eid, obs, action)
obs, reward, done, info = self.env.step(action)
self.log_returns(eid, reward, info=info)
@@ -131,13 +131,15 @@ class TestServingEnv(unittest.TestCase):
def testServingEnvOffPolicy(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)),
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 50)
self.assertEqual(batch["actions"][0], 42)
self.assertEqual(batch["actions"][-1], 42)
def testServingEnvBadActions(self):
ev = CommonPolicyEvaluator(
@@ -185,6 +187,16 @@ class TestServingEnv(unittest.TestCase):
return
raise Exception("failed to improve reward")
def testServingEnvHorizonNotSupported(self):
ev = CommonPolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
episode_horizon=20,
batch_steps=10,
batch_mode="complete_episodes")
ev.sample()
self.assertRaises(Exception, lambda: ev.sample())
if __name__ == '__main__':
ray.init()
+219 -22
View File
@@ -2,66 +2,121 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.utils.vector_env import VectorEnv
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
class AsyncVectorEnv(object):
"""The lowest-level env interface used by RLlib for sampling.
AsyncVectorEnv models multiple agents executing asynchronously. A call to
poll() returns observations from ready agents, and actions for those agents
AsyncVectorEnv models multiple agents executing asynchronously in multiple
environments. A call to poll() returns observations from ready agents
keyed by their environment and agent ids, and actions for those agents
can be sent back via send_actions().
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
conversions internally in CommonPolicyEvaluator, for example:
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
Examples:
>>> env = MyAsyncVectorEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{"car_0": [2.4, 1.6], "car_1": [3.4, -3.2]}
>>> env.send_actions({"car_0": 0, "car_1": 1})
{
"env_0": {
"car_0": [2.4, 1.6],
"car_1": [3.4, -3.2],
}
}
>>> env.send_actions(
actions={
"env_0": {
"car_0": 0,
"car_1": 1,
}
})
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{"car_0": [4.1, 1.7], "car_1": [3.2, -4.2]}
{
"env_0": {
"car_0": [4.1, 1.7],
"car_1": [3.2, -4.2],
}
}
>>> print(dones)
{
"env_0": {
"__all__": False,
"car_0": False,
"car_1": True,
}
}
"""
@staticmethod
def wrap_async(env, make_env=None, num_envs=1):
"""Wraps any env type as needed to expose the async interface."""
if not isinstance(env, AsyncVectorEnv):
if isinstance(env, MultiAgentEnv):
env = _MultiAgentEnvToAsync(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ServingEnv):
if num_envs != 1:
raise ValueError(
"ServingEnv does not currently support num_envs > 1.")
env = _ServingEnvToAsync(env)
elif isinstance(env, VectorEnv):
env = _VectorEnvToAsync(env)
else:
env = VectorEnv.wrap(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
env = _VectorEnvToAsync(env)
assert isinstance(env, AsyncVectorEnv)
return env
def poll(self):
"""Returns observations from ready agents.
The returns are dicts mapping from agent episode ids to values. The
number of agents can vary over time.
The returns are two-level dicts mapping from env_id to a dict of
agent_id to values. The number of agents and envs can vary over time.
Returns:
obs (dict): New observations for each ready episode.
rewards (dict): Reward values for each ready episode. If the
obs (dict): New observations for each ready agent.
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready episode. If True, the
episode is terminated.
infos (dict): Info values for each ready episode.
dones (dict): Done values for each ready agent. The special key
"__all__" is used to indicate env termination.
infos (dict): Info values for each ready agent.
off_policy_actions (dict): Agents may take off-policy actions. When
that happens, there will be an entry in this dict that contains
the taken action.
the taken action. There is no need to send_actions() for agents
that have already chosen off-policy actions.
"""
raise NotImplementedError
def send_actions(self, action_dict):
"""Called to send actions back to running agents in this env.
Actions should be sent for each ready agent that returned observations
in the previous poll() call.
Arguments:
action_dict (dict): Actions for each agent to take.
action_dict (dict): Actions values keyed by env_id and agent_id.
"""
raise NotImplementedError
def try_reset(self, agent_id):
"""Attempt to reset the agent with the given id.
def try_reset(self, env_id):
"""Attempt to reset the env with the given id.
If the environment does not support synchronous reset, None can be
returned here.
Returns:
obs (obj|None): Resetted observation or None if not supported.
obs (dict|None): Resetted observation or None if not supported.
"""
return None
@@ -74,8 +129,63 @@ class AsyncVectorEnv(object):
return None
# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "single_agent"
def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
class _ServingEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
def __init__(self, serving_env):
self.serving_env = serving_env
serving_env.start()
def poll(self):
with self.serving_env._results_avail_condition:
results = self._poll()
while len(results[0]) == 0:
self.serving_env._results_avail_condition.wait()
results = self._poll()
if not self.serving_env.isAlive():
raise Exception("Serving thread has stopped.")
limit = self.serving_env._max_concurrent_episodes
assert len(results[0]) < limit, \
("Too many concurrent episodes, were some leaked? This ServingEnv "
"was created with max_concurrent={}".format(limit))
return results
def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.serving_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
del self.serving_env._episodes[eid]
if data:
all_obs[eid] = data["obs"]
all_rewards[eid] = data["reward"]
all_dones[eid] = data["done"]
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
return _with_dummy_agent_id(all_obs), \
_with_dummy_agent_id(all_rewards), \
_with_dummy_agent_id(all_dones, "__all__"), \
_with_dummy_agent_id(all_infos), \
_with_dummy_agent_id(off_policy_actions)
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.serving_env._episodes[eid].action_queue.put(
action[_DUMMY_AGENT_ID])
class _VectorEnvToAsync(AsyncVectorEnv):
"""Wraps VectorEnv to implement AsyncVectorEnv.
"""Internal adapter of VectorEnv to AsyncVectorEnv.
We assume the caller will always send the full vector of actions in each
call to send_actions(), and that they call reset_at() on all completed
@@ -99,17 +209,104 @@ class _VectorEnvToAsync(AsyncVectorEnv):
self.cur_rewards = []
self.cur_dones = []
self.cur_infos = []
return new_obs, rewards, dones, infos, {}
return _with_dummy_agent_id(new_obs), \
_with_dummy_agent_id(rewards), \
_with_dummy_agent_id(dones, "__all__"), \
_with_dummy_agent_id(infos), {}
def send_actions(self, action_dict):
action_vector = [None] * self.num_envs
for i in range(self.num_envs):
action_vector[i] = action_dict[i]
action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
self.vector_env.vector_step(action_vector)
def try_reset(self, agent_id):
return self.vector_env.reset_at(agent_id)
def try_reset(self, env_id):
return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
def get_unwrapped(self):
return self.vector_env.get_unwrapped()
class _MultiAgentEnvToAsync(AsyncVectorEnv):
"""Internal adapter of MultiAgentEnv to AsyncVectorEnv.
This also supports vectorization if num_envs > 1.
"""
def __init__(self, make_env, existing_envs, num_envs):
"""Wrap existing multi-agent envs.
Arguments:
make_env (func|None): Factory that produces a new multiagent env.
Must be defined if the number of existing envs is less than
num_envs.
existing_envs (list): List of existing multiagent envs.
num_envs (int): Desired num multiagent envs to keep total.
"""
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
self.dones = set()
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
for env in self.envs:
assert isinstance(env, MultiAgentEnv)
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
def poll(self):
obs, rewards, dones, infos = {}, {}, {}, {}
for i, env_state in enumerate(self.env_states):
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
return obs, rewards, dones, infos, {}
def send_actions(self, action_dict):
for env_id, agent_dict in action_dict.items():
if env_id in self.dones:
raise ValueError("Env {} is already done".format(env_id))
env = self.envs[env_id]
obs, rewards, dones, infos = env.step(agent_dict)
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
if obs is not None:
self.dones.remove(env_id)
return obs
class _MultiAgentEnvState(object):
def __init__(self, env):
assert isinstance(env, MultiAgentEnv)
self.env = env
self.reset()
def poll(self):
if self.last_obs is None:
raise ValueError("Need to send action after polling")
obs, rew, dones, info = (
self.last_obs, self.last_rewards, self.last_dones, self.last_infos)
self.last_obs = None
self.last_rewards = None
self.last_dones = None
self.last_infos = None
return obs, rew, dones, info
def observe(self, obs, rewards, dones, infos):
self.last_obs = obs
self.last_rewards = rewards
self.last_dones = dones
self.last_infos = infos
def reset(self):
self.last_obs = self.env.reset()
self.last_rewards = {
agent_id: None for agent_id in self.last_obs.keys()}
self.last_dones = {
agent_id: False for agent_id in self.last_obs.keys()}
self.last_infos = {
agent_id: {} for agent_id in self.last_obs.keys()}
self.last_dones["__all__"] = False
return self.last_obs
@@ -8,14 +8,15 @@ import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.optimizers import SampleBatch
from ray.rllib.optimizers import MultiAgentBatch
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.vector_env import VectorEnv
from ray.tune.result import TrainingResult
@@ -152,6 +153,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
self.env = env_creator(env_config)
if isinstance(self.env, VectorEnv) or \
isinstance(self.env, ServingEnv) or \
isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, AsyncVectorEnv):
def wrap(env):
return env # we can't auto-wrap these env types
@@ -169,8 +171,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
def make_env():
return wrap(env_creator(env_config))
self.policy_map = {}
if issubclass(policy_graph, TFPolicyGraph):
with tf.Graph().as_default():
if tf_session_creator:
@@ -186,25 +186,21 @@ class CommonPolicyEvaluator(PolicyEvaluator):
policy = policy_graph(
self.env.observation_space, self.env.action_space,
policy_config)
self.policy_map = {
"default": policy
}
self.obs_filter = get_filter(
observation_filter, self.env.observation_space.shape)
self.filters = {"obs_filter": self.obs_filter}
self.filters = {
# TODO(ekl) make the obs space dependent on policy
policy_id: get_filter(
observation_filter, self.env.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
# Always use vector env for consistency even if num_envs = 1
if not isinstance(self.env, AsyncVectorEnv):
if isinstance(self.env, ServingEnv):
self.vector_env = _ServingEnvToAsync(self.env)
else:
if not isinstance(self.env, VectorEnv):
self.env = VectorEnv.wrap(
make_env, [self.env], num_envs=num_envs)
self.vector_env = _VectorEnvToAsync(self.env)
else:
self.vector_env = self.env
self.async_env = AsyncVectorEnv.wrap_async(
self.env, make_env=make_env, num_envs=num_envs)
if self.batch_mode == "truncate_episodes":
if batch_steps % num_envs != 0:
@@ -222,19 +218,21 @@ class CommonPolicyEvaluator(PolicyEvaluator):
"Unsupported batch mode: {}".format(self.batch_mode))
if sample_async:
self.sampler = AsyncSampler(
self.vector_env, self.policy_map["default"], self.obs_filter,
batch_steps, horizon=episode_horizon, pack=pack_episodes)
self.async_env, self.policy_map, lambda agent_id: "default",
self.filters, batch_steps, horizon=episode_horizon,
pack=pack_episodes)
self.sampler.start()
else:
self.sampler = SyncSampler(
self.vector_env, self.policy_map["default"], self.obs_filter,
batch_steps, horizon=episode_horizon, pack=pack_episodes)
self.async_env, self.policy_map, lambda agent_id: "default",
self.filters, batch_steps, horizon=episode_horizon,
pack=pack_episodes)
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
Return:
SampleBatch from evaluating the current policies.
SampleBatch|MultiAgentBatch from evaluating the current policies.
"""
batches = [self.sampler.get_data()]
@@ -243,11 +241,16 @@ class CommonPolicyEvaluator(PolicyEvaluator):
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
batch = SampleBatch.concat_samples(batches)
batch = batches[0].concat_samples(batches)
if self.compress_observations:
batch["obs"] = [pack(o) for o in batch["obs"]]
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
if isinstance(batch, MultiAgentBatch):
for data in batch.policy_batches.values():
data["obs"] = [pack(o) for o in data["obs"]]
data["new_obs"] = [pack(o) for o in data["new_obs"]]
else:
batch["obs"] = [pack(o) for o in batch["obs"]]
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
return batch
+60
View File
@@ -0,0 +1,60 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MultiAgentEnv(object):
"""An environment that hosts multiple independent agents.
Agents are identified by (string) agent ids.
Examples:
>>> env = MyMultiAgentEnv()
>>> obs = env.reset()
>>> print(obs)
{
"car_0": [2.4, 1.6],
"car_1": [3.4, -3.2],
"traffic_light_1": [0, 3, 5, 1],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"car_0": 1, "car_1": 0, "traffic_light_1": 2,
})
>>> print(rewards)
{
"car_0": 3,
"car_1": -1,
"traffic_light_1": 0,
}
>>> print(dones)
{
"car_0": False,
"car_1": True,
"__all__": False,
}
"""
def reset(self):
"""Resets the env and returns observations from ready agents.
Returns:
obs (dict): New observations for each ready agent.
"""
raise NotImplementedError
def step(self, action_dict):
"""Returns observations from ready agents.
The returns are dicts mapping from agent_id strings to values. The
number of agents in the env can vary over time.
Returns:
obs (dict): New observations for each ready agent.
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" is used to indicate env termination.
infos (dict): Info values for each ready agent.
"""
raise NotImplementedError
+6 -3
View File
@@ -25,7 +25,9 @@ class PolicyGraph(object):
action_space (gym.Space): Action space of the env.
config (dict): Policy-specific configuration data.
"""
pass
self.observation_space = observation_space
self.action_space = action_space
def compute_actions(self, obs_batch, state_batches, is_training=False):
"""Compute actions for the current policy.
@@ -70,8 +72,9 @@ class PolicyGraph(object):
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains the
experience batches seen by other agents.
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy_graph, agent_batch) tuples
containing the policy graph and experiences of the other agent.
Returns:
SampleBatch: postprocessed sample batch.
+198 -141
View File
@@ -7,13 +7,16 @@ import numpy as np
import six.moves.queue as queue
import threading
from ray.rllib.optimizers.sample_batch import SampleBatchBuilder
from ray.rllib.utils.vector_env import VectorEnv
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \
MultiAgentBatch
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
CompletedRollout = namedtuple("CompletedRollout",
["episode_length", "episode_reward"])
RolloutMetrics = namedtuple(
"RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"])
PolicyEvalData = namedtuple(
"PolicyEvalData", ["env_id", "agent_id", "obs", "rnn_state"])
class SyncSampler(object):
@@ -26,26 +29,23 @@ class SyncSampler(object):
thread."""
def __init__(
self, env, policy, obs_filter, num_local_steps,
horizon=None, pack=False):
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, VectorEnv):
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
env = _VectorEnvToAsync(env)
self.async_vector_env = env
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.num_local_steps = num_local_steps
self.horizon = horizon
self.policy = policy
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(self.async_vector_env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter, pack)
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, pack)
self.metrics_queue = queue.Queue()
def get_data(self):
while True:
item = next(self.rollout_provider)
if isinstance(item, CompletedRollout):
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
return item
@@ -67,23 +67,20 @@ class AsyncSampler(threading.Thread):
accumulate and the gradient can be calculated on up to 5 batches."""
def __init__(
self, env, policy, obs_filter, num_local_steps,
horizon=None, pack=False):
assert getattr(
obs_filter, "is_concurrent",
False), ("Observation Filter must support concurrent updates.")
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, VectorEnv):
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
env = _VectorEnvToAsync(env)
self.async_vector_env = env
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.horizon = horizon
self.policy = policy
self._obs_filter = obs_filter
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.daemon = True
self.pack = pack
@@ -95,15 +92,15 @@ class AsyncSampler(threading.Thread):
raise e
def _run(self):
rollout_provider = _env_runner(self.async_vector_env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter, self.pack)
rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, self.pack)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
@@ -115,8 +112,9 @@ class AsyncSampler(threading.Thread):
if isinstance(rollout, BaseException):
raise rollout
# We can't auto-concat rollouts in vector mode
if self.async_vector_env.num_envs > 1:
# We can't auto-concat rollouts in these modes
if self.async_vector_env.num_envs > 1 or \
isinstance(rollout, MultiAgentBatch):
return rollout
# Auto-concat rollouts; TODO(ekl) is this important for A3C perf?
@@ -141,23 +139,22 @@ class AsyncSampler(threading.Thread):
def _env_runner(
async_vector_env, policy, num_local_steps, horizon, obs_filter, pack):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a
certain length, the thread runner appends the policy to the queue. Yields
when `timestep_limit` is surpassed, environment terminates, or
`num_local_steps` is reached.
async_vector_env, policies, policy_mapping_fn, num_local_steps,
horizon, obs_filters, pack):
"""This implements the common experience collection logic.
Args:
async_vector_env: env implementing AsyncVectorEnv.
policy: Policy used to interact with environment. Also sets fields
to be included in `SampleBatch`.
num_local_steps: Number of steps before `SampleBatch` is yielded. Set
to infinity to yield complete episodes.
horizon: Horizon of the episode.
obs_filter: Filter used to process observations.
pack: Whether to pack multiple episodes into each batch. This
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
policies (dict): Map of policy ids to PolicyGraph instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode.
num_local_steps (int): Number of episode steps before `SampleBatch` is
yielded. Set to infinity to yield complete episodes.
horizon (int): Horizon of the episode.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `num_local_steps` in size.
Yields:
@@ -181,110 +178,131 @@ def _env_runner(
if batch_builder_pool:
return batch_builder_pool.pop()
else:
return SampleBatchBuilder()
return MultiAgentSampleBatchBuilder(policies)
episodes = defaultdict(
lambda: _Episode(policy.get_initial_state(), get_batch_builder))
active_episodes = defaultdict(
lambda: _MultiAgentEpisode(
policies, policy_mapping_fn, get_batch_builder))
while True:
# Get observations from ready envs
# Get observations from all ready agents
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
async_vector_env.poll()
ready_eids = []
ready_obs = []
ready_rnn_states = []
# Process and record the new observations
for eid, raw_obs in unfiltered_obs.items():
episode = episodes[eid]
filtered_obs = obs_filter(raw_obs)
ready_eids.append(eid)
ready_obs.append(filtered_obs)
ready_rnn_states.append(episode.rnn_state)
# Map of policy_id to list of PolicyEvalData
to_eval = defaultdict(list)
if episode.last_observation is None:
episode.last_observation = filtered_obs
continue # This is the initial observation after a reset
# For each environment
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
episode = active_episodes[env_id]
if not new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode.add_agent_rewards(rewards[env_id])
episode.length += 1
episode.total_reward += rewards[eid]
# Handle episode terminations
if dones[eid] or episode.length >= horizon:
done = True
yield CompletedRollout(episode.length, episode.total_reward)
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
yield RolloutMetrics(
episode.length, episode.total_reward,
dict(episode.agent_rewards))
else:
done = False
all_done = False
if infos[eid].get("training_enabled", True):
episode.batch_builder.add_values(
obs=episode.last_observation,
actions=episode.last_action_flat(),
rewards=rewards[eid],
dones=done,
new_obs=filtered_obs,
**episode.last_pi_info)
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = obs_filters[policy_id](raw_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
# Cut the batch if we're not packing multiple episodes into
# one, or if we've exceeded the requested batch size.
if (done and not pack) or \
last_observation = episode.last_observation_for(agent_id)
episode.set_last_observation(agent_id, filtered_obs)
# Record transition info if applicable
if last_observation is not None and \
infos[env_id][agent_id].get("training_enabled", True):
episode.batch_builder.add_values(
agent_id,
policy_id,
t=episode.length - 1,
obs=last_observation,
actions=episode.last_action_for(agent_id),
rewards=rewards[env_id][agent_id],
dones=agent_done,
infos=infos[env_id][agent_id],
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= num_local_steps:
yield episode.batch_builder.build_and_reset(
policy.postprocess_trajectory)
elif done:
# Make sure postprocessor never crosses episode boundaries
episode.batch_builder.postprocess_batch_so_far(
policy.postprocess_trajectory)
yield episode.batch_builder.build_and_reset()
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far()
if done:
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
del episodes[eid]
resetted_obs = async_vector_env.try_reset(eid)
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
assert horizon == float("inf"), \
"Setting episode horizon requires reset() support."
ready_eids.pop()
ready_obs.pop()
ready_rnn_states.pop()
else:
# Reset successful, put in the new obs as ready
episode = episodes[eid]
episode.last_observation = obs_filter(resetted_obs)
ready_obs[-1] = episode.last_observation
ready_rnn_states[-1] = episode.rnn_state
else:
episode.last_observation = filtered_obs
# Creates a new episode
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = obs_filters[policy_id](raw_obs)
episode.set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
if not ready_eids:
continue # No actions to take
# Map of env_id -> agent_id -> action
action_dict = defaultdict(dict)
# Compute action for ready envs
ready_rnn_state_cols = _to_column_format(ready_rnn_states)
actions, new_rnn_state_cols, pi_info_cols = policy.compute_actions(
ready_obs, ready_rnn_state_cols, is_training=True)
# Add RNN state info
for f_i, column in enumerate(ready_rnn_state_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(new_rnn_state_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# TODO(ekl) fuse all policy evaluation into one TF run
for policy_id, eval_data in to_eval.items():
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
actions, rnn_out_cols, pi_info_cols = \
policies[policy_id].compute_actions(
[t.obs for t in eval_data], rnn_in_cols, is_training=True)
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Save output rows
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
action_dict[env_id][agent_id] = action
episode = active_episodes[env_id]
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode.set_last_pi_info(
agent_id, {k: v[i] for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode.set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
else:
episode.set_last_action(agent_id, action)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(dict(zip(ready_eids, actions)))
# Store the computed action info
for i, eid in enumerate(ready_eids):
episode = episodes[eid]
if eid in off_policy_actions:
episode.last_action = off_policy_actions[eid]
else:
episode.last_action = actions[i]
episode.rnn_state = [column[i] for column in new_rnn_state_cols]
episode.last_pi_info = {
k: column[i] for k, column in pi_info_cols.items()}
async_vector_env.send_actions(dict(action_dict))
def _to_column_format(rnn_state_rows):
@@ -293,18 +311,57 @@ def _to_column_format(rnn_state_rows):
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
class _Episode(object):
def __init__(self, init_rnn_state, batch_builder_factory):
self.rnn_state = init_rnn_state
class _MultiAgentEpisode(object):
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
self.batch_builder = batch_builder_factory()
self.last_action = None
self.last_observation = None
self.last_pi_info = None
self.total_reward = 0.0
self.length = 0
self.agent_rewards = defaultdict(float)
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
def last_action_flat(self):
# Concatenate multiagent actions
if isinstance(self.last_action, list):
return np.concatenate(self.last_action, axis=0).flatten()
return self.last_action
def add_agent_rewards(self, reward_dict):
for agent_id, reward in reward_dict.items():
self.agent_rewards[agent_id] += reward
self.total_reward += reward
def policy_for(self, agent_id):
if agent_id not in self._agent_to_policy:
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
def rnn_state_for(self, agent_id):
if agent_id not in self._agent_to_rnn_state:
policy = self._policies[self.policy_for(agent_id)]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
def last_observation_for(self, agent_id):
return self._agent_to_last_obs.get(agent_id)
def last_action_for(self, agent_id):
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
action = np.concatenate(action, axis=0).flatten()
return action
def last_pi_info_for(self, agent_id):
return self._agent_to_last_pi_info[agent_id]
def set_rnn_state(self, agent_id, rnn_state):
self._agent_to_rnn_state[agent_id] = rnn_state
def set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
def set_last_pi_info(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info
+3 -47
View File
@@ -6,11 +6,9 @@ from six.moves import queue
import threading
import uuid
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
class ServingEnv(threading.Thread):
"""Environment that provides policy serving.
"""An environment that provides policy serving.
Unlike simulator envs, control is inverted. The environment queries the
policy to obtain actions and logs observations and rewards for training.
@@ -91,7 +89,7 @@ class ServingEnv(threading.Thread):
raise ValueError(
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _Episode(
self._episodes[episode_id] = _ServingEnvEpisode(
episode_id, self._results_avail_condition, training_enabled)
return episode_id
@@ -165,49 +163,7 @@ class ServingEnv(threading.Thread):
return self._episodes[episode_id]
class _ServingEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
def __init__(self, serving_env):
self.serving_env = serving_env
serving_env.start()
def poll(self):
with self.serving_env._results_avail_condition:
results = self._poll()
while len(results[0]) == 0:
self.serving_env._results_avail_condition.wait()
results = self._poll()
if not self.serving_env.isAlive():
raise Exception("Serving thread has stopped.")
limit = self.serving_env._max_concurrent_episodes
assert len(results[0]) < limit, \
("Too many concurrent episodes, were some leaked? This ServingEnv "
"was created with max_concurrent={}".format(limit))
return results
def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.serving_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
del self.serving_env._episodes[eid]
if data:
all_obs[eid] = data["obs"]
all_rewards[eid] = data["reward"]
all_dones[eid] = data["done"]
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
return all_obs, all_rewards, all_dones, all_infos, off_policy_actions
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.serving_env._episodes[eid].action_queue.put(action)
class _Episode(object):
class _ServingEnvEpisode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition, training_enabled):
+23 -1
View File
@@ -22,22 +22,44 @@ class VectorEnv(object):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
def vector_reset(self):
"""Resets all environments.
Returns:
obs (list): Vector of observations from each environment.
"""
raise NotImplementedError
def reset_at(self, index):
"""Resets a single environment.
Returns:
obs (obj): Observations from the resetted environment.
"""
raise NotImplementedError
def vector_step(self, actions):
"""Vectorized step.
Arguments:
actions (list): Actions for each env.
Returns:
obs (list): New observations for each env.
rewards (list): Reward values for each env.
dones (list): Done values for each env.
infos (list): Info values for each env.
"""
raise NotImplementedError
def get_unwrapped(self):
"""Returns a single instance of the underlying env."""
raise NotImplementedError
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper for gym envs to implement VectorEnv.
Arguents:
Arguments:
make_env (func|None): Factory that produces a new gym env. Must be
defined if the number of existing envs is less than num_envs.
existing_envs (list): List of existing gym envs.