[rllib] Document creating an ensemble of envs; also add vector_index attribute to env config (#2513)

This also removes the async resetting code in VectorEnv. While that improves benchmark performance slightly, it substantially complicates env configuration and probably isn't worth it for most envs.

This makes it easy to efficiently support setups like Joint PPO: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/retro-contest/gotta_learn_fast_report.pdf

For example, for 188 envs, you could do something like num_envs: 10, num_envs_per_worker: 19.
This commit is contained in:
Eric Liang
2018-08-01 16:29:27 -07:00
committed by GitHub
parent a630e332f3
commit 9a479b3a63
8 changed files with 81 additions and 59 deletions
+1 -1
View File
@@ -251,7 +251,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
self.num_envs = num_envs
self.dones = set()
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
self.envs.append(self.make_env(len(self.envs)))
for env in self.envs:
assert isinstance(env, MultiAgentEnv)
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
+8 -1
View File
@@ -15,8 +15,15 @@ class EnvContext(dict):
Attributes:
worker_index (int): When there are multiple workers created, this
uniquely identifies the worker the env is created in.
vector_index (int): When there are multiple envs per worker, this
uniquely identifies the env index within the worker.
"""
def __init__(self, env_config, worker_index):
def __init__(self, env_config, worker_index, vector_index=0):
dict.__init__(self, env_config)
self.worker_index = worker_index
self.vector_index = vector_index
def with_vector_index(self, vector_index):
return EnvContext(
self, worker_index=self.worker_index, vector_index=vector_index)
+2 -50
View File
@@ -2,9 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import queue
import threading
class VectorEnv(object):
"""An environment that supports batch evaluation.
@@ -70,20 +67,14 @@ class _VectorizedGymEnv(VectorEnv):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
if make_env and num_envs > 1:
self.resetter = _AsyncResetter(make_env, int(self.num_envs**0.5))
else:
self.resetter = _SimpleResetter(make_env)
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env())
self.envs.append(self.make_env(len(self.envs)))
def vector_reset(self):
return [e.reset() for e in self.envs]
def reset_at(self, index):
new_obs, new_env = self.resetter.trade_for_resetted(self.envs[index])
self.envs[index] = new_env
return new_obs
return self.envs[index].reset()
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
@@ -97,42 +88,3 @@ class _VectorizedGymEnv(VectorEnv):
def get_unwrapped(self):
return self.envs[0]
class _AsyncResetter(threading.Thread):
"""Does env reset asynchronously in the background.
This is useful since resetting an env can be 100x slower than stepping."""
def __init__(self, make_env, pool_size):
threading.Thread.__init__(self)
self.make_env = make_env
self.pool_size = 0
self.to_reset = queue.Queue()
self.resetted = queue.Queue()
self.daemon = True
self.pool_size = pool_size
while self.resetted.qsize() < self.pool_size:
env = self.make_env()
obs = env.reset()
self.resetted.put((obs, env))
self.start()
def run(self):
while True:
env = self.to_reset.get()
obs = env.reset()
self.resetted.put((obs, env))
def trade_for_resetted(self, env):
self.to_reset.put(env)
new_obs, new_env = self.resetted.get(timeout=30)
return new_obs, new_env
class _SimpleResetter(object):
def __init__(self, make_env):
pass
def trade_for_resetted(self, env):
return env.reset(), env
@@ -190,8 +190,9 @@ class PolicyEvaluator(EvaluatorInterface):
self.env = wrap(self.env)
def make_env():
return wrap(env_creator(env_context))
def make_env(vector_index):
return wrap(
env_creator(env_context.with_vector_index(vector_index)))
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
@@ -160,7 +160,7 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(done["__all__"], True)
def testVectorizeBasic(self):
env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2)
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
@@ -236,7 +236,7 @@ class TestMultiAgentEnv(unittest.TestCase):
})
def testVectorizeRoundRobin(self):
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
env = _MultiAgentEnvToAsync(lambda v: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
@@ -33,8 +33,9 @@ class BadPolicyGraph(PolicyGraph):
class MockEnv(gym.Env):
def __init__(self, episode_length):
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
@@ -150,7 +151,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testAutoVectorization(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=20),
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16,
@@ -165,6 +166,11 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result.episodes_total, 8)
indices = []
for env in ev.async_env.vector_env.envs:
self.assertEqual(env.unwrapped.config.worker_index, 0)
indices.append(env.unwrapped.config.vector_index)
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
def testBatchDivisibilityCheck(self):
self.assertRaises(