mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:34:51 +08:00
[rllib] Add async remote workers (#4253)
This commit is contained in:
@@ -134,6 +134,9 @@ COMMON_CONFIG = {
|
||||
# remote processes instead of in the same worker. This adds overheads, but
|
||||
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
|
||||
"remote_worker_envs": False,
|
||||
# Similar to remote_worker_envs, but runs the envs asynchronously in the
|
||||
# background for greater efficiency. Conflicts with remote_worker_envs.
|
||||
"async_remote_worker_envs": False,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# __sphinx_doc_input_begin__
|
||||
@@ -473,9 +476,7 @@ class Agent(Trainable):
|
||||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}),
|
||||
extra_config or {}),
|
||||
remote_worker_envs=False,
|
||||
)
|
||||
extra_config or {}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
@@ -490,14 +491,8 @@ class Agent(Trainable):
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(
|
||||
cls,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
i + 1,
|
||||
self.config,
|
||||
remote_worker_envs=self.config["remote_worker_envs"])
|
||||
for i in range(count)
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -563,13 +558,8 @@ class Agent(Trainable):
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _make_evaluator(self,
|
||||
cls,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
worker_index,
|
||||
config,
|
||||
remote_worker_envs=False):
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
@@ -639,7 +629,8 @@ class Agent(Trainable):
|
||||
input_creator=input_creator,
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=remote_worker_envs)
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
async_remote_worker_envs=config["async_remote_worker_envs"])
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
||||
Vendored
+49
-20
@@ -38,14 +38,22 @@ class BaseEnv(object):
|
||||
"env_0": {
|
||||
"car_0": [2.4, 1.6],
|
||||
"car_1": [3.4, -3.2],
|
||||
}
|
||||
},
|
||||
"env_1": {
|
||||
"car_0": [8.0, 4.1],
|
||||
},
|
||||
"env_2": {
|
||||
"car_0": [2.3, 3.3],
|
||||
"car_1": [1.4, -0.2],
|
||||
"car_3": [1.2, 0.1],
|
||||
},
|
||||
}
|
||||
>>> env.send_actions(
|
||||
actions={
|
||||
"env_0": {
|
||||
"car_0": 0,
|
||||
"car_1": 1,
|
||||
}
|
||||
}, ...
|
||||
})
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
@@ -53,7 +61,7 @@ class BaseEnv(object):
|
||||
"env_0": {
|
||||
"car_0": [4.1, 1.7],
|
||||
"car_1": [3.2, -4.2],
|
||||
}
|
||||
}, ...
|
||||
}
|
||||
>>> print(dones)
|
||||
{
|
||||
@@ -61,25 +69,40 @@ class BaseEnv(object):
|
||||
"__all__": False,
|
||||
"car_0": False,
|
||||
"car_1": True,
|
||||
}
|
||||
}, ...
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
|
||||
def to_base_env(env,
|
||||
make_env=None,
|
||||
num_envs=1,
|
||||
remote_envs=False,
|
||||
async_remote_envs=False):
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
if remote_envs and num_envs == 1:
|
||||
|
||||
from ray.rllib.env.remote_vector_env import RemoteVectorEnv
|
||||
if (remote_envs or async_remote_envs) and num_envs == 1:
|
||||
raise ValueError(
|
||||
"Remote envs only make sense to use if num_envs > 1 "
|
||||
"(i.e. vectorization is enabled).")
|
||||
if remote_envs and async_remote_envs:
|
||||
raise ValueError("You can only specify one of remote_envs or "
|
||||
"async_remote_envs.")
|
||||
|
||||
if not isinstance(env, BaseEnv):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
if remote_envs:
|
||||
raise NotImplementedError(
|
||||
"Remote multiagent environments are not implemented")
|
||||
|
||||
env = _MultiAgentEnvToBaseEnv(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=True, sync=True)
|
||||
elif async_remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=True, sync=False)
|
||||
else:
|
||||
env = _MultiAgentEnvToBaseEnv(
|
||||
make_env=make_env,
|
||||
existing_envs=[env],
|
||||
num_envs=num_envs)
|
||||
elif isinstance(env, ExternalEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
@@ -88,15 +111,21 @@ class BaseEnv(object):
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env,
|
||||
existing_envs=[env],
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_envs,
|
||||
action_space=env.action_space,
|
||||
observation_space=env.observation_space)
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
assert isinstance(env, BaseEnv)
|
||||
if remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=False, sync=True)
|
||||
elif async_remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=False, sync=False)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env,
|
||||
existing_envs=[env],
|
||||
num_envs=num_envs,
|
||||
action_space=env.action_space,
|
||||
observation_space=env.observation_space)
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
assert isinstance(env, BaseEnv), env
|
||||
return env
|
||||
|
||||
@PublicAPI
|
||||
|
||||
+118
@@ -0,0 +1,118 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteVectorEnv(BaseEnv):
|
||||
"""Vector env that executes envs in remote workers.
|
||||
|
||||
This provides dynamic batching of inference as observations are returned
|
||||
from the remote simulator actors. Both single and multi-agent child envs
|
||||
are supported, and envs can be stepped synchronously or async.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env, num_envs, multiagent, sync):
|
||||
self.make_local_env = make_env
|
||||
if sync:
|
||||
self.timeout = 9999999.0 # wait for all envs
|
||||
else:
|
||||
self.timeout = 0.0 # wait for only ready envs
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
if multiagent:
|
||||
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
|
||||
else:
|
||||
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
|
||||
|
||||
self.actors = [make_remote_env(i) for i in range(num_envs)]
|
||||
self.pending = None # lazy init
|
||||
|
||||
def poll(self):
|
||||
if self.pending is None:
|
||||
self.pending = {a.reset.remote(): a for a in self.actors}
|
||||
|
||||
# each keyed by env_id in [0, num_remote_envs)
|
||||
obs, rewards, dones, infos = {}, {}, {}, {}
|
||||
ready = []
|
||||
|
||||
# Wait for at least 1 env to be ready here
|
||||
while not ready:
|
||||
ready, _ = ray.wait(
|
||||
list(self.pending),
|
||||
num_returns=len(self.pending),
|
||||
timeout=self.timeout)
|
||||
|
||||
# Get and return observations for each of the ready envs
|
||||
env_ids = set()
|
||||
for obj_id in ready:
|
||||
actor = self.pending.pop(obj_id)
|
||||
env_id = self.actors.index(actor)
|
||||
env_ids.add(env_id)
|
||||
ob, rew, done, info = ray.get(obj_id)
|
||||
obs[env_id] = ob
|
||||
rewards[env_id] = rew
|
||||
dones[env_id] = done
|
||||
infos[env_id] = info
|
||||
|
||||
logger.debug("Got obs batch for actors {}".format(env_ids))
|
||||
return obs, rewards, dones, infos, {}
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for env_id, actions in action_dict.items():
|
||||
actor = self.actors[env_id]
|
||||
obj_id = actor.step.remote(actions)
|
||||
self.pending[obj_id] = actor
|
||||
|
||||
def try_reset(self, env_id):
|
||||
obs, _, _, _ = ray.get(self.actors[env_id].reset.remote())
|
||||
return obs
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteMultiAgentEnv(object):
|
||||
"""Wrapper class for making a multi-agent env a remote actor."""
|
||||
|
||||
def __init__(self, make_env, i):
|
||||
self.env = make_env(i)
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
# each keyed by agent_id in the env
|
||||
rew = {agent_id: 0 for agent_id in obs.keys()}
|
||||
info = {agent_id: {} for agent_id in obs.keys()}
|
||||
done = {"__all__": False}
|
||||
return obs, rew, done, info
|
||||
|
||||
def step(self, action_dict):
|
||||
return self.env.step(action_dict)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteSingleAgentEnv(object):
|
||||
"""Wrapper class for making a gym env a remote actor."""
|
||||
|
||||
def __init__(self, make_env, i):
|
||||
self.env = make_env(i)
|
||||
|
||||
def reset(self):
|
||||
obs = {_DUMMY_AGENT_ID: self.env.reset()}
|
||||
rew = {agent_id: 0 for agent_id in obs.keys()}
|
||||
info = {agent_id: {} for agent_id in obs.keys()}
|
||||
done = {"__all__": False}
|
||||
return obs, rew, done, info
|
||||
|
||||
def step(self, action):
|
||||
obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID])
|
||||
obs, rew, done, info = [{
|
||||
_DUMMY_AGENT_ID: x
|
||||
} for x in [obs, rew, done, info]]
|
||||
done["__all__"] = done[_DUMMY_AGENT_ID]
|
||||
return obs, rew, done, info
|
||||
Vendored
-73
@@ -5,7 +5,6 @@ from __future__ import print_function
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -27,12 +26,8 @@ class VectorEnv(object):
|
||||
def wrap(make_env=None,
|
||||
existing_envs=None,
|
||||
num_envs=1,
|
||||
remote_envs=False,
|
||||
action_space=None,
|
||||
observation_space=None):
|
||||
if remote_envs:
|
||||
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
|
||||
observation_space)
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
|
||||
action_space, observation_space)
|
||||
|
||||
@@ -129,71 +124,3 @@ class _VectorizedGymEnv(VectorEnv):
|
||||
@override(VectorEnv)
|
||||
def get_unwrapped(self):
|
||||
return self.envs
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteEnv(object):
|
||||
"""Wrapper class for making a gym env a remote actor."""
|
||||
|
||||
def __init__(self, make_env, i):
|
||||
self.env = make_env(i)
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
|
||||
class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
|
||||
"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
make_env,
|
||||
num_envs,
|
||||
action_space=None,
|
||||
observation_space=None):
|
||||
self.make_local_env = make_env
|
||||
self.num_envs = num_envs
|
||||
self.initialized = False
|
||||
self.action_space = action_space
|
||||
self.observation_space = observation_space
|
||||
|
||||
def _initialize_if_needed(self):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
return _RemoteEnv.remote(self.make_local_env, i)
|
||||
|
||||
_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
|
||||
self.action_space, self.observation_space)
|
||||
|
||||
for env in self.envs:
|
||||
assert isinstance(env, ray.actor.ActorHandle), env
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def vector_reset(self):
|
||||
self._initialize_if_needed()
|
||||
return ray.get([env.reset.remote() for env in self.envs])
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def reset_at(self, index):
|
||||
return ray.get(self.envs[index].reset.remote())
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def vector_step(self, actions):
|
||||
step_outs = ray.get(
|
||||
[env.step.remote(act) for env, act in zip(self.envs, actions)])
|
||||
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
for obs, rew, done, info in step_outs:
|
||||
obs_batch.append(obs)
|
||||
rew_batch.append(rew)
|
||||
done_batch.append(done)
|
||||
info_batch.append(info)
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
||||
@@ -122,7 +122,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False):
|
||||
remote_worker_envs=False,
|
||||
async_remote_worker_envs=False):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
@@ -201,6 +202,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
are very CPU intensive (e.g., for StarCraft).
|
||||
async_remote_worker_envs (bool): Similar to remote_worker_envs,
|
||||
but runs the envs asynchronously in the background.
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
@@ -307,7 +310,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.env,
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs)
|
||||
remote_envs=remote_worker_envs,
|
||||
async_remote_envs=async_remote_worker_envs)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
|
||||
@@ -334,6 +334,38 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
|
||||
list(range(25)) * 6)
|
||||
|
||||
def testMultiAgentSampleSyncRemote(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 200)
|
||||
|
||||
def testMultiAgentSampleAsyncRemote(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
num_envs=4,
|
||||
async_remote_worker_envs=True)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 200)
|
||||
|
||||
def testMultiAgentSampleWithHorizon(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
@@ -621,5 +653,5 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
ray.init(num_cpus=4)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user