[rllib] Add async remote workers (#4253)

This commit is contained in:
Eric Liang
2019-03-08 15:39:48 -08:00
committed by GitHub
parent fd2d8c2c06
commit c7f74dbdc7
9 changed files with 230 additions and 124 deletions
+10 -19
View File
@@ -134,6 +134,9 @@ COMMON_CONFIG = {
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
"remote_worker_envs": False,
# Similar to remote_worker_envs, but runs the envs asynchronously in the
# background for greater efficiency. Conflicts with remote_worker_envs.
"async_remote_worker_envs": False,
# === Offline Datasets ===
# __sphinx_doc_input_begin__
@@ -473,9 +476,7 @@ class Agent(Trainable):
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}),
remote_worker_envs=False,
)
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
@@ -490,14 +491,8 @@ class Agent(Trainable):
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(
cls,
env_creator,
policy_graph,
i + 1,
self.config,
remote_worker_envs=self.config["remote_worker_envs"])
for i in range(count)
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
]
@DeveloperAPI
@@ -563,13 +558,8 @@ class Agent(Trainable):
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _make_evaluator(self,
cls,
env_creator,
policy_graph,
worker_index,
config,
remote_worker_envs=False):
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
@@ -639,7 +629,8 @@ class Agent(Trainable):
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=remote_worker_envs)
remote_worker_envs=config["remote_worker_envs"],
async_remote_worker_envs=config["async_remote_worker_envs"])
@override(Trainable)
def _export_model(self, export_formats, export_dir):
+49 -20
View File
@@ -38,14 +38,22 @@ class BaseEnv(object):
"env_0": {
"car_0": [2.4, 1.6],
"car_1": [3.4, -3.2],
}
},
"env_1": {
"car_0": [8.0, 4.1],
},
"env_2": {
"car_0": [2.3, 3.3],
"car_1": [1.4, -0.2],
"car_3": [1.2, 0.1],
},
}
>>> env.send_actions(
actions={
"env_0": {
"car_0": 0,
"car_1": 1,
}
}, ...
})
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
@@ -53,7 +61,7 @@ class BaseEnv(object):
"env_0": {
"car_0": [4.1, 1.7],
"car_1": [3.2, -4.2],
}
}, ...
}
>>> print(dones)
{
@@ -61,25 +69,40 @@ class BaseEnv(object):
"__all__": False,
"car_0": False,
"car_1": True,
}
}, ...
}
"""
@staticmethod
def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
def to_base_env(env,
make_env=None,
num_envs=1,
remote_envs=False,
async_remote_envs=False):
"""Wraps any env type as needed to expose the async interface."""
if remote_envs and num_envs == 1:
from ray.rllib.env.remote_vector_env import RemoteVectorEnv
if (remote_envs or async_remote_envs) and num_envs == 1:
raise ValueError(
"Remote envs only make sense to use if num_envs > 1 "
"(i.e. vectorization is enabled).")
if remote_envs and async_remote_envs:
raise ValueError("You can only specify one of remote_envs or "
"async_remote_envs.")
if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv):
if remote_envs:
raise NotImplementedError(
"Remote multiagent environments are not implemented")
env = _MultiAgentEnvToBaseEnv(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
env = RemoteVectorEnv(
make_env, num_envs, multiagent=True, sync=True)
elif async_remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=True, sync=False)
else:
env = _MultiAgentEnvToBaseEnv(
make_env=make_env,
existing_envs=[env],
num_envs=num_envs)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
@@ -88,15 +111,21 @@ class BaseEnv(object):
elif isinstance(env, VectorEnv):
env = _VectorEnvToBaseEnv(env)
else:
env = VectorEnv.wrap(
make_env=make_env,
existing_envs=[env],
num_envs=num_envs,
remote_envs=remote_envs,
action_space=env.action_space,
observation_space=env.observation_space)
env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv)
if remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=False, sync=True)
elif async_remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=False, sync=False)
else:
env = VectorEnv.wrap(
make_env=make_env,
existing_envs=[env],
num_envs=num_envs,
action_space=env.action_space,
observation_space=env.observation_space)
env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv), env
return env
@PublicAPI
+118
View File
@@ -0,0 +1,118 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
logger = logging.getLogger(__name__)
class RemoteVectorEnv(BaseEnv):
"""Vector env that executes envs in remote workers.
This provides dynamic batching of inference as observations are returned
from the remote simulator actors. Both single and multi-agent child envs
are supported, and envs can be stepped synchronously or async.
"""
def __init__(self, make_env, num_envs, multiagent, sync):
self.make_local_env = make_env
if sync:
self.timeout = 9999999.0 # wait for all envs
else:
self.timeout = 0.0 # wait for only ready envs
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
if multiagent:
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
else:
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
self.actors = [make_remote_env(i) for i in range(num_envs)]
self.pending = None # lazy init
def poll(self):
if self.pending is None:
self.pending = {a.reset.remote(): a for a in self.actors}
# each keyed by env_id in [0, num_remote_envs)
obs, rewards, dones, infos = {}, {}, {}, {}
ready = []
# Wait for at least 1 env to be ready here
while not ready:
ready, _ = ray.wait(
list(self.pending),
num_returns=len(self.pending),
timeout=self.timeout)
# Get and return observations for each of the ready envs
env_ids = set()
for obj_id in ready:
actor = self.pending.pop(obj_id)
env_id = self.actors.index(actor)
env_ids.add(env_id)
ob, rew, done, info = ray.get(obj_id)
obs[env_id] = ob
rewards[env_id] = rew
dones[env_id] = done
infos[env_id] = info
logger.debug("Got obs batch for actors {}".format(env_ids))
return obs, rewards, dones, infos, {}
def send_actions(self, action_dict):
for env_id, actions in action_dict.items():
actor = self.actors[env_id]
obj_id = actor.step.remote(actions)
self.pending[obj_id] = actor
def try_reset(self, env_id):
obs, _, _, _ = ray.get(self.actors[env_id].reset.remote())
return obs
@ray.remote(num_cpus=0)
class _RemoteMultiAgentEnv(object):
"""Wrapper class for making a multi-agent env a remote actor."""
def __init__(self, make_env, i):
self.env = make_env(i)
def reset(self):
obs = self.env.reset()
# each keyed by agent_id in the env
rew = {agent_id: 0 for agent_id in obs.keys()}
info = {agent_id: {} for agent_id in obs.keys()}
done = {"__all__": False}
return obs, rew, done, info
def step(self, action_dict):
return self.env.step(action_dict)
@ray.remote(num_cpus=0)
class _RemoteSingleAgentEnv(object):
"""Wrapper class for making a gym env a remote actor."""
def __init__(self, make_env, i):
self.env = make_env(i)
def reset(self):
obs = {_DUMMY_AGENT_ID: self.env.reset()}
rew = {agent_id: 0 for agent_id in obs.keys()}
info = {agent_id: {} for agent_id in obs.keys()}
done = {"__all__": False}
return obs, rew, done, info
def step(self, action):
obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID])
obs, rew, done, info = [{
_DUMMY_AGENT_ID: x
} for x in [obs, rew, done, info]]
done["__all__"] = done[_DUMMY_AGENT_ID]
return obs, rew, done, info
-73
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
import logging
import numpy as np
import ray
from ray.rllib.utils.annotations import override, PublicAPI
logger = logging.getLogger(__name__)
@@ -27,12 +26,8 @@ class VectorEnv(object):
def wrap(make_env=None,
existing_envs=None,
num_envs=1,
remote_envs=False,
action_space=None,
observation_space=None):
if remote_envs:
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
observation_space)
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)
@@ -129,71 +124,3 @@ class _VectorizedGymEnv(VectorEnv):
@override(VectorEnv)
def get_unwrapped(self):
return self.envs
@ray.remote(num_cpus=0)
class _RemoteEnv(object):
"""Wrapper class for making a gym env a remote actor."""
def __init__(self, make_env, i):
self.env = make_env(i)
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
"""
def __init__(self,
make_env,
num_envs,
action_space=None,
observation_space=None):
self.make_local_env = make_env
self.num_envs = num_envs
self.initialized = False
self.action_space = action_space
self.observation_space = observation_space
def _initialize_if_needed(self):
if self.initialized:
return
self.initialized = True
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
return _RemoteEnv.remote(self.make_local_env, i)
_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
self.action_space, self.observation_space)
for env in self.envs:
assert isinstance(env, ray.actor.ActorHandle), env
@override(_VectorizedGymEnv)
def vector_reset(self):
self._initialize_if_needed()
return ray.get([env.reset.remote() for env in self.envs])
@override(_VectorizedGymEnv)
def reset_at(self, index):
return ray.get(self.envs[index].reset.remote())
@override(_VectorizedGymEnv)
def vector_step(self, actions):
step_outs = ray.get(
[env.step.remote(act) for env, act in zip(self.envs, actions)])
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for obs, rew, done, info in step_outs:
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
@@ -122,7 +122,8 @@ class PolicyEvaluator(EvaluatorInterface):
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation=frozenset([]),
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False):
remote_worker_envs=False,
async_remote_worker_envs=False):
"""Initialize a policy evaluator.
Arguments:
@@ -201,6 +202,8 @@ class PolicyEvaluator(EvaluatorInterface):
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
are very CPU intensive (e.g., for StarCraft).
async_remote_worker_envs (bool): Similar to remote_worker_envs,
but runs the envs asynchronously in the background.
"""
if log_level:
@@ -307,7 +310,8 @@ class PolicyEvaluator(EvaluatorInterface):
self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs)
remote_envs=remote_worker_envs,
async_remote_envs=async_remote_worker_envs)
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
+33 -1
View File
@@ -334,6 +334,38 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
list(range(25)) * 6)
def testMultiAgentSampleSyncRemote(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
num_envs=4,
remote_worker_envs=True)
batch = ev.sample()
self.assertEqual(batch.count, 200)
def testMultiAgentSampleAsyncRemote(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
num_envs=4,
async_remote_worker_envs=True)
batch = ev.sample()
self.assertEqual(batch.count, 200)
def testMultiAgentSampleWithHorizon(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
@@ -621,5 +653,5 @@ class TestMultiAgentEnv(unittest.TestCase):
if __name__ == "__main__":
ray.init()
ray.init(num_cpus=4)
unittest.main(verbosity=2)