From a45019d98c7067806295bea22c22a66f062fa50e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 16 Mar 2019 13:34:09 -0700 Subject: [PATCH] [rllib] Add option to proceed even if some workers crashed (#4376) --- ci/jenkins_tests/run_rllib_tests.sh | 5 +- python/ray/rllib/agents/agent.py | 78 ++++++++++++- python/ray/rllib/agents/impala/impala.py | 3 +- python/ray/rllib/evaluation/sampler.py | 2 + .../optimizers/async_replay_optimizer.py | 5 + .../optimizers/async_samples_optimizer.py | 5 + .../ray/rllib/optimizers/policy_optimizer.py | 6 + .../rllib/tests/test_ignore_worker_failure.py | 109 ++++++++++++++++++ python/ray/rllib/utils/actors.py | 12 ++ 9 files changed, 219 insertions(+), 6 deletions(-) create mode 100644 python/ray/rllib/tests/test_ignore_worker_failure.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 91d4acd9c..fda03cfb2 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -412,4 +412,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ --config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}' docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/rllib/agents/impala/vtrace_test.py + /ray/python/ray/rllib/tests/run_silent.sh agents/impala/vtrace_test.py + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/python/ray/rllib/tests/run_silent.sh tests/test_ignore_worker_failure.py diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 235d0f704..1c60ba8b8 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -13,6 +13,7 @@ import tensorflow as tf from types import FunctionType import ray +from ray.exceptions import RayError from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ ShuffledInput from ray.rllib.models import MODEL_DEFAULTS @@ -29,6 +30,10 @@ from ray.tune.result import DEFAULT_RESULTS_DIR logger = logging.getLogger(__name__) +# Max number of times to retry a worker failure. We shouldn't try too many +# times in a row since that would indicate a persistent cluster issue. +MAX_WORKER_FAILURE_RETRIES = 3 + # yapf: disable # __sphinx_doc_begin__ COMMON_CONFIG = { @@ -48,6 +53,8 @@ COMMON_CONFIG = { "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} "on_train_result": None, # arg: {"agent": ..., "result": ...} }, + # Whether to attempt to continue training if a worker crashes. + "ignore_worker_failures": False, # === Policy === # Arguments to pass to model. See models/catalog.py for a full list of the @@ -99,7 +106,7 @@ COMMON_CONFIG = { "train_batch_size": 200, # Whether to rollout "complete_episodes" or "truncate_episodes" "batch_mode": "truncate_episodes", - # Whether to use a background thread for sampling (slightly off-policy) + # (Deprecated) Use a background thread for sampling (slightly off-policy) "sample_async": False, # Element-wise observation filter, either "NoFilter" or "MeanStdFilter" "observation_filter": "NoFilter", @@ -285,15 +292,32 @@ class Agent(Trainable): def train(self): """Overrides super.train to synchronize global vars.""" - if hasattr(self, "optimizer") and isinstance(self.optimizer, - PolicyOptimizer): + if self._has_policy_optimizer(): self.global_vars["timestep"] = self.optimizer.num_steps_sampled self.optimizer.local_evaluator.set_global_vars(self.global_vars) for ev in self.optimizer.remote_evaluators: ev.set_global_vars.remote(self.global_vars) logger.debug("updated global vars: {}".format(self.global_vars)) - result = Trainable.train(self) + result = None + for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + try: + result = Trainable.train(self) + except RayError as e: + if self.config["ignore_worker_failures"]: + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.info( + "Worker crashed during call to train(). To attempt to " + "continue training without the failed worker, set " + "`'ignore_worker_failures': True`.") + raise e + else: + break + if result is None: + raise RuntimeError("Failed to recover from worker crash") if (self.config.get("observation_filter", "NoFilter") != "NoFilter" and hasattr(self, "local_evaluator")): @@ -304,6 +328,9 @@ class Agent(Trainable): logger.debug("synchronized filters: {}".format( self.local_evaluator.filters)) + if self._has_policy_optimizer(): + result["num_healthy_workers"] = len( + self.optimizer.remote_evaluators) return result @override(Trainable) @@ -558,6 +585,49 @@ class Agent(Trainable): "`input_evaluation` must be a list of strings, got {}".format( config["input_evaluation"])) + def _try_recover(self): + """Try to identify and blacklist any unhealthy workers. + + This method is called after an unexpected remote error is encountered + from a worker. It issues check requests to all current workers and + blacklists any that respond with error. If no healthy workers remain, + an error is raised. + """ + + if not self._has_policy_optimizer(): + raise NotImplementedError( + "Recovery is not supported for this algorithm") + + logger.info("Health checking all workers...") + checks = [] + for ev in self.optimizer.remote_evaluators: + _, obj_id = ev.sample_with_count.remote() + checks.append(obj_id) + + healthy_evaluators = [] + for i, obj_id in enumerate(checks): + ev = self.optimizer.remote_evaluators[i] + try: + ray.get(obj_id) + healthy_evaluators.append(ev) + logger.info("Worker {} looks healthy".format(i + 1)) + except RayError: + logger.exception("Blacklisting worker {}".format(i + 1)) + try: + ev.__ray_terminate__.remote() + except Exception: + logger.exception("Error terminating unhealthy worker") + + if len(healthy_evaluators) < 1: + raise RuntimeError( + "Not enough healthy workers remain to continue.") + + self.optimizer.reset(healthy_evaluators) + + def _has_policy_optimizer(self): + return hasattr(self, "optimizer") and isinstance( + self.optimizer, PolicyOptimizer) + def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, config): def session_creator(): diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 94294a1fc..cf87c773b 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -116,7 +116,8 @@ class ImpalaAgent(Agent): prev_steps = self.optimizer.num_steps_sampled start = time.time() self.optimizer.step() - while time.time() - start < self.config["min_iter_time_s"]: + while (time.time() - start < self.config["min_iter_time_s"] + or self.optimizer.num_steps_sampled == prev_steps): self.optimizer.step() result = self.optimizer.collect_metrics( self.config["collect_metrics_timeout"]) diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 3d91cc44f..0c3032178 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -171,6 +171,8 @@ class AsyncSampler(threading.Thread, SamplerInput): queue_putter(item) def get_data(self): + if not self.is_alive(): + raise RuntimeError("Sampling thread has died") rollout = self.queue.get(timeout=600.0) # Propagate errors diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index f1ae1bf71..51ab68a0c 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -133,6 +133,11 @@ class AsyncReplayOptimizer(PolicyOptimizer): r.__ray_terminate__.remote() self.learner.stopped = True + @override(PolicyOptimizer) + def reset(self, remote_evaluators): + self.remote_evaluators = remote_evaluators + self.sample_tasks.reset_evaluators(remote_evaluators) + @override(PolicyOptimizer) def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 250ec7355..22b7ea18b 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -152,6 +152,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer): def stop(self): self.learner.stopped = True + @override(PolicyOptimizer) + def reset(self, remote_evaluators): + self.remote_evaluators = remote_evaluators + self.sample_tasks.reset_evaluators(remote_evaluators) + @override(PolicyOptimizer) def stats(self): def timer_to_ms(timer): diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 7f4305508..d6a2bb15c 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -142,6 +142,12 @@ class PolicyOptimizer(object): res.update(info=self.stats()) return res + @DeveloperAPI + def reset(self, remote_evaluators): + """Called to change the set of remote evaluators being used.""" + + self.remote_evaluators = remote_evaluators + @DeveloperAPI def foreach_evaluator(self, func): """Apply the given function to each evaluator instance.""" diff --git a/python/ray/rllib/tests/test_ignore_worker_failure.py b/python/ray/rllib/tests/test_ignore_worker_failure.py new file mode 100644 index 000000000..3a9970397 --- /dev/null +++ b/python/ray/rllib/tests/test_ignore_worker_failure.py @@ -0,0 +1,109 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import unittest + +import ray +from ray.rllib import _register_all +from ray.rllib.agents.registry import get_agent_class +from ray.tune.registry import register_env + + +class FaultInjectEnv(gym.Env): + def __init__(self, config): + self.env = gym.make("CartPole-v0") + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + self.config = config + + def reset(self): + return self.env.reset() + + def step(self, action): + if self.config.worker_index in self.config["bad_indices"]: + raise ValueError("This is a simulated error from {}".format( + self.config.worker_index)) + return self.env.step(action) + + +class IgnoresWorkerFailure(unittest.TestCase): + def doTest(self, alg, config, fn=None): + fn = fn or self._doTestFaultRecover + try: + ray.init(num_cpus=6) + fn(alg, config) + finally: + ray.shutdown() + _register_all() # re-register the evicted objects + + def _doTestFaultRecover(self, alg, config): + register_env("fault_env", lambda c: FaultInjectEnv(c)) + agent_cls = get_agent_class(alg) + + # Test fault handling + config["num_workers"] = 2 + config["ignore_worker_failures"] = True + config["env_config"] = {"bad_indices": [1]} + a = agent_cls(config=config, env="fault_env") + result = a.train() + self.assertTrue(result["num_healthy_workers"], 1) + a.stop() + + def _doTestFaultFatal(self, alg, config): + register_env("fault_env", lambda c: FaultInjectEnv(c)) + agent_cls = get_agent_class(alg) + + # Test raises real error when out of workers + config["num_workers"] = 2 + config["ignore_worker_failures"] = True + config["env_config"] = {"bad_indices": [1, 2]} + a = agent_cls(config=config, env="fault_env") + self.assertRaises(Exception, lambda: a.train()) + a.stop() + + def testFatal(self): + # test the case where all workers fail + self.doTest("PG", {"optimizer": {}}, fn=self._doTestFaultFatal) + + def testAsyncGrads(self): + self.doTest("A3C", {"optimizer": {"grads_per_step": 1}}) + + def testAsyncReplay(self): + self.doTest( + "APEX", { + "timesteps_per_iteration": 1000, + "num_gpus": 0, + "min_iter_time_s": 1, + "learning_starts": 1000, + "target_network_update_freq": 100, + "optimizer": { + "num_replay_buffer_shards": 1, + }, + }) + + def testAsyncSamples(self): + self.doTest("IMPALA", {"num_gpus": 0}) + + def testSyncReplay(self): + self.doTest("DQN", {"timesteps_per_iteration": 1}) + + def testMultiGPU(self): + self.doTest( + "PPO", { + "num_sgd_iter": 1, + "train_batch_size": 10, + "sample_batch_size": 10, + "sgd_minibatch_size": 1, + }) + + def testSyncSamples(self): + self.doTest("PG", {"optimizer": {}}) + + def testAsyncSamplingOption(self): + self.doTest("PG", {"optimizer": {}, "sample_async": True}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index 44cbc98e2..7b1944eed 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -53,6 +53,18 @@ class TaskPool(object): remaining.append((worker, obj_id)) self._fetching = remaining + def reset_evaluators(self, evaluators): + """Notify that some evaluators may be removed.""" + for obj_id, ev in self._tasks.copy().items(): + if ev not in evaluators: + del self._tasks[obj_id] + del self._objects[obj_id] + ok = [] + for ev, obj_id in self._fetching: + if ev in evaluators: + ok.append((ev, obj_id)) + self._fetching = ok + @property def count(self): return len(self._tasks)