[rllib] Add option to proceed even if some workers crashed (#4376)

This commit is contained in:
Eric Liang
2019-03-16 13:34:09 -07:00
committed by GitHub
parent db9fe6619d
commit a45019d98c
9 changed files with 219 additions and 6 deletions
+74 -4
View File
@@ -13,6 +13,7 @@ import tensorflow as tf
from types import FunctionType
import ray
from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
@@ -29,6 +30,10 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
@@ -48,6 +53,8 @@ COMMON_CONFIG = {
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"agent": ..., "result": ...}
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
@@ -99,7 +106,7 @@ COMMON_CONFIG = {
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# Whether to use a background thread for sampling (slightly off-policy)
# (Deprecated) Use a background thread for sampling (slightly off-policy)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",
@@ -285,15 +292,32 @@ class Agent(Trainable):
def train(self):
"""Overrides super.train to synchronize global vars."""
if hasattr(self, "optimizer") and isinstance(self.optimizer,
PolicyOptimizer):
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = Trainable.train(self)
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
else:
break
if result is None:
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
@@ -304,6 +328,9 @@ class Agent(Trainable):
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
return result
@override(Trainable)
@@ -558,6 +585,49 @@ class Agent(Trainable):
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _try_recover(self):
"""Try to identify and blacklist any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker. It issues check requests to all current workers and
blacklists any that respond with error. If no healthy workers remain,
an error is raised.
"""
if not self._has_policy_optimizer():
raise NotImplementedError(
"Recovery is not supported for this algorithm")
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.remote_evaluators:
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_evaluators = []
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
ev.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_evaluators) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_evaluators)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
+2 -1
View File
@@ -116,7 +116,8 @@ class ImpalaAgent(Agent):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()
self.optimizer.step()
while time.time() - start < self.config["min_iter_time_s"]:
while (time.time() - start < self.config["min_iter_time_s"]
or self.optimizer.num_steps_sampled == prev_steps):
self.optimizer.step()
result = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
+2
View File
@@ -171,6 +171,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
queue_putter(item)
def get_data(self):
if not self.is_alive():
raise RuntimeError("Sampling thread has died")
rollout = self.queue.get(timeout=600.0)
# Propagate errors
@@ -133,6 +133,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
r.__ray_terminate__.remote()
self.learner.stopped = True
@override(PolicyOptimizer)
def reset(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
self.sample_tasks.reset_evaluators(remote_evaluators)
@override(PolicyOptimizer)
def stats(self):
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
@@ -152,6 +152,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
def stop(self):
self.learner.stopped = True
@override(PolicyOptimizer)
def reset(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
self.sample_tasks.reset_evaluators(remote_evaluators)
@override(PolicyOptimizer)
def stats(self):
def timer_to_ms(timer):
@@ -142,6 +142,12 @@ class PolicyOptimizer(object):
res.update(info=self.stats())
return res
@DeveloperAPI
def reset(self, remote_evaluators):
"""Called to change the set of remote evaluators being used."""
self.remote_evaluators = remote_evaluators
@DeveloperAPI
def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""
@@ -0,0 +1,109 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import unittest
import ray
from ray.rllib import _register_all
from ray.rllib.agents.registry import get_agent_class
from ray.tune.registry import register_env
class FaultInjectEnv(gym.Env):
def __init__(self, config):
self.env = gym.make("CartPole-v0")
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
self.config = config
def reset(self):
return self.env.reset()
def step(self, action):
if self.config.worker_index in self.config["bad_indices"]:
raise ValueError("This is a simulated error from {}".format(
self.config.worker_index))
return self.env.step(action)
class IgnoresWorkerFailure(unittest.TestCase):
def doTest(self, alg, config, fn=None):
fn = fn or self._doTestFaultRecover
try:
ray.init(num_cpus=6)
fn(alg, config)
finally:
ray.shutdown()
_register_all() # re-register the evicted objects
def _doTestFaultRecover(self, alg, config):
register_env("fault_env", lambda c: FaultInjectEnv(c))
agent_cls = get_agent_class(alg)
# Test fault handling
config["num_workers"] = 2
config["ignore_worker_failures"] = True
config["env_config"] = {"bad_indices": [1]}
a = agent_cls(config=config, env="fault_env")
result = a.train()
self.assertTrue(result["num_healthy_workers"], 1)
a.stop()
def _doTestFaultFatal(self, alg, config):
register_env("fault_env", lambda c: FaultInjectEnv(c))
agent_cls = get_agent_class(alg)
# Test raises real error when out of workers
config["num_workers"] = 2
config["ignore_worker_failures"] = True
config["env_config"] = {"bad_indices": [1, 2]}
a = agent_cls(config=config, env="fault_env")
self.assertRaises(Exception, lambda: a.train())
a.stop()
def testFatal(self):
# test the case where all workers fail
self.doTest("PG", {"optimizer": {}}, fn=self._doTestFaultFatal)
def testAsyncGrads(self):
self.doTest("A3C", {"optimizer": {"grads_per_step": 1}})
def testAsyncReplay(self):
self.doTest(
"APEX", {
"timesteps_per_iteration": 1000,
"num_gpus": 0,
"min_iter_time_s": 1,
"learning_starts": 1000,
"target_network_update_freq": 100,
"optimizer": {
"num_replay_buffer_shards": 1,
},
})
def testAsyncSamples(self):
self.doTest("IMPALA", {"num_gpus": 0})
def testSyncReplay(self):
self.doTest("DQN", {"timesteps_per_iteration": 1})
def testMultiGPU(self):
self.doTest(
"PPO", {
"num_sgd_iter": 1,
"train_batch_size": 10,
"sample_batch_size": 10,
"sgd_minibatch_size": 1,
})
def testSyncSamples(self):
self.doTest("PG", {"optimizer": {}})
def testAsyncSamplingOption(self):
self.doTest("PG", {"optimizer": {}, "sample_async": True})
if __name__ == "__main__":
unittest.main(verbosity=2)
+12
View File
@@ -53,6 +53,18 @@ class TaskPool(object):
remaining.append((worker, obj_id))
self._fetching = remaining
def reset_evaluators(self, evaluators):
"""Notify that some evaluators may be removed."""
for obj_id, ev in self._tasks.copy().items():
if ev not in evaluators:
del self._tasks[obj_id]
del self._objects[obj_id]
ok = []
for ev, obj_id in self._fetching:
if ev in evaluators:
ok.append((ev, obj_id))
self._fetching = ok
@property
def count(self):
return len(self._tasks)