mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[rllib] Part 1 of multiagent support: make sampler path support multiagent envs (#2268)
This refactors the RLlib sampler to support multi-agent environments. The main changes were: AsyncVectorEnv now produces dicts of env_id -> agent_id -> value rather than env_id -> value. This lets it model both vectorized and multi-agent envs (or both). The sampler class operates over the above nested dict structure for all envs. Single agent envs just return a dict with one agent_id=single_agent. When sample() is called on a policy evaluator, in the single agent case we return a SampleBatch, otherwise we return a MultiAgentBatch (which is a list of sample batches per policy). Left for another PR: Exposing multi-agent in the public interfaces. Optimizations such as evaluating multiple policies in one TF run.
This commit is contained in:
@@ -154,7 +154,8 @@ class A3CAgent(Agent):
|
||||
def compute_action(self, observation, state=None):
|
||||
if state is None:
|
||||
state = []
|
||||
obs = self.local_evaluator.obs_filter(observation, update=False)
|
||||
obs = self.local_evaluator.filters["default"](
|
||||
observation, update=False)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(
|
||||
obs, state, is_training=False)[0])
|
||||
|
||||
@@ -18,6 +18,7 @@ class SharedTorchPolicy(PolicyGraph):
|
||||
"""A simple, non-recurrent PyTorch policy example."""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
PolicyGraph.__init__(self, obs_space, action_space, config)
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = config.get("summarize")
|
||||
|
||||
@@ -3,7 +3,7 @@ from ray.rllib.optimizers.async_optimizer import AsyncOptimizer
|
||||
from ray.rllib.optimizers.local_sync import LocalSyncOptimizer
|
||||
from ray.rllib.optimizers.local_sync_replay import LocalSyncReplayOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \
|
||||
TFMultiGPUSupport
|
||||
|
||||
@@ -11,4 +11,4 @@ from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \
|
||||
__all__ = [
|
||||
"ApexOptimizer", "AsyncOptimizer", "LocalSyncOptimizer",
|
||||
"LocalSyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch",
|
||||
"PolicyEvaluator", "TFMultiGPUSupport"]
|
||||
"PolicyEvaluator", "TFMultiGPUSupport", "MultiAgentBatch"]
|
||||
|
||||
@@ -14,7 +14,6 @@ class SampleBatchBuilder(object):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.postprocessed = []
|
||||
self.buffers = collections.defaultdict(list)
|
||||
self.count = 0
|
||||
|
||||
@@ -25,29 +24,131 @@ class SampleBatchBuilder(object):
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
def postprocess_batch_so_far(self, postprocessor):
|
||||
"""Apply the given postprocessor to any unprocessed rows."""
|
||||
def add_batch(self, batch):
|
||||
"""Add the given batch of values to this batch."""
|
||||
|
||||
batch = postprocessor(self._build_buffers())
|
||||
self.postprocessed.append(batch)
|
||||
for k, column in batch.items():
|
||||
self.buffers[k].extend(column)
|
||||
self.count += batch.count
|
||||
|
||||
def build_and_reset(self, postprocessor):
|
||||
"""Returns a sample batch including all previously added values.
|
||||
def build_and_reset(self):
|
||||
"""Returns a sample batch including all previously added values."""
|
||||
|
||||
Any unprocessed rows will be first postprocessed with the given
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far(postprocessor)
|
||||
batch = SampleBatch.concat_samples(self.postprocessed)
|
||||
self.postprocessed = []
|
||||
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
|
||||
self.buffers.clear()
|
||||
self.count = 0
|
||||
return batch
|
||||
|
||||
def _build_buffers(self):
|
||||
batch = SampleBatch({k: np.array(v) for k, v in self.buffers.items()})
|
||||
self.buffers.clear()
|
||||
return batch
|
||||
|
||||
class MultiAgentSampleBatchBuilder(object):
|
||||
"""Util to build SampleBatches for each policy in a multi-agent env.
|
||||
|
||||
Input data is per-agent, while output data is per-policy. There is an M:N
|
||||
mapping between agents and policies. We retain one local batch builder
|
||||
per agent. When an agent is done, then its local batch is appended into the
|
||||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_map):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy graph instances.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.policy_builders = {
|
||||
k: SampleBatchBuilder() for k in policy_map.keys()}
|
||||
self.agent_builders = {}
|
||||
self.agent_to_policy = {}
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def has_pending_data(self):
|
||||
"""Returns whether there is pending unprocessed data."""
|
||||
|
||||
return len(self.agent_builders) > 0
|
||||
|
||||
def add_values(self, agent_id, policy_id, **values):
|
||||
"""Add the given dictionary (row) of values to this batch.
|
||||
|
||||
Arguments:
|
||||
agent_id (obj): Unique id for the agent we are adding values for.
|
||||
policy_id (obj): Unique id for policy controlling the agent.
|
||||
values (dict): Row of values to add for this agent.
|
||||
"""
|
||||
|
||||
if agent_id not in self.agent_builders:
|
||||
self.agent_builders[agent_id] = SampleBatchBuilder()
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
builder = self.agent_builders[agent_id]
|
||||
builder.add_values(**values)
|
||||
|
||||
def postprocess_batch_so_far(self):
|
||||
"""Apply policy postprocessors to any unprocessed rows.
|
||||
|
||||
This pushes the postprocessed per-agent batches onto the per-policy
|
||||
builders, clearing per-agent state.
|
||||
"""
|
||||
|
||||
# Materialize the batches so far
|
||||
pre_batches = {}
|
||||
for agent_id, builder in self.agent_builders.items():
|
||||
pre_batches[agent_id] = (
|
||||
self.policy_map[self.agent_to_policy[agent_id]],
|
||||
builder.build_and_reset())
|
||||
|
||||
# Apply postprocessor
|
||||
post_batches = {}
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches)
|
||||
|
||||
# Append into policy batches and reset
|
||||
for agent_id, post_batch in post_batches.items():
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
def build_and_reset(self):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with a policy
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far()
|
||||
policy_batches = {}
|
||||
for policy_id, policy_batch_builder in self.policy_builders.items():
|
||||
policy_batches[policy_id] = policy_batch_builder.build_and_reset()
|
||||
self.count = 0
|
||||
return MultiAgentBatch.wrap_as_needed(policy_batches)
|
||||
|
||||
|
||||
class MultiAgentBatch(object):
|
||||
def __init__(self, policy_batches):
|
||||
self.policy_batches = policy_batches
|
||||
|
||||
@staticmethod
|
||||
def wrap_as_needed(batches):
|
||||
if len(batches) == 1 and "default" in batches:
|
||||
return batches["default"]
|
||||
return MultiAgentBatch(batches)
|
||||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
policy_batches = collections.defaultdict(list)
|
||||
for s in samples:
|
||||
assert isinstance(s, MultiAgentBatch)
|
||||
for policy_id, batch in s.policy_batches.items():
|
||||
policy_batches[policy_id].append(batch)
|
||||
out = {}
|
||||
for policy_id, batches in policy_batches.items():
|
||||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out)
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
|
||||
@@ -88,7 +88,7 @@ class ProximalPolicyGraph(object):
|
||||
feed_dict={self.observations: observations})
|
||||
return action, [], {"vf_preds": vf, "logprobs": logprobs}
|
||||
|
||||
def postprocess_trajectory(self, batch):
|
||||
def postprocess_trajectory(self, batch, other_agent_batches=None):
|
||||
return batch
|
||||
|
||||
def get_initial_state(self):
|
||||
|
||||
@@ -79,8 +79,9 @@ class PPOEvaluator(TFMultiGPUSupport):
|
||||
self.filters = {"obs_filter": self.obs_filter,
|
||||
"rew_filter": self.rew_filter}
|
||||
self.sampler = SyncSampler(
|
||||
self.env, self.common_policy, self.obs_filter,
|
||||
self.config["horizon"], self.config["horizon"])
|
||||
self.env, {"default": self.common_policy}, lambda _: "default",
|
||||
{"default": self.obs_filter}, self.config["horizon"],
|
||||
self.config["horizon"])
|
||||
|
||||
def tf_loss_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
@@ -20,7 +20,7 @@ class MockPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
||||
def postprocess_trajectory(self, batch):
|
||||
def postprocess_trajectory(self, batch, other_agent_batches=None):
|
||||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class BadPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
raise Exception("intentional error")
|
||||
|
||||
def postprocess_trajectory(self, batch):
|
||||
def postprocess_trajectory(self, batch, other_agent_batches=None):
|
||||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
@@ -211,6 +211,9 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 20)
|
||||
self.assertEqual(
|
||||
batch["t"].tolist(),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
def testFilterSync(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
@@ -221,7 +224,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
filters = ev.get_filters(flush_after=True)
|
||||
obs_f = filters["obs_filter"]
|
||||
obs_f = filters["default"]
|
||||
self.assertNotEqual(obs_f.rs.n, 0)
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
|
||||
@@ -235,8 +238,8 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
time.sleep(2)
|
||||
filters2 = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["obs_filter"]
|
||||
obs_f2 = filters2["obs_filter"]
|
||||
obs_f = filters["default"]
|
||||
obs_f2 = filters2["default"]
|
||||
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
|
||||
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
||||
|
||||
@@ -250,15 +253,15 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
# Current State
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["obs_filter"]
|
||||
obs_f = filters["default"]
|
||||
|
||||
self.assertLessEqual(obs_f.buffer.n, 20)
|
||||
|
||||
new_obsf = obs_f.copy()
|
||||
new_obsf.rs._n = 100
|
||||
ev.sync_filters({"obs_filter": new_obsf})
|
||||
ev.sync_filters({"default": new_obsf})
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["obs_filter"]
|
||||
obs_f = filters["default"]
|
||||
self.assertGreaterEqual(obs_f.rs.n, 100)
|
||||
self.assertLessEqual(obs_f.buffer.n, 20)
|
||||
|
||||
@@ -266,7 +269,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
filters = ev.get_filters(flush_after=True)
|
||||
obs_f = filters["obs_filter"]
|
||||
obs_f = filters["default"]
|
||||
self.assertNotEqual(obs_f.rs.n, 0)
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
return obs_f
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.test.test_common_policy_evaluator import MockEnv
|
||||
from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync
|
||||
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class BasicMultiAgent(MultiAgentEnv):
|
||||
"""Env of N independent agents, each of which exits after 25 steps."""
|
||||
|
||||
def __init__(self, num):
|
||||
self.agents = [MockEnv(25) for _ in range(num)]
|
||||
self.dones = set()
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
|
||||
def step(self, action_dict):
|
||||
obs, rew, done, info = {}, {}, {}, {}
|
||||
for i, action in action_dict.items():
|
||||
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
"""Env of N independent agents, each of which exits after 5 steps.
|
||||
|
||||
On each step() of the env, only one agent takes an action."""
|
||||
|
||||
def __init__(self, num):
|
||||
self.agents = [MockEnv(5) for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
self.last_done = {}
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
self.num = num
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
|
||||
def step(self, action_dict):
|
||||
assert len(self.dones) != len(self.agents)
|
||||
for i, action in action_dict.items():
|
||||
(self.last_obs[i], self.last_rew[i], self.last_done[i],
|
||||
self.last_info[i]) = self.agents[i].step(action)
|
||||
if self.last_done[i]:
|
||||
self.dones.add(i)
|
||||
obs = {self.i: self.last_obs[i]}
|
||||
rew = {self.i: self.last_rew[i]}
|
||||
done = {self.i: self.last_done[i]}
|
||||
info = {self.i: self.last_info[i]}
|
||||
self.i += 1
|
||||
self.i %= self.num
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testBasicMock(self):
|
||||
env = BasicMultiAgent(4)
|
||||
obs = env.reset()
|
||||
self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
|
||||
for _ in range(24):
|
||||
obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
|
||||
self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
|
||||
self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1})
|
||||
self.assertEqual(
|
||||
done,
|
||||
{0: False, 1: False, 2: False, 3: False, "__all__": False})
|
||||
obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
|
||||
self.assertEqual(
|
||||
done, {0: True, 1: True, 2: True, 3: True, "__all__": True})
|
||||
|
||||
def testRoundRobinMock(self):
|
||||
env = RoundRobinMultiAgent(2)
|
||||
obs = env.reset()
|
||||
self.assertEqual(obs, {0: 0, 1: 0})
|
||||
obs, rew, done, info = env.step({0: 0, 1: 0})
|
||||
self.assertEqual(obs, {0: 0})
|
||||
for _ in range(4):
|
||||
obs, rew, done, info = env.step({0: 0})
|
||||
self.assertEqual(obs, {1: 0})
|
||||
self.assertEqual(done["__all__"], False)
|
||||
obs, rew, done, info = env.step({1: 0})
|
||||
self.assertEqual(obs, {0: 0})
|
||||
self.assertEqual(done["__all__"], True)
|
||||
|
||||
def testVectorizeBasic(self):
|
||||
env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
self.assertEqual(
|
||||
dones,
|
||||
{0: {0: False, 1: False, "__all__": False},
|
||||
1: {0: False, 1: False, "__all__": False}})
|
||||
for _ in range(24):
|
||||
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
|
||||
self.assertEqual(
|
||||
dones,
|
||||
{0: {0: False, 1: False, "__all__": False},
|
||||
1: {0: False, 1: False, "__all__": False}})
|
||||
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(
|
||||
dones,
|
||||
{0: {0: True, 1: True, "__all__": True},
|
||||
1: {0: True, 1: True, "__all__": True}})
|
||||
|
||||
# Reset processing
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}))
|
||||
self.assertEqual(env.try_reset(0), {0: 0, 1: 0})
|
||||
self.assertEqual(env.try_reset(1), {0: 0, 1: 0})
|
||||
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
|
||||
self.assertEqual(
|
||||
dones,
|
||||
{0: {0: False, 1: False, "__all__": False},
|
||||
1: {0: False, 1: False, "__all__": False}})
|
||||
|
||||
def testVectorizeRoundRobin(self):
|
||||
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
env.send_actions({0: {0: 0}, 1: {0: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
env.send_actions({0: {0: 0}, 1: {0: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
unittest.main(verbosity=2)
|
||||
@@ -60,16 +60,16 @@ class PartOffPolicyServing(ServingEnv):
|
||||
|
||||
|
||||
class SimpleOffPolicyServing(ServingEnv):
|
||||
def __init__(self, env):
|
||||
def __init__(self, env, fixed_action):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
self.fixed_action = fixed_action
|
||||
|
||||
def run(self):
|
||||
eid = self.start_episode()
|
||||
obs = self.env.reset()
|
||||
while True:
|
||||
# Take random actions
|
||||
action = self.env.action_space.sample()
|
||||
action = self.fixed_action
|
||||
self.log_action(eid, obs, action)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.log_returns(eid, reward, info=info)
|
||||
@@ -131,13 +131,15 @@ class TestServingEnv(unittest.TestCase):
|
||||
|
||||
def testServingEnvOffPolicy(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)),
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
self.assertEqual(batch["actions"][0], 42)
|
||||
self.assertEqual(batch["actions"][-1], 42)
|
||||
|
||||
def testServingEnvBadActions(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
@@ -185,6 +187,16 @@ class TestServingEnv(unittest.TestCase):
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testServingEnvHorizonNotSupported(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
episode_horizon=20,
|
||||
batch_steps=10,
|
||||
batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
self.assertRaises(Exception, lambda: ev.sample())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
|
||||
@@ -2,66 +2,121 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class AsyncVectorEnv(object):
|
||||
"""The lowest-level env interface used by RLlib for sampling.
|
||||
|
||||
AsyncVectorEnv models multiple agents executing asynchronously. A call to
|
||||
poll() returns observations from ready agents, and actions for those agents
|
||||
AsyncVectorEnv models multiple agents executing asynchronously in multiple
|
||||
environments. A call to poll() returns observations from ready agents
|
||||
keyed by their environment and agent ids, and actions for those agents
|
||||
can be sent back via send_actions().
|
||||
|
||||
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
|
||||
conversions internally in CommonPolicyEvaluator, for example:
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
|
||||
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
|
||||
rllib.ServingEnv => rllib.AsyncVectorEnv
|
||||
|
||||
Examples:
|
||||
>>> env = MyAsyncVectorEnv()
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
{"car_0": [2.4, 1.6], "car_1": [3.4, -3.2]}
|
||||
>>> env.send_actions({"car_0": 0, "car_1": 1})
|
||||
{
|
||||
"env_0": {
|
||||
"car_0": [2.4, 1.6],
|
||||
"car_1": [3.4, -3.2],
|
||||
}
|
||||
}
|
||||
>>> env.send_actions(
|
||||
actions={
|
||||
"env_0": {
|
||||
"car_0": 0,
|
||||
"car_1": 1,
|
||||
}
|
||||
})
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
{"car_0": [4.1, 1.7], "car_1": [3.2, -4.2]}
|
||||
{
|
||||
"env_0": {
|
||||
"car_0": [4.1, 1.7],
|
||||
"car_1": [3.2, -4.2],
|
||||
}
|
||||
}
|
||||
>>> print(dones)
|
||||
{
|
||||
"env_0": {
|
||||
"__all__": False,
|
||||
"car_0": False,
|
||||
"car_1": True,
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_async(env, make_env=None, num_envs=1):
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
env = _MultiAgentEnvToAsync(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
elif isinstance(env, ServingEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"ServingEnv does not currently support num_envs > 1.")
|
||||
env = _ServingEnvToAsync(env)
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = _VectorEnvToAsync(env)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
env = _VectorEnvToAsync(env)
|
||||
assert isinstance(env, AsyncVectorEnv)
|
||||
return env
|
||||
|
||||
def poll(self):
|
||||
"""Returns observations from ready agents.
|
||||
|
||||
The returns are dicts mapping from agent episode ids to values. The
|
||||
number of agents can vary over time.
|
||||
The returns are two-level dicts mapping from env_id to a dict of
|
||||
agent_id to values. The number of agents and envs can vary over time.
|
||||
|
||||
Returns:
|
||||
obs (dict): New observations for each ready episode.
|
||||
rewards (dict): Reward values for each ready episode. If the
|
||||
obs (dict): New observations for each ready agent.
|
||||
rewards (dict): Reward values for each ready agent. If the
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready episode. If True, the
|
||||
episode is terminated.
|
||||
infos (dict): Info values for each ready episode.
|
||||
dones (dict): Done values for each ready agent. The special key
|
||||
"__all__" is used to indicate env termination.
|
||||
infos (dict): Info values for each ready agent.
|
||||
off_policy_actions (dict): Agents may take off-policy actions. When
|
||||
that happens, there will be an entry in this dict that contains
|
||||
the taken action.
|
||||
the taken action. There is no need to send_actions() for agents
|
||||
that have already chosen off-policy actions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
"""Called to send actions back to running agents in this env.
|
||||
|
||||
Actions should be sent for each ready agent that returned observations
|
||||
in the previous poll() call.
|
||||
|
||||
Arguments:
|
||||
action_dict (dict): Actions for each agent to take.
|
||||
action_dict (dict): Actions values keyed by env_id and agent_id.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def try_reset(self, agent_id):
|
||||
"""Attempt to reset the agent with the given id.
|
||||
def try_reset(self, env_id):
|
||||
"""Attempt to reset the env with the given id.
|
||||
|
||||
If the environment does not support synchronous reset, None can be
|
||||
returned here.
|
||||
|
||||
Returns:
|
||||
obs (obj|None): Resetted observation or None if not supported.
|
||||
obs (dict|None): Resetted observation or None if not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -74,8 +129,63 @@ class AsyncVectorEnv(object):
|
||||
return None
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
_DUMMY_AGENT_ID = "single_agent"
|
||||
|
||||
|
||||
def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
||||
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
|
||||
|
||||
|
||||
class _ServingEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
|
||||
|
||||
def __init__(self, serving_env):
|
||||
self.serving_env = serving_env
|
||||
serving_env.start()
|
||||
|
||||
def poll(self):
|
||||
with self.serving_env._results_avail_condition:
|
||||
results = self._poll()
|
||||
while len(results[0]) == 0:
|
||||
self.serving_env._results_avail_condition.wait()
|
||||
results = self._poll()
|
||||
if not self.serving_env.isAlive():
|
||||
raise Exception("Serving thread has stopped.")
|
||||
limit = self.serving_env._max_concurrent_episodes
|
||||
assert len(results[0]) < limit, \
|
||||
("Too many concurrent episodes, were some leaked? This ServingEnv "
|
||||
"was created with max_concurrent={}".format(limit))
|
||||
return results
|
||||
|
||||
def _poll(self):
|
||||
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
|
||||
off_policy_actions = {}
|
||||
for eid, episode in self.serving_env._episodes.copy().items():
|
||||
data = episode.get_data()
|
||||
if episode.cur_done:
|
||||
del self.serving_env._episodes[eid]
|
||||
if data:
|
||||
all_obs[eid] = data["obs"]
|
||||
all_rewards[eid] = data["reward"]
|
||||
all_dones[eid] = data["done"]
|
||||
all_infos[eid] = data["info"]
|
||||
if "off_policy_action" in data:
|
||||
off_policy_actions[eid] = data["off_policy_action"]
|
||||
return _with_dummy_agent_id(all_obs), \
|
||||
_with_dummy_agent_id(all_rewards), \
|
||||
_with_dummy_agent_id(all_dones, "__all__"), \
|
||||
_with_dummy_agent_id(all_infos), \
|
||||
_with_dummy_agent_id(off_policy_actions)
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for eid, action in action_dict.items():
|
||||
self.serving_env._episodes[eid].action_queue.put(
|
||||
action[_DUMMY_AGENT_ID])
|
||||
|
||||
|
||||
class _VectorEnvToAsync(AsyncVectorEnv):
|
||||
"""Wraps VectorEnv to implement AsyncVectorEnv.
|
||||
"""Internal adapter of VectorEnv to AsyncVectorEnv.
|
||||
|
||||
We assume the caller will always send the full vector of actions in each
|
||||
call to send_actions(), and that they call reset_at() on all completed
|
||||
@@ -99,17 +209,104 @@ class _VectorEnvToAsync(AsyncVectorEnv):
|
||||
self.cur_rewards = []
|
||||
self.cur_dones = []
|
||||
self.cur_infos = []
|
||||
return new_obs, rewards, dones, infos, {}
|
||||
return _with_dummy_agent_id(new_obs), \
|
||||
_with_dummy_agent_id(rewards), \
|
||||
_with_dummy_agent_id(dones, "__all__"), \
|
||||
_with_dummy_agent_id(infos), {}
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
action_vector = [None] * self.num_envs
|
||||
for i in range(self.num_envs):
|
||||
action_vector[i] = action_dict[i]
|
||||
action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
|
||||
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
|
||||
self.vector_env.vector_step(action_vector)
|
||||
|
||||
def try_reset(self, agent_id):
|
||||
return self.vector_env.reset_at(agent_id)
|
||||
def try_reset(self, env_id):
|
||||
return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
|
||||
|
||||
def get_unwrapped(self):
|
||||
return self.vector_env.get_unwrapped()
|
||||
|
||||
|
||||
class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of MultiAgentEnv to AsyncVectorEnv.
|
||||
|
||||
This also supports vectorization if num_envs > 1.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env, existing_envs, num_envs):
|
||||
"""Wrap existing multi-agent envs.
|
||||
|
||||
Arguments:
|
||||
make_env (func|None): Factory that produces a new multiagent env.
|
||||
Must be defined if the number of existing envs is less than
|
||||
num_envs.
|
||||
existing_envs (list): List of existing multiagent envs.
|
||||
num_envs (int): Desired num multiagent envs to keep total.
|
||||
"""
|
||||
self.make_env = make_env
|
||||
self.envs = existing_envs
|
||||
self.num_envs = num_envs
|
||||
self.dones = set()
|
||||
while len(self.envs) < self.num_envs:
|
||||
self.envs.append(self.make_env())
|
||||
for env in self.envs:
|
||||
assert isinstance(env, MultiAgentEnv)
|
||||
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
|
||||
|
||||
def poll(self):
|
||||
obs, rewards, dones, infos = {}, {}, {}, {}
|
||||
for i, env_state in enumerate(self.env_states):
|
||||
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
|
||||
return obs, rewards, dones, infos, {}
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for env_id, agent_dict in action_dict.items():
|
||||
if env_id in self.dones:
|
||||
raise ValueError("Env {} is already done".format(env_id))
|
||||
env = self.envs[env_id]
|
||||
obs, rewards, dones, infos = env.step(agent_dict)
|
||||
if dones["__all__"]:
|
||||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
def try_reset(self, env_id):
|
||||
obs = self.env_states[env_id].reset()
|
||||
if obs is not None:
|
||||
self.dones.remove(env_id)
|
||||
return obs
|
||||
|
||||
|
||||
class _MultiAgentEnvState(object):
|
||||
def __init__(self, env):
|
||||
assert isinstance(env, MultiAgentEnv)
|
||||
self.env = env
|
||||
self.reset()
|
||||
|
||||
def poll(self):
|
||||
if self.last_obs is None:
|
||||
raise ValueError("Need to send action after polling")
|
||||
obs, rew, dones, info = (
|
||||
self.last_obs, self.last_rewards, self.last_dones, self.last_infos)
|
||||
self.last_obs = None
|
||||
self.last_rewards = None
|
||||
self.last_dones = None
|
||||
self.last_infos = None
|
||||
return obs, rew, dones, info
|
||||
|
||||
def observe(self, obs, rewards, dones, infos):
|
||||
self.last_obs = obs
|
||||
self.last_rewards = rewards
|
||||
self.last_dones = dones
|
||||
self.last_infos = infos
|
||||
|
||||
def reset(self):
|
||||
self.last_obs = self.env.reset()
|
||||
self.last_rewards = {
|
||||
agent_id: None for agent_id in self.last_obs.keys()}
|
||||
self.last_dones = {
|
||||
agent_id: False for agent_id in self.last_obs.keys()}
|
||||
self.last_infos = {
|
||||
agent_id: {} for agent_id in self.last_obs.keys()}
|
||||
self.last_dones["__all__"] = False
|
||||
return self.last_obs
|
||||
|
||||
@@ -8,14 +8,15 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
from ray.rllib.optimizers import MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.tune.result import TrainingResult
|
||||
@@ -152,6 +153,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
self.env = env_creator(env_config)
|
||||
if isinstance(self.env, VectorEnv) or \
|
||||
isinstance(self.env, ServingEnv) or \
|
||||
isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, AsyncVectorEnv):
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
@@ -169,8 +171,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
def make_env():
|
||||
return wrap(env_creator(env_config))
|
||||
|
||||
self.policy_map = {}
|
||||
|
||||
if issubclass(policy_graph, TFPolicyGraph):
|
||||
with tf.Graph().as_default():
|
||||
if tf_session_creator:
|
||||
@@ -186,25 +186,21 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
policy = policy_graph(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
policy_config)
|
||||
|
||||
self.policy_map = {
|
||||
"default": policy
|
||||
}
|
||||
|
||||
self.obs_filter = get_filter(
|
||||
observation_filter, self.env.observation_space.shape)
|
||||
self.filters = {"obs_filter": self.obs_filter}
|
||||
self.filters = {
|
||||
# TODO(ekl) make the obs space dependent on policy
|
||||
policy_id: get_filter(
|
||||
observation_filter, self.env.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
if not isinstance(self.env, AsyncVectorEnv):
|
||||
if isinstance(self.env, ServingEnv):
|
||||
self.vector_env = _ServingEnvToAsync(self.env)
|
||||
else:
|
||||
if not isinstance(self.env, VectorEnv):
|
||||
self.env = VectorEnv.wrap(
|
||||
make_env, [self.env], num_envs=num_envs)
|
||||
self.vector_env = _VectorEnvToAsync(self.env)
|
||||
else:
|
||||
self.vector_env = self.env
|
||||
self.async_env = AsyncVectorEnv.wrap_async(
|
||||
self.env, make_env=make_env, num_envs=num_envs)
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
if batch_steps % num_envs != 0:
|
||||
@@ -222,19 +218,21 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
"Unsupported batch mode: {}".format(self.batch_mode))
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.vector_env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, horizon=episode_horizon, pack=pack_episodes)
|
||||
self.async_env, self.policy_map, lambda agent_id: "default",
|
||||
self.filters, batch_steps, horizon=episode_horizon,
|
||||
pack=pack_episodes)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.vector_env, self.policy_map["default"], self.obs_filter,
|
||||
batch_steps, horizon=episode_horizon, pack=pack_episodes)
|
||||
self.async_env, self.policy_map, lambda agent_id: "default",
|
||||
self.filters, batch_steps, horizon=episode_horizon,
|
||||
pack=pack_episodes)
|
||||
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
|
||||
Return:
|
||||
SampleBatch from evaluating the current policies.
|
||||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
batches = [self.sampler.get_data()]
|
||||
@@ -243,11 +241,16 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
batch = self.sampler.get_data()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batch = SampleBatch.concat_samples(batches)
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.compress_observations:
|
||||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
data["obs"] = [pack(o) for o in data["obs"]]
|
||||
data["new_obs"] = [pack(o) for o in data["new_obs"]]
|
||||
else:
|
||||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MultiAgentEnv(object):
|
||||
"""An environment that hosts multiple independent agents.
|
||||
|
||||
Agents are identified by (string) agent ids.
|
||||
|
||||
Examples:
|
||||
>>> env = MyMultiAgentEnv()
|
||||
>>> obs = env.reset()
|
||||
>>> print(obs)
|
||||
{
|
||||
"car_0": [2.4, 1.6],
|
||||
"car_1": [3.4, -3.2],
|
||||
"traffic_light_1": [0, 3, 5, 1],
|
||||
}
|
||||
>>> obs, rewards, dones, infos = env.step(
|
||||
action_dict={
|
||||
"car_0": 1, "car_1": 0, "traffic_light_1": 2,
|
||||
})
|
||||
>>> print(rewards)
|
||||
{
|
||||
"car_0": 3,
|
||||
"car_1": -1,
|
||||
"traffic_light_1": 0,
|
||||
}
|
||||
>>> print(dones)
|
||||
{
|
||||
"car_0": False,
|
||||
"car_1": True,
|
||||
"__all__": False,
|
||||
}
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""Resets the env and returns observations from ready agents.
|
||||
|
||||
Returns:
|
||||
obs (dict): New observations for each ready agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, action_dict):
|
||||
"""Returns observations from ready agents.
|
||||
|
||||
The returns are dicts mapping from agent_id strings to values. The
|
||||
number of agents in the env can vary over time.
|
||||
|
||||
Returns:
|
||||
obs (dict): New observations for each ready agent.
|
||||
rewards (dict): Reward values for each ready agent. If the
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready agent. The special key
|
||||
"__all__" is used to indicate env termination.
|
||||
infos (dict): Info values for each ready agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -25,7 +25,9 @@ class PolicyGraph(object):
|
||||
action_space (gym.Space): Action space of the env.
|
||||
config (dict): Policy-specific configuration data.
|
||||
"""
|
||||
pass
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
"""Compute actions for the current policy.
|
||||
@@ -70,8 +72,9 @@ class PolicyGraph(object):
|
||||
Arguments:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains the
|
||||
experience batches seen by other agents.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains a
|
||||
mapping of agent ids to (policy_graph, agent_batch) tuples
|
||||
containing the policy graph and experiences of the other agent.
|
||||
|
||||
Returns:
|
||||
SampleBatch: postprocessed sample batch.
|
||||
|
||||
+198
-141
@@ -7,13 +7,16 @@ import numpy as np
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatchBuilder
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv, _VectorEnvToAsync
|
||||
from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
|
||||
|
||||
CompletedRollout = namedtuple("CompletedRollout",
|
||||
["episode_length", "episode_reward"])
|
||||
RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"])
|
||||
|
||||
PolicyEvalData = namedtuple(
|
||||
"PolicyEvalData", ["env_id", "agent_id", "obs", "rnn_state"])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
@@ -26,26 +29,23 @@ class SyncSampler(object):
|
||||
thread."""
|
||||
|
||||
def __init__(
|
||||
self, env, policy, obs_filter, num_local_steps,
|
||||
horizon=None, pack=False):
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if not isinstance(env, VectorEnv):
|
||||
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
|
||||
env = _VectorEnvToAsync(env)
|
||||
self.async_vector_env = env
|
||||
self, env, policies, policy_mapping_fn, obs_filters,
|
||||
num_local_steps, horizon=None, pack=False):
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.rollout_provider = _env_runner(self.async_vector_env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter, pack)
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self._obs_filters = obs_filters
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, pack)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
while True:
|
||||
item = next(self.rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
if isinstance(item, RolloutMetrics):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
return item
|
||||
@@ -67,23 +67,20 @@ class AsyncSampler(threading.Thread):
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
|
||||
def __init__(
|
||||
self, env, policy, obs_filter, num_local_steps,
|
||||
horizon=None, pack=False):
|
||||
assert getattr(
|
||||
obs_filter, "is_concurrent",
|
||||
False), ("Observation Filter must support concurrent updates.")
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if not isinstance(env, VectorEnv):
|
||||
env = VectorEnv.wrap(make_env=None, existing_envs=[env])
|
||||
env = _VectorEnvToAsync(env)
|
||||
self.async_vector_env = env
|
||||
self, env, policies, policy_mapping_fn, obs_filters,
|
||||
num_local_steps, horizon=None, pack=False):
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.policy = policy
|
||||
self._obs_filter = obs_filter
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self._obs_filters = obs_filters
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
|
||||
@@ -95,15 +92,15 @@ class AsyncSampler(threading.Thread):
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(self.async_vector_env, self.policy,
|
||||
self.num_local_steps, self.horizon,
|
||||
self._obs_filter, self.pack)
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, self.pack)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
# set to some large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
if isinstance(item, RolloutMetrics):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
@@ -115,8 +112,9 @@ class AsyncSampler(threading.Thread):
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
|
||||
# We can't auto-concat rollouts in vector mode
|
||||
if self.async_vector_env.num_envs > 1:
|
||||
# We can't auto-concat rollouts in these modes
|
||||
if self.async_vector_env.num_envs > 1 or \
|
||||
isinstance(rollout, MultiAgentBatch):
|
||||
return rollout
|
||||
|
||||
# Auto-concat rollouts; TODO(ekl) is this important for A3C perf?
|
||||
@@ -141,23 +139,22 @@ class AsyncSampler(threading.Thread):
|
||||
|
||||
|
||||
def _env_runner(
|
||||
async_vector_env, policy, num_local_steps, horizon, obs_filter, pack):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
certain length, the thread runner appends the policy to the queue. Yields
|
||||
when `timestep_limit` is surpassed, environment terminates, or
|
||||
`num_local_steps` is reached.
|
||||
async_vector_env, policies, policy_mapping_fn, num_local_steps,
|
||||
horizon, obs_filters, pack):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
async_vector_env: env implementing AsyncVectorEnv.
|
||||
policy: Policy used to interact with environment. Also sets fields
|
||||
to be included in `SampleBatch`.
|
||||
num_local_steps: Number of steps before `SampleBatch` is yielded. Set
|
||||
to infinity to yield complete episodes.
|
||||
horizon: Horizon of the episode.
|
||||
obs_filter: Filter used to process observations.
|
||||
pack: Whether to pack multiple episodes into each batch. This
|
||||
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
|
||||
policies (dict): Map of policy ids to PolicyGraph instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
agent is then "bound" to the returned policy for the episode.
|
||||
num_local_steps (int): Number of episode steps before `SampleBatch` is
|
||||
yielded. Set to infinity to yield complete episodes.
|
||||
horizon (int): Horizon of the episode.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `num_local_steps` in size.
|
||||
|
||||
Yields:
|
||||
@@ -181,110 +178,131 @@ def _env_runner(
|
||||
if batch_builder_pool:
|
||||
return batch_builder_pool.pop()
|
||||
else:
|
||||
return SampleBatchBuilder()
|
||||
return MultiAgentSampleBatchBuilder(policies)
|
||||
|
||||
episodes = defaultdict(
|
||||
lambda: _Episode(policy.get_initial_state(), get_batch_builder))
|
||||
active_episodes = defaultdict(
|
||||
lambda: _MultiAgentEpisode(
|
||||
policies, policy_mapping_fn, get_batch_builder))
|
||||
|
||||
while True:
|
||||
# Get observations from ready envs
|
||||
# Get observations from all ready agents
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
async_vector_env.poll()
|
||||
ready_eids = []
|
||||
ready_obs = []
|
||||
ready_rnn_states = []
|
||||
|
||||
# Process and record the new observations
|
||||
for eid, raw_obs in unfiltered_obs.items():
|
||||
episode = episodes[eid]
|
||||
filtered_obs = obs_filter(raw_obs)
|
||||
ready_eids.append(eid)
|
||||
ready_obs.append(filtered_obs)
|
||||
ready_rnn_states.append(episode.rnn_state)
|
||||
# Map of policy_id to list of PolicyEvalData
|
||||
to_eval = defaultdict(list)
|
||||
|
||||
if episode.last_observation is None:
|
||||
episode.last_observation = filtered_obs
|
||||
continue # This is the initial observation after a reset
|
||||
# For each environment
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
new_episode = env_id not in active_episodes
|
||||
episode = active_episodes[env_id]
|
||||
if not new_episode:
|
||||
episode.length += 1
|
||||
episode.batch_builder.count += 1
|
||||
episode.add_agent_rewards(rewards[env_id])
|
||||
|
||||
episode.length += 1
|
||||
episode.total_reward += rewards[eid]
|
||||
|
||||
# Handle episode terminations
|
||||
if dones[eid] or episode.length >= horizon:
|
||||
done = True
|
||||
yield CompletedRollout(episode.length, episode.total_reward)
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
all_done = True
|
||||
yield RolloutMetrics(
|
||||
episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards))
|
||||
else:
|
||||
done = False
|
||||
all_done = False
|
||||
|
||||
if infos[eid].get("training_enabled", True):
|
||||
episode.batch_builder.add_values(
|
||||
obs=episode.last_observation,
|
||||
actions=episode.last_action_flat(),
|
||||
rewards=rewards[eid],
|
||||
dones=done,
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info)
|
||||
# For each agent in the environment
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = obs_filters[policy_id](raw_obs)
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id)))
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into
|
||||
# one, or if we've exceeded the requested batch size.
|
||||
if (done and not pack) or \
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode.set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
infos[env_id][agent_id].get("training_enabled", True):
|
||||
episode.batch_builder.add_values(
|
||||
agent_id,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
obs=last_observation,
|
||||
actions=episode.last_action_for(agent_id),
|
||||
rewards=rewards[env_id][agent_id],
|
||||
dones=agent_done,
|
||||
infos=infos[env_id][agent_id],
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= num_local_steps:
|
||||
yield episode.batch_builder.build_and_reset(
|
||||
policy.postprocess_trajectory)
|
||||
elif done:
|
||||
# Make sure postprocessor never crosses episode boundaries
|
||||
episode.batch_builder.postprocess_batch_so_far(
|
||||
policy.postprocess_trajectory)
|
||||
yield episode.batch_builder.build_and_reset()
|
||||
elif all_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
episode.batch_builder.postprocess_batch_so_far()
|
||||
|
||||
if done:
|
||||
if all_done:
|
||||
# Handle episode termination
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
del episodes[eid]
|
||||
resetted_obs = async_vector_env.try_reset(eid)
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = async_vector_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
assert horizon == float("inf"), \
|
||||
"Setting episode horizon requires reset() support."
|
||||
ready_eids.pop()
|
||||
ready_obs.pop()
|
||||
ready_rnn_states.pop()
|
||||
else:
|
||||
# Reset successful, put in the new obs as ready
|
||||
episode = episodes[eid]
|
||||
episode.last_observation = obs_filter(resetted_obs)
|
||||
ready_obs[-1] = episode.last_observation
|
||||
ready_rnn_states[-1] = episode.rnn_state
|
||||
else:
|
||||
episode.last_observation = filtered_obs
|
||||
# Creates a new episode
|
||||
episode = active_episodes[env_id]
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = obs_filters[policy_id](raw_obs)
|
||||
episode.set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id)))
|
||||
|
||||
if not ready_eids:
|
||||
continue # No actions to take
|
||||
# Map of env_id -> agent_id -> action
|
||||
action_dict = defaultdict(dict)
|
||||
|
||||
# Compute action for ready envs
|
||||
ready_rnn_state_cols = _to_column_format(ready_rnn_states)
|
||||
actions, new_rnn_state_cols, pi_info_cols = policy.compute_actions(
|
||||
ready_obs, ready_rnn_state_cols, is_training=True)
|
||||
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(ready_rnn_state_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(new_rnn_state_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# TODO(ekl) fuse all policy evaluation into one TF run
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
actions, rnn_out_cols, pi_info_cols = \
|
||||
policies[policy_id].compute_actions(
|
||||
[t.obs for t in eval_data], rnn_in_cols, is_training=True)
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# Save output rows
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
action_dict[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode.set_last_pi_info(
|
||||
agent_id, {k: v[i] for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode.set_last_action(
|
||||
agent_id, off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode.set_last_action(agent_id, action)
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
async_vector_env.send_actions(dict(zip(ready_eids, actions)))
|
||||
|
||||
# Store the computed action info
|
||||
for i, eid in enumerate(ready_eids):
|
||||
episode = episodes[eid]
|
||||
if eid in off_policy_actions:
|
||||
episode.last_action = off_policy_actions[eid]
|
||||
else:
|
||||
episode.last_action = actions[i]
|
||||
episode.rnn_state = [column[i] for column in new_rnn_state_cols]
|
||||
episode.last_pi_info = {
|
||||
k: column[i] for k, column in pi_info_cols.items()}
|
||||
async_vector_env.send_actions(dict(action_dict))
|
||||
|
||||
|
||||
def _to_column_format(rnn_state_rows):
|
||||
@@ -293,18 +311,57 @@ def _to_column_format(rnn_state_rows):
|
||||
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||
|
||||
|
||||
class _Episode(object):
|
||||
def __init__(self, init_rnn_state, batch_builder_factory):
|
||||
self.rnn_state = init_rnn_state
|
||||
class _MultiAgentEpisode(object):
|
||||
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
|
||||
self.batch_builder = batch_builder_factory()
|
||||
self.last_action = None
|
||||
self.last_observation = None
|
||||
self.last_pi_info = None
|
||||
self.total_reward = 0.0
|
||||
self.length = 0
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
self._agent_to_last_action = {}
|
||||
self._agent_to_last_pi_info = {}
|
||||
|
||||
def last_action_flat(self):
|
||||
# Concatenate multiagent actions
|
||||
if isinstance(self.last_action, list):
|
||||
return np.concatenate(self.last_action, axis=0).flatten()
|
||||
return self.last_action
|
||||
def add_agent_rewards(self, reward_dict):
|
||||
for agent_id, reward in reward_dict.items():
|
||||
self.agent_rewards[agent_id] += reward
|
||||
self.total_reward += reward
|
||||
|
||||
def policy_for(self, agent_id):
|
||||
if agent_id not in self._agent_to_policy:
|
||||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
def rnn_state_for(self, agent_id):
|
||||
if agent_id not in self._agent_to_rnn_state:
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
def last_observation_for(self, agent_id):
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
def last_action_for(self, agent_id):
|
||||
action = self._agent_to_last_action[agent_id]
|
||||
# Concatenate tuple actions
|
||||
if isinstance(action, list):
|
||||
action = np.concatenate(action, axis=0).flatten()
|
||||
return action
|
||||
|
||||
def last_pi_info_for(self, agent_id):
|
||||
return self._agent_to_last_pi_info[agent_id]
|
||||
|
||||
def set_rnn_state(self, agent_id, rnn_state):
|
||||
self._agent_to_rnn_state[agent_id] = rnn_state
|
||||
|
||||
def set_last_observation(self, agent_id, obs):
|
||||
self._agent_to_last_obs[agent_id] = obs
|
||||
|
||||
def set_last_action(self, agent_id, action):
|
||||
self._agent_to_last_action[agent_id] = action
|
||||
|
||||
def set_last_pi_info(self, agent_id, pi_info):
|
||||
self._agent_to_last_pi_info[agent_id] = pi_info
|
||||
|
||||
@@ -6,11 +6,9 @@ from six.moves import queue
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
|
||||
|
||||
|
||||
class ServingEnv(threading.Thread):
|
||||
"""Environment that provides policy serving.
|
||||
"""An environment that provides policy serving.
|
||||
|
||||
Unlike simulator envs, control is inverted. The environment queries the
|
||||
policy to obtain actions and logs observations and rewards for training.
|
||||
@@ -91,7 +89,7 @@ class ServingEnv(threading.Thread):
|
||||
raise ValueError(
|
||||
"Episode {} is already started".format(episode_id))
|
||||
|
||||
self._episodes[episode_id] = _Episode(
|
||||
self._episodes[episode_id] = _ServingEnvEpisode(
|
||||
episode_id, self._results_avail_condition, training_enabled)
|
||||
|
||||
return episode_id
|
||||
@@ -165,49 +163,7 @@ class ServingEnv(threading.Thread):
|
||||
return self._episodes[episode_id]
|
||||
|
||||
|
||||
class _ServingEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
|
||||
|
||||
def __init__(self, serving_env):
|
||||
self.serving_env = serving_env
|
||||
serving_env.start()
|
||||
|
||||
def poll(self):
|
||||
with self.serving_env._results_avail_condition:
|
||||
results = self._poll()
|
||||
while len(results[0]) == 0:
|
||||
self.serving_env._results_avail_condition.wait()
|
||||
results = self._poll()
|
||||
if not self.serving_env.isAlive():
|
||||
raise Exception("Serving thread has stopped.")
|
||||
limit = self.serving_env._max_concurrent_episodes
|
||||
assert len(results[0]) < limit, \
|
||||
("Too many concurrent episodes, were some leaked? This ServingEnv "
|
||||
"was created with max_concurrent={}".format(limit))
|
||||
return results
|
||||
|
||||
def _poll(self):
|
||||
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
|
||||
off_policy_actions = {}
|
||||
for eid, episode in self.serving_env._episodes.copy().items():
|
||||
data = episode.get_data()
|
||||
if episode.cur_done:
|
||||
del self.serving_env._episodes[eid]
|
||||
if data:
|
||||
all_obs[eid] = data["obs"]
|
||||
all_rewards[eid] = data["reward"]
|
||||
all_dones[eid] = data["done"]
|
||||
all_infos[eid] = data["info"]
|
||||
if "off_policy_action" in data:
|
||||
off_policy_actions[eid] = data["off_policy_action"]
|
||||
return all_obs, all_rewards, all_dones, all_infos, off_policy_actions
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for eid, action in action_dict.items():
|
||||
self.serving_env._episodes[eid].action_queue.put(action)
|
||||
|
||||
|
||||
class _Episode(object):
|
||||
class _ServingEnvEpisode(object):
|
||||
"""Tracked state for each active episode."""
|
||||
|
||||
def __init__(self, episode_id, results_avail_condition, training_enabled):
|
||||
|
||||
@@ -22,22 +22,44 @@ class VectorEnv(object):
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
|
||||
|
||||
def vector_reset(self):
|
||||
"""Resets all environments.
|
||||
|
||||
Returns:
|
||||
obs (list): Vector of observations from each environment.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_at(self, index):
|
||||
"""Resets a single environment.
|
||||
|
||||
Returns:
|
||||
obs (obj): Observations from the resetted environment.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def vector_step(self, actions):
|
||||
"""Vectorized step.
|
||||
|
||||
Arguments:
|
||||
actions (list): Actions for each env.
|
||||
|
||||
Returns:
|
||||
obs (list): New observations for each env.
|
||||
rewards (list): Reward values for each env.
|
||||
dones (list): Done values for each env.
|
||||
infos (list): Info values for each env.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_unwrapped(self):
|
||||
"""Returns a single instance of the underlying env."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _VectorizedGymEnv(VectorEnv):
|
||||
"""Internal wrapper for gym envs to implement VectorEnv.
|
||||
|
||||
Arguents:
|
||||
Arguments:
|
||||
make_env (func|None): Factory that produces a new gym env. Must be
|
||||
defined if the number of existing envs is less than num_envs.
|
||||
existing_envs (list): List of existing gym envs.
|
||||
|
||||
Reference in New Issue
Block a user