mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:49:45 +08:00
[rllib] Make batch timeout for remote workers tunable (#4435)
This commit is contained in:
@@ -147,11 +147,14 @@ COMMON_CONFIG = {
|
||||
"metrics_smoothing_episodes": 100,
|
||||
# If using num_envs_per_worker > 1, whether to create those new envs in
|
||||
# remote processes instead of in the same worker. This adds overheads, but
|
||||
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
|
||||
# can make sense if your envs can take much time to step / reset
|
||||
# (e.g., for StarCraft)
|
||||
"remote_worker_envs": False,
|
||||
# Similar to remote_worker_envs, but runs the envs asynchronously in the
|
||||
# background for greater efficiency. Conflicts with remote_worker_envs.
|
||||
"async_remote_worker_envs": False,
|
||||
# Timeout that remote workers are waiting when polling environments.
|
||||
# 0 (continue when at least one env is ready) is a reasonable default,
|
||||
# but optimal value could be obtained by measuring your environment
|
||||
# step / reset and model inference perf.
|
||||
"remote_env_batch_wait_ms": 0,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
@@ -378,10 +381,18 @@ class Agent(Trainable):
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
# Call stop on all evaluators to release resources
|
||||
if hasattr(self, "local_evaluator"):
|
||||
self.local_evaluator.stop()
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.stop.remote()
|
||||
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
if hasattr(self, "optimizer"):
|
||||
self.optimizer.stop()
|
||||
|
||||
@@ -657,12 +668,12 @@ class Agent(Trainable):
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
MixedInput(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
JsonReader(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
@@ -724,7 +735,7 @@ class Agent(Trainable):
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
async_remote_worker_envs=config["async_remote_worker_envs"])
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"])
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
||||
Vendored
+17
-13
@@ -7,6 +7,8 @@ from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
ASYNC_RESET_RETURN = "async_reset_return"
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class BaseEnv(object):
|
||||
@@ -78,26 +80,23 @@ class BaseEnv(object):
|
||||
make_env=None,
|
||||
num_envs=1,
|
||||
remote_envs=False,
|
||||
async_remote_envs=False):
|
||||
remote_env_batch_wait_ms=0):
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
|
||||
from ray.rllib.env.remote_vector_env import RemoteVectorEnv
|
||||
if (remote_envs or async_remote_envs) and num_envs == 1:
|
||||
if remote_envs and num_envs == 1:
|
||||
raise ValueError(
|
||||
"Remote envs only make sense to use if num_envs > 1 "
|
||||
"(i.e. vectorization is enabled).")
|
||||
if remote_envs and async_remote_envs:
|
||||
raise ValueError("You can only specify one of remote_envs or "
|
||||
"async_remote_envs.")
|
||||
|
||||
if not isinstance(env, BaseEnv):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
if remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=True, sync=True)
|
||||
elif async_remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=True, sync=False)
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=True,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
else:
|
||||
env = _MultiAgentEnvToBaseEnv(
|
||||
make_env=make_env,
|
||||
@@ -113,10 +112,10 @@ class BaseEnv(object):
|
||||
else:
|
||||
if remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=False, sync=True)
|
||||
elif async_remote_envs:
|
||||
env = RemoteVectorEnv(
|
||||
make_env, num_envs, multiagent=False, sync=False)
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=False,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env,
|
||||
@@ -184,6 +183,11 @@ class BaseEnv(object):
|
||||
"""
|
||||
return []
|
||||
|
||||
@PublicAPI
|
||||
def stop(self):
|
||||
"""Releases all resources used."""
|
||||
pass
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
_DUMMY_AGENT_ID = "agent0"
|
||||
|
||||
+28
-17
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,24 +18,28 @@ class RemoteVectorEnv(BaseEnv):
|
||||
are supported, and envs can be stepped synchronously or async.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env, num_envs, multiagent, sync):
|
||||
def __init__(self, make_env, num_envs, multiagent,
|
||||
remote_env_batch_wait_ms):
|
||||
self.make_local_env = make_env
|
||||
if sync:
|
||||
self.timeout = 9999999.0 # wait for all envs
|
||||
else:
|
||||
self.timeout = 0.0 # wait for only ready envs
|
||||
self.num_envs = num_envs
|
||||
self.multiagent = multiagent
|
||||
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
if multiagent:
|
||||
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
|
||||
else:
|
||||
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
|
||||
|
||||
self.actors = [make_remote_env(i) for i in range(num_envs)]
|
||||
self.actors = None # lazy init
|
||||
self.pending = None # lazy init
|
||||
|
||||
def poll(self):
|
||||
if self.actors is None:
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
if self.multiagent:
|
||||
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
|
||||
else:
|
||||
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
|
||||
|
||||
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
|
||||
|
||||
if self.pending is None:
|
||||
self.pending = {a.reset.remote(): a for a in self.actors}
|
||||
|
||||
@@ -48,7 +52,7 @@ class RemoteVectorEnv(BaseEnv):
|
||||
ready, _ = ray.wait(
|
||||
list(self.pending),
|
||||
num_returns=len(self.pending),
|
||||
timeout=self.timeout)
|
||||
timeout=self.poll_timeout)
|
||||
|
||||
# Get and return observations for each of the ready envs
|
||||
env_ids = set()
|
||||
@@ -72,8 +76,15 @@ class RemoteVectorEnv(BaseEnv):
|
||||
self.pending[obj_id] = actor
|
||||
|
||||
def try_reset(self, env_id):
|
||||
obs, _, _, _ = ray.get(self.actors[env_id].reset.remote())
|
||||
return obs
|
||||
actor = self.actors[env_id]
|
||||
obj_id = actor.reset.remote()
|
||||
self.pending[obj_id] = actor
|
||||
return ASYNC_RESET_RETURN
|
||||
|
||||
def stop(self):
|
||||
if self.actors is not None:
|
||||
for actor in self.actors:
|
||||
actor.__ray_terminate__.remote()
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
|
||||
@@ -125,7 +125,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False,
|
||||
async_remote_worker_envs=False):
|
||||
remote_env_batch_wait_ms=0):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
@@ -203,9 +203,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
are very CPU intensive (e.g., for StarCraft).
|
||||
async_remote_worker_envs (bool): Similar to remote_worker_envs,
|
||||
but runs the envs asynchronously in the background.
|
||||
can take much time to step / reset (e.g., for StarCraft)
|
||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||
are waiting when polling environments. 0 (continue when at
|
||||
least one env is ready) is a reasonable default, but optimal
|
||||
value could be obtained by measuring your environment
|
||||
step / reset and model inference perf.
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
@@ -321,7 +324,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs,
|
||||
async_remote_envs=async_remote_worker_envs)
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
@@ -668,6 +671,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map[policy_id].export_checkpoint(export_dir,
|
||||
filename_prefix)
|
||||
|
||||
@DeveloperAPI
|
||||
def stop(self):
|
||||
self.async_env.stop()
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.offline import InputReader
|
||||
@@ -490,8 +490,9 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
raise ValueError(
|
||||
"Setting episode horizon requires reset() support "
|
||||
"from the environment.")
|
||||
else:
|
||||
# Creates a new episode
|
||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||
# Creates a new episode if this is not async return
|
||||
# If reset is async, we will get its result in some future poll
|
||||
episode = active_episodes[env_id]
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
|
||||
@@ -282,8 +282,16 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
# Reset processing
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}))
|
||||
ValueError, lambda: env.send_actions({
|
||||
0: {
|
||||
0: 0,
|
||||
1: 0
|
||||
},
|
||||
1: {
|
||||
0: 0,
|
||||
1: 0
|
||||
}
|
||||
}))
|
||||
self.assertEqual(env.try_reset(0), {0: 0, 1: 0})
|
||||
self.assertEqual(env.try_reset(1), {0: 0, 1: 0})
|
||||
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
@@ -346,7 +354,8 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True)
|
||||
remote_worker_envs=True,
|
||||
remote_env_batch_wait_ms=99999999)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 200)
|
||||
|
||||
@@ -362,7 +371,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
num_envs=4,
|
||||
async_remote_worker_envs=True)
|
||||
remote_worker_envs=True)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 200)
|
||||
|
||||
@@ -599,15 +608,14 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
remote_evs = []
|
||||
optimizer = optimizer_cls(ev, remote_evs, {})
|
||||
for i in range(200):
|
||||
ev.foreach_policy(
|
||||
lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02))
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
ev.foreach_policy(lambda p, _: p.set_epsilon(
|
||||
max(0.02, 1 - i * .02))
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev, remote_evs)
|
||||
if i % 20 == 0:
|
||||
ev.foreach_policy(
|
||||
lambda p, _: p.update_target()
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
ev.foreach_policy(lambda p, _: p.update_target() if isinstance(
|
||||
p, DQNPolicyGraph) else None)
|
||||
print("Iter {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
|
||||
Reference in New Issue
Block a user