diff --git a/doc/source/tune.rst b/doc/source/tune.rst index bcbd880e6..b38160652 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -199,6 +199,20 @@ Trial Checkpointing To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) `__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT. +Additionally, checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq: N`` and ``max_failures: M`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.: + +.. code-block:: python + + run_experiments({ + "my_experiment": { + ... + "checkpoint_freq": 10, + "max_failures": 5, + }, + }) + +The class interface that must be implemented to enable checkpointing is as follows: + .. autoclass:: ray.tune.trainable.Trainable Resource Allocation diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 47c270299..a7db45c95 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -139,12 +139,19 @@ class _MockAgent(Agent): """Mock agent for use in tests""" _agent_name = "MockAgent" - _default_config = {} + _default_config = { + "mock_error": False, + "persistent_error": False, + } def _init(self): self.info = None + self.restored = False def _train(self): + if self.config["mock_error"] and self.iteration == 1 \ + and (self.config["persistent_error"] or not self.restored): + raise Exception("mock error") return TrainingResult( episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}) @@ -159,6 +166,7 @@ class _MockAgent(Agent): with open(checkpoint_path, 'rb') as f: info = pickle.load(f) self.info = info + self.restored = True def set_info(self, info): self.info = info diff --git a/python/ray/rllib/dqn/dqn_evaluator.py b/python/ray/rllib/dqn/dqn_evaluator.py index 5ed9befdd..20a269cbf 100644 --- a/python/ray/rllib/dqn/dqn_evaluator.py +++ b/python/ray/rllib/dqn/dqn_evaluator.py @@ -174,7 +174,9 @@ class DQNEvaluator(TFMultiGPUSupport): self.episode_rewards, self.episode_lengths, self.saved_mean_reward, - self.obs] + self.obs, + self.global_timestep, + self.local_timestep] def restore(self, data): self.exploration = data[0] @@ -182,3 +184,5 @@ class DQNEvaluator(TFMultiGPUSupport): self.episode_lengths = data[2] self.saved_mean_reward = data[3] self.obs = data[4] + self.global_timestep = data[5] + self.local_timestep = data[6] diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 942602540..f8f086527 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -74,6 +74,10 @@ def make_parser(**kwargs): "--checkpoint-freq", default=0, type=int, help="How many training iterations between checkpoints. " "A value of 0 (default) disables checkpointing.") + parser.add_argument( + "--max-failures", default=3, type=int, + help="Try to recover a trial from its last checkpoint at least this " + "many times. Only applies if checkpointing is enabled.") parser.add_argument( "--scheduler", default="FIFO", type=str, help="FIFO (default), MedianStopping, or HyperBand.") diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 60d707639..5b97b8409 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -208,6 +208,7 @@ class VariantGeneratorTest(unittest.TestCase): trials = generate_trials({ "run": "PPO", "repeat": 2, + "max_failures": 5, "config": { "env": "Pong-v0", "foo": "bar" @@ -219,6 +220,7 @@ class VariantGeneratorTest(unittest.TestCase): self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"}) self.assertEqual(trials[0].trainable_name, "PPO") self.assertEqual(trials[0].experiment_tag, "0") + self.assertEqual(trials[0].max_failures, 5) self.assertEqual( trials[0].local_dir, os.path.join(DEFAULT_RESULTS_DIR, "tune-pong")) @@ -457,6 +459,81 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[1].status, Trial.RUNNING) + def testFailureRecoveryDisabled(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 0, + "config": { + "mock_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.ERROR) + self.assertEqual(trials[0].num_failures, 1) + + def testFailureRecoveryEnabled(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 1, + "config": { + "mock_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(trials[0].num_failures, 1) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + def testFailureRecoveryMaxFailures(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 2, + "config": { + "mock_error": True, + "persistent_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(trials[0].num_failures, 1) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(trials[0].num_failures, 2) + runner.step() + self.assertEqual(trials[0].status, Trial.ERROR) + self.assertEqual(trials[0].num_failures, 3) + def testCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 9acf66493..f6166eaaf 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -78,7 +78,7 @@ class Trial(object): self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR, experiment_tag=None, resources=Resources(cpu=1, gpu=0), stopping_criterion=None, checkpoint_freq=0, - restore_path=None, upload_dir=None): + restore_path=None, upload_dir=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined @@ -106,6 +106,7 @@ class Trial(object): self.checkpoint_freq = checkpoint_freq self.upload_dir = upload_dir self.verbose = True + self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = None @@ -119,6 +120,7 @@ class Trial(object): self.last_debug = 0 self.trial_id = binary_to_hex(random_string())[:8] self.error_file = None + self.num_failures = 0 def start(self, checkpoint_obj=None): """Starts this trial. @@ -158,6 +160,7 @@ class Trial(object): try: if error_msg and self.logdir: + self.num_failures += 1 error_file = os.path.join( self.logdir, "error_{}.txt".format(date_str())) with open(error_file, "w") as f: @@ -268,7 +271,12 @@ class Trial(object): def _status_string(self): return "{}{}".format( self.status, - " => {}".format(self.error_file) if self.error_file else "") + ", {} failures: {}".format(self.num_failures, self.error_file) + if self.error_file else "") + + def has_checkpoint(self): + return self._checkpoint_path is not None or \ + self._checkpoint_obj is not None def checkpoint(self, to_object_store=False): """Checkpoints the state of this trial. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index c1ffc135e..934562170 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -241,8 +241,23 @@ class TrialRunner(object): error_msg = traceback.format_exc() print("Error processing event:", error_msg) if trial.status == Trial.RUNNING: - self._scheduler_alg.on_trial_error(self, trial) - self._stop_trial(trial, error=True, error_msg=error_msg) + if trial.has_checkpoint() and \ + trial.num_failures < trial.max_failures: + self._try_recover(trial, error_msg) + else: + self._scheduler_alg.on_trial_error(self, trial) + self._stop_trial(trial, error=True, error_msg=error_msg) + + def _try_recover(self, trial, error_msg): + try: + print("Attempting to recover trial state from last checkpoint") + trial.stop(error=True, error_msg=error_msg, stop_logger=False) + trial.start() + self._running[trial.train_remote()] = trial + except Exception: + error_msg = traceback.format_exc() + print("Error recovering trial from checkpoint, abort:", error_msg) + self._stop_trial(trial, error=True, error_msg=error_msg) def _get_runnable(self): return self._scheduler_alg.choose_trial_to_run(self) diff --git a/python/ray/tune/variant_generator.py b/python/ray/tune/variant_generator.py index 7ea8b334c..4224073c3 100644 --- a/python/ray/tune/variant_generator.py +++ b/python/ray/tune/variant_generator.py @@ -62,7 +62,8 @@ def generate_trials(unresolved_spec, output_path=''): stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, restore_path=spec.get("restore"), - upload_dir=args.upload_dir) + upload_dir=args.upload_dir, + max_failures=args.max_failures) def generate_variants(unresolved_spec):