[rllib] Make batch timeout for remote workers tunable (#4435)

This commit is contained in:
bjg2
2019-03-29 21:19:42 +01:00
committed by Eric Liang
parent 2ffe67c5c3
commit 77005d1814
9 changed files with 105 additions and 60 deletions
+20 -9
View File
@@ -147,11 +147,14 @@ COMMON_CONFIG = {
"metrics_smoothing_episodes": 100,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
# can make sense if your envs can take much time to step / reset
# (e.g., for StarCraft)
"remote_worker_envs": False,
# Similar to remote_worker_envs, but runs the envs asynchronously in the
# background for greater efficiency. Conflicts with remote_worker_envs.
"async_remote_worker_envs": False,
# Timeout that remote workers are waiting when polling environments.
# 0 (continue when at least one env is ready) is a reasonable default,
# but optimal value could be obtained by measuring your environment
# step / reset and model inference perf.
"remote_env_batch_wait_ms": 0,
# === Offline Datasets ===
# Specify how to generate experiences:
@@ -378,10 +381,18 @@ class Agent(Trainable):
@override(Trainable)
def _stop(self):
# Call stop on all evaluators to release resources
if hasattr(self, "local_evaluator"):
self.local_evaluator.stop()
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.stop.remote()
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@@ -657,12 +668,12 @@ class Agent(Trainable):
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
@@ -724,7 +735,7 @@ class Agent(Trainable):
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
async_remote_worker_envs=config["async_remote_worker_envs"])
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"])
@override(Trainable)
def _export_model(self, export_formats, export_dir):
+17 -13
View File
@@ -7,6 +7,8 @@ from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override, PublicAPI
ASYNC_RESET_RETURN = "async_reset_return"
@PublicAPI
class BaseEnv(object):
@@ -78,26 +80,23 @@ class BaseEnv(object):
make_env=None,
num_envs=1,
remote_envs=False,
async_remote_envs=False):
remote_env_batch_wait_ms=0):
"""Wraps any env type as needed to expose the async interface."""
from ray.rllib.env.remote_vector_env import RemoteVectorEnv
if (remote_envs or async_remote_envs) and num_envs == 1:
if remote_envs and num_envs == 1:
raise ValueError(
"Remote envs only make sense to use if num_envs > 1 "
"(i.e. vectorization is enabled).")
if remote_envs and async_remote_envs:
raise ValueError("You can only specify one of remote_envs or "
"async_remote_envs.")
if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv):
if remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=True, sync=True)
elif async_remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=True, sync=False)
make_env,
num_envs,
multiagent=True,
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
else:
env = _MultiAgentEnvToBaseEnv(
make_env=make_env,
@@ -113,10 +112,10 @@ class BaseEnv(object):
else:
if remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=False, sync=True)
elif async_remote_envs:
env = RemoteVectorEnv(
make_env, num_envs, multiagent=False, sync=False)
make_env,
num_envs,
multiagent=False,
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
else:
env = VectorEnv.wrap(
make_env=make_env,
@@ -184,6 +183,11 @@ class BaseEnv(object):
"""
return []
@PublicAPI
def stop(self):
"""Releases all resources used."""
pass
# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "agent0"
+28 -17
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import logging
import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
logger = logging.getLogger(__name__)
@@ -18,24 +18,28 @@ class RemoteVectorEnv(BaseEnv):
are supported, and envs can be stepped synchronously or async.
"""
def __init__(self, make_env, num_envs, multiagent, sync):
def __init__(self, make_env, num_envs, multiagent,
remote_env_batch_wait_ms):
self.make_local_env = make_env
if sync:
self.timeout = 9999999.0 # wait for all envs
else:
self.timeout = 0.0 # wait for only ready envs
self.num_envs = num_envs
self.multiagent = multiagent
self.poll_timeout = remote_env_batch_wait_ms / 1000
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
if multiagent:
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
else:
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
self.actors = [make_remote_env(i) for i in range(num_envs)]
self.actors = None # lazy init
self.pending = None # lazy init
def poll(self):
if self.actors is None:
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
if self.multiagent:
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
else:
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
if self.pending is None:
self.pending = {a.reset.remote(): a for a in self.actors}
@@ -48,7 +52,7 @@ class RemoteVectorEnv(BaseEnv):
ready, _ = ray.wait(
list(self.pending),
num_returns=len(self.pending),
timeout=self.timeout)
timeout=self.poll_timeout)
# Get and return observations for each of the ready envs
env_ids = set()
@@ -72,8 +76,15 @@ class RemoteVectorEnv(BaseEnv):
self.pending[obj_id] = actor
def try_reset(self, env_id):
obs, _, _, _ = ray.get(self.actors[env_id].reset.remote())
return obs
actor = self.actors[env_id]
obj_id = actor.reset.remote()
self.pending[obj_id] = actor
return ASYNC_RESET_RETURN
def stop(self):
if self.actors is not None:
for actor in self.actors:
actor.__ray_terminate__.remote()
@ray.remote(num_cpus=0)
@@ -125,7 +125,7 @@ class PolicyEvaluator(EvaluatorInterface):
input_evaluation=frozenset([]),
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False,
async_remote_worker_envs=False):
remote_env_batch_wait_ms=0):
"""Initialize a policy evaluator.
Arguments:
@@ -203,9 +203,12 @@ class PolicyEvaluator(EvaluatorInterface):
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
are very CPU intensive (e.g., for StarCraft).
async_remote_worker_envs (bool): Similar to remote_worker_envs,
but runs the envs asynchronously in the background.
can take much time to step / reset (e.g., for StarCraft)
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
"""
if log_level:
@@ -321,7 +324,7 @@ class PolicyEvaluator(EvaluatorInterface):
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs,
async_remote_envs=async_remote_worker_envs)
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
@@ -668,6 +671,10 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map[policy_id].export_checkpoint(export_dir,
filename_prefix)
@DeveloperAPI
def stop(self):
self.async_env.stop()
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
preprocessors = {}
+4 -3
View File
@@ -14,7 +14,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.offline import InputReader
@@ -490,8 +490,9 @@ def _process_observations(base_env, policies, batch_builder_pool,
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
else:
# Creates a new episode
elif resetted_obs != ASYNC_RESET_RETURN:
# Creates a new episode if this is not async return
# If reset is async, we will get its result in some future poll
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
+18 -10
View File
@@ -282,8 +282,16 @@ class TestMultiAgentEnv(unittest.TestCase):
# Reset processing
self.assertRaises(
ValueError,
lambda: env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}))
ValueError, lambda: env.send_actions({
0: {
0: 0,
1: 0
},
1: {
0: 0,
1: 0
}
}))
self.assertEqual(env.try_reset(0), {0: 0, 1: 0})
self.assertEqual(env.try_reset(1), {0: 0, 1: 0})
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
@@ -346,7 +354,8 @@ class TestMultiAgentEnv(unittest.TestCase):
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
num_envs=4,
remote_worker_envs=True)
remote_worker_envs=True,
remote_env_batch_wait_ms=99999999)
batch = ev.sample()
self.assertEqual(batch.count, 200)
@@ -362,7 +371,7 @@ class TestMultiAgentEnv(unittest.TestCase):
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
num_envs=4,
async_remote_worker_envs=True)
remote_worker_envs=True)
batch = ev.sample()
self.assertEqual(batch.count, 200)
@@ -599,15 +608,14 @@ class TestMultiAgentEnv(unittest.TestCase):
remote_evs = []
optimizer = optimizer_cls(ev, remote_evs, {})
for i in range(200):
ev.foreach_policy(
lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02))
if isinstance(p, DQNPolicyGraph) else None)
ev.foreach_policy(lambda p, _: p.set_epsilon(
max(0.02, 1 - i * .02))
if isinstance(p, DQNPolicyGraph) else None)
optimizer.step()
result = collect_metrics(ev, remote_evs)
if i % 20 == 0:
ev.foreach_policy(
lambda p, _: p.update_target()
if isinstance(p, DQNPolicyGraph) else None)
ev.foreach_policy(lambda p, _: p.update_target() if isinstance(
p, DQNPolicyGraph) else None)
print("Iter {}, rew {}".format(i,
result["policy_reward_mean"]))
print("Total reward", result["episode_reward_mean"])