mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 06:19:29 +08:00
[rllib] Document creating an ensemble of envs; also add vector_index attribute to env config (#2513)
This also removes the async resetting code in VectorEnv. While that improves benchmark performance slightly, it substantially complicates env configuration and probably isn't worth it for most envs. This makes it easy to efficiently support setups like Joint PPO: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/retro-contest/gotta_learn_fast_report.pdf For example, for 188 envs, you could do something like num_envs: 10, num_envs_per_worker: 19.
This commit is contained in:
+1
-1
@@ -251,7 +251,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
self.num_envs = num_envs
|
||||
self.dones = set()
|
||||
while len(self.envs) < self.num_envs:
|
||||
self.envs.append(self.make_env())
|
||||
self.envs.append(self.make_env(len(self.envs)))
|
||||
for env in self.envs:
|
||||
assert isinstance(env, MultiAgentEnv)
|
||||
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
|
||||
|
||||
Vendored
+8
-1
@@ -15,8 +15,15 @@ class EnvContext(dict):
|
||||
Attributes:
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the worker the env is created in.
|
||||
vector_index (int): When there are multiple envs per worker, this
|
||||
uniquely identifies the env index within the worker.
|
||||
"""
|
||||
|
||||
def __init__(self, env_config, worker_index):
|
||||
def __init__(self, env_config, worker_index, vector_index=0):
|
||||
dict.__init__(self, env_config)
|
||||
self.worker_index = worker_index
|
||||
self.vector_index = vector_index
|
||||
|
||||
def with_vector_index(self, vector_index):
|
||||
return EnvContext(
|
||||
self, worker_index=self.worker_index, vector_index=vector_index)
|
||||
|
||||
Vendored
+2
-50
@@ -2,9 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
|
||||
class VectorEnv(object):
|
||||
"""An environment that supports batch evaluation.
|
||||
@@ -70,20 +67,14 @@ class _VectorizedGymEnv(VectorEnv):
|
||||
self.make_env = make_env
|
||||
self.envs = existing_envs
|
||||
self.num_envs = num_envs
|
||||
if make_env and num_envs > 1:
|
||||
self.resetter = _AsyncResetter(make_env, int(self.num_envs**0.5))
|
||||
else:
|
||||
self.resetter = _SimpleResetter(make_env)
|
||||
while len(self.envs) < self.num_envs:
|
||||
self.envs.append(self.make_env())
|
||||
self.envs.append(self.make_env(len(self.envs)))
|
||||
|
||||
def vector_reset(self):
|
||||
return [e.reset() for e in self.envs]
|
||||
|
||||
def reset_at(self, index):
|
||||
new_obs, new_env = self.resetter.trade_for_resetted(self.envs[index])
|
||||
self.envs[index] = new_env
|
||||
return new_obs
|
||||
return self.envs[index].reset()
|
||||
|
||||
def vector_step(self, actions):
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
@@ -97,42 +88,3 @@ class _VectorizedGymEnv(VectorEnv):
|
||||
|
||||
def get_unwrapped(self):
|
||||
return self.envs[0]
|
||||
|
||||
|
||||
class _AsyncResetter(threading.Thread):
|
||||
"""Does env reset asynchronously in the background.
|
||||
|
||||
This is useful since resetting an env can be 100x slower than stepping."""
|
||||
|
||||
def __init__(self, make_env, pool_size):
|
||||
threading.Thread.__init__(self)
|
||||
self.make_env = make_env
|
||||
self.pool_size = 0
|
||||
self.to_reset = queue.Queue()
|
||||
self.resetted = queue.Queue()
|
||||
self.daemon = True
|
||||
self.pool_size = pool_size
|
||||
while self.resetted.qsize() < self.pool_size:
|
||||
env = self.make_env()
|
||||
obs = env.reset()
|
||||
self.resetted.put((obs, env))
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
env = self.to_reset.get()
|
||||
obs = env.reset()
|
||||
self.resetted.put((obs, env))
|
||||
|
||||
def trade_for_resetted(self, env):
|
||||
self.to_reset.put(env)
|
||||
new_obs, new_env = self.resetted.get(timeout=30)
|
||||
return new_obs, new_env
|
||||
|
||||
|
||||
class _SimpleResetter(object):
|
||||
def __init__(self, make_env):
|
||||
pass
|
||||
|
||||
def trade_for_resetted(self, env):
|
||||
return env.reset(), env
|
||||
|
||||
@@ -190,8 +190,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env():
|
||||
return wrap(env_creator(env_context))
|
||||
def make_env(vector_index):
|
||||
return wrap(
|
||||
env_creator(env_context.with_vector_index(vector_index)))
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
||||
|
||||
@@ -160,7 +160,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(done["__all__"], True)
|
||||
|
||||
def testVectorizeBasic(self):
|
||||
env = _MultiAgentEnvToAsync(lambda: BasicMultiAgent(2), [], 2)
|
||||
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
@@ -236,7 +236,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
})
|
||||
|
||||
def testVectorizeRoundRobin(self):
|
||||
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
|
||||
env = _MultiAgentEnvToAsync(lambda v: RoundRobinMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
|
||||
|
||||
@@ -33,8 +33,9 @@ class BadPolicyGraph(PolicyGraph):
|
||||
|
||||
|
||||
class MockEnv(gym.Env):
|
||||
def __init__(self, episode_length):
|
||||
def __init__(self, episode_length, config=None):
|
||||
self.episode_length = episode_length
|
||||
self.config = config
|
||||
self.i = 0
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
@@ -150,7 +151,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
def testAutoVectorization(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=20),
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=16,
|
||||
@@ -165,6 +166,11 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result.episodes_total, 8)
|
||||
indices = []
|
||||
for env in ev.async_env.vector_env.envs:
|
||||
self.assertEqual(env.unwrapped.config.worker_index, 0)
|
||||
indices.append(env.unwrapped.config.vector_index)
|
||||
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
def testBatchDivisibilityCheck(self):
|
||||
self.assertRaises(
|
||||
|
||||
Reference in New Issue
Block a user