diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index c97d851ea..ee58948d1 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -267,8 +267,7 @@ class FunctionActorManager: )) self._num_task_executions[job_id][function_id] = 0 except Exception: - logger.exception( - "Failed to load function {}.".format(function_name)) + logger.exception("Failed to load function %s.", function_name) raise Exception( "Function {} failed to be loaded from local code.".format( function_descriptor)) @@ -428,8 +427,7 @@ class FunctionActorManager: else: return actor_class except Exception: - logger.exception( - "Failed to load actor_class %s.".format(class_name)) + logger.exception("Failed to load actor_class %s.", class_name) raise Exception( "Actor {} failed to be imported from local code.".format( class_name)) @@ -475,8 +473,7 @@ class FunctionActorManager: with self.lock: actor_class = pickle.loads(pickled_class) except Exception: - logger.exception( - "Failed to load actor class %s.".format(class_name)) + logger.exception("Failed to load actor class %s.", class_name) # The actor class failed to be unpickled, create a fake actor # class instead (just to produce error messages and to prevent # the driver from hanging). diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 27318355c..c6d10460c 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -13,7 +13,8 @@ class Checkpoint: Attributes: storage (str): Storage type. value (str): If storage==MEMORY, it is a Python object. - If storage==PERSISTENT, it is a path to persistent storage. + If storage==PERSISTENT, it is a path to persistent storage, + or a future that will be resolved to such a path. """ MEMORY = "memory" @@ -29,6 +30,18 @@ class Checkpoint: """Creates a checkpoint from a Python object.""" return Checkpoint(Checkpoint.MEMORY, value) + @property + def is_ready(self): + """Returns whether the checkpoint is ready to be used for restoration. + + A PERSISTENT checkpoint is considered ready once its value is resolved + to an actual path. MEMORY checkpoints are always considered ready since + they are transient. + """ + if self.storage == Checkpoint.PERSISTENT: + return isinstance(self.value, str) + return self.storage == Checkpoint.MEMORY + class QueueItem: def __init__(self, priority, value): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 48b1bfc3e..df98ec980 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -554,7 +554,7 @@ class RayTrialExecutor(TrialExecutor): self._update_avail_resources() def save(self, trial, storage=Checkpoint.PERSISTENT, result=None): - """Saves the trial's state to a checkpoint. + """Saves the trial's state to a checkpoint asynchronously. Args: trial (Trial): The trial to be saved. @@ -567,29 +567,16 @@ class RayTrialExecutor(TrialExecutor): Checkpoint object, or None if an Exception occurs. """ result = result or trial.last_result - with self._change_working_directory(trial): if storage == Checkpoint.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = Checkpoint(storage, value, result) - else: - with warn_if_slow("save_checkpoint_to_storage"): - # TODO(ujvl): Make this asynchronous. - value = ray.get(trial.runner.save.remote()) - checkpoint = Checkpoint(storage, value, result) - with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile: - try: trial.on_checkpoint(checkpoint) - except Exception: - logger.exception("Trial %s: Error handling checkpoint %s", - trial, checkpoint.value) - return None - if profile.too_slow and trial.sync_on_checkpoint: - logger.warning( - "Consider turning off forced head-worker trial checkpoint " - "syncs by setting sync_on_checkpoint=False. Note that this " - "might result in faulty trial restoration for some worker " - "failure modes.") + else: + value = trial.runner.save.remote() + checkpoint = Checkpoint(storage, value, result) + trial.saving_to = checkpoint + self._running[value] = trial return checkpoint def restore(self, trial, checkpoint=None): diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index a4d49e3e2..1753d23de 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -146,14 +146,15 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): trial = Trial("__fake", **kwargs) runner.add_trial(trial) - runner.step() # run 1 + runner.step() # Start trial assert trial.status == Trial.RUNNING cluster.remove_node(node) cluster.add_node(num_cpus=1) cluster.wait_for_nodes() assert ray.cluster_resources()["CPU"] == 1 - for i in range(3): + # Process result (x2), process save, process result. + for _ in range(4): runner.step() assert trial.status == Trial.TERMINATED @@ -237,39 +238,45 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): # Test recovery of trial that hasn't been checkpointed t = Trial(trainable_id, **kwargs) runner.add_trial(t) - runner.step() # start - runner.step() # 1 result + runner.step() # Start trial + runner.step() # Process result assert t.last_result node2 = cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() + # TODO(ujvl): Node failure does not propagate until a step after it + # actually should. This is possibly a problem with `Cluster`. + runner.step() runner.step() # Recovery step # TODO(rliaw): This assertion is not critical but will not pass # because checkpoint handling is messy and should be refactored # rather than hotfixed. # assert t.last_result is None, "Trial result not restored correctly." - for i in range(4): + + # Process result (x2), process save, process result (x2), process save + for _ in range(6): runner.step() - assert t.status == Trial.TERMINATED + assert t.status == Trial.TERMINATED, runner.debug_string() # Test recovery of trial that has been checkpointed t2 = Trial(trainable_id, **kwargs) runner.add_trial(t2) - runner.step() # start - runner.step() # 1 result - runner.step() # 2 result and checkpoint + # Start trial, process result (x2), process save + for _ in range(4): + runner.step() assert t2.has_checkpoint() node3 = cluster.add_node(num_cpus=1) cluster.remove_node(node2) cluster.wait_for_nodes() - runner.step() # 3 result + start and fail 4 result - runner.step() # Recovery step - runner.step() # Process recovery - runner.step() # result + runner.step() # Process result 3 + start and fail 4 result + runner.step() # Dispatch restore + runner.step() # Process restore + runner.step() # Process result 5 if t2.status != Trial.TERMINATED: - runner.step() + runner.step() # Process result 6, dispatch save + runner.step() # Process save assert t2.status == Trial.TERMINATED, runner.debug_string() # Test recovery of trial that won't be checkpointed @@ -282,8 +289,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): } t3 = Trial(trainable_id, **kwargs) runner.add_trial(t3) - runner.step() # start - runner.step() # 1 result + runner.step() # Start trial + runner.step() # Process result 1 cluster.add_node(num_cpus=1) cluster.remove_node(node3) cluster.wait_for_nodes() @@ -318,13 +325,16 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id): for t in trials: runner.add_trial(t) - runner.step() # start - runner.step() # 1 result + runner.step() # Start trial + runner.step() # Process result, dispatch save + runner.step() # Process save cluster.remove_node(node) cluster.wait_for_nodes() - runner.step() - assert all(t.status == Trial.PENDING for t in trials) + runner.step() # Process result, dispatch save + runner.step() # Process save (detect error), requeue trial + assert all( + t.status == Trial.PENDING for t in trials), runner.debug_string() with pytest.raises(TuneError): runner.step() @@ -374,19 +384,21 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, # Test recovery of trial that has been checkpointed t1 = Trial(trainable_id, **kwargs) runner.add_trial(t1) - runner.step() # start - runner.step() # 1 result - runner.step() # 2 result and checkpoint + + # Start trial, process result (x2), process save + for _ in range(4): + runner.step() assert t1.has_checkpoint() + cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() shutil.rmtree(os.path.dirname(t1.checkpoint.value)) - runner.step() # collect result 3, kick off + fail result 4 - runner.step() # Recovery step - runner.step() # Process Recovery + step 4 - for i in range(3): + runner.step() # Collect result 3, kick off + fail result 4 + runner.step() # Dispatch restore + runner.step() # Process restore + step 4 + for _ in range(3): if t1.status != Trial.TERMINATED: runner.step() assert t1.status == Trial.TERMINATED, runner.debug_string() @@ -414,9 +426,9 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id): for t in trials: runner.add_trial(t) - runner.step() # start - runner.step() # start2 - runner.step() # step + # Start trial (x2), process result, process save + for _ in range(4): + runner.step() assert all(t.status == Trial.RUNNING for t in runner.get_trials()) runner.checkpoint() @@ -425,11 +437,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id): cluster = _start_new_cluster() runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath) - runner.step() # start - runner.step() # process restore - runner.step() # start2 + # Start trial, process restore, process result, process save + for _ in range(4): + runner.step() - for i in range(3): + # Start trial 2, process result, process save, process result, process save + for i in range(5): runner.step() with pytest.raises(TuneError): diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 6c9fd12d9..8c1db221b 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -30,11 +30,25 @@ class RayTrialExecutorTest(unittest.TestCase): self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) + def testAsyncSave(self): + """Tests that saved checkpoint value not immediately set.""" + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) + self.assertEqual(checkpoint, trial.saving_to) + self.assertEqual(trial.checkpoint.value, None) + self.process_trial_save(trial) + self.assertEqual(checkpoint, trial.checkpoint) + self.trial_executor.stop_trial(trial) + self.assertEqual(Trial.TERMINATED, trial.status) + def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.save(trial, Checkpoint.PERSISTENT) + self.process_trial_save(trial) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) @@ -59,6 +73,8 @@ class RayTrialExecutorTest(unittest.TestCase): checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT) + # Process save result (simulates trial runner) + self.process_trial_save(trial) # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) @@ -125,11 +141,20 @@ class RayTrialExecutorTest(unittest.TestCase): self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) - def generate_trials(self, spec, name): + @staticmethod + def generate_trials(spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) return suggester.next_trials() + @staticmethod + def process_trial_save(trial): + """Simulates trial runner save.""" + checkpoint = trial.saving_to + checkpoint_value = ray.get(checkpoint.value) + checkpoint.value = checkpoint_value + trial.on_checkpoint(checkpoint) + class RayExecutorQueueTest(unittest.TestCase): def setUp(self): diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 7e1d76abf..e02b20eea 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -84,11 +84,12 @@ class TrialRunnerTest2(unittest.TestCase): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process result, dispatch save self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process save + runner.step() # Error self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 1) self.assertEqual(len(searchalg.errored_trials), 1) @@ -111,14 +112,15 @@ class TrialRunnerTest2(unittest.TestCase): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process result, dispatch save self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process save + runner.step() # Error (transient), dispatch restore self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[0].num_failures, 1) - runner.step() + runner.step() # Process restore self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(len(searchalg.errored_trials), 0) self.assertEqual(len(scheduler.errored_trials), 0) @@ -142,15 +144,16 @@ class TrialRunnerTest2(unittest.TestCase): with patch("ray.cluster_resources") as resource_mock: resource_mock.return_value = {"CPU": 1, "GPU": 1} - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process result, dispatch save + runner.step() # Process save self.assertEqual(trials[0].status, Trial.RUNNING) # Mimic a node failure resource_mock.return_value = {"CPU": 0, "GPU": 0} - runner.step() + runner.step() # Detect node failure self.assertEqual(trials[0].status, Trial.PENDING) self.assertEqual(trials[0].num_failures, 1) self.assertEqual(len(searchalg.errored_trials), 0) @@ -171,19 +174,20 @@ class TrialRunnerTest2(unittest.TestCase): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process result, dispatch save self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() + runner.step() # Process save + runner.step() # Error (transient), dispatch restore self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[0].num_failures, 1) - runner.step() # Restore step - runner.step() + runner.step() # Process restore + runner.step() # Error (transient), dispatch restore self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[0].num_failures, 2) - runner.step() # Restore step - runner.step() + runner.step() # Process restore + runner.step() # Error (terminal) self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 3) @@ -195,61 +199,69 @@ class TrialRunnerTest2(unittest.TestCase): "training_iteration": 1 }, "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, } runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - checkpoint = runner.trial_executor.save(trials[0]) - kwargs["restore_path"] = checkpoint.value + runner.step() # Process result, dispatch save + runner.step() # Process save, stop trial + kwargs["restore_path"] = trials[0].checkpoint.value + self.assertEqual(trials[0].status, Trial.TERMINATED) runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() - self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.PENDING) - runner.step() + runner.step() # Start trial, dispatch restore + self.assertEqual(trials[1].status, Trial.RUNNING) + + runner.step() # Process restore self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) - self.addCleanup(os.remove, checkpoint.value) + self.addCleanup(os.remove, trials[0].checkpoint.value) def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner() kwargs = { "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, } runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - checkpoint = runner.trial_executor.save(trials[0]) + # checkpoint = runner.trial_executor.save(trials[0]) + runner.step() # Process result, dispatch save + runner.step() # Process save runner.trial_executor.stop_trial(trials[0]) - kwargs["restore_path"] = checkpoint.value + kwargs["restore_path"] = trials[0].checkpoint.value + kwargs.pop("checkpoint_freq") # No checkpointing for next trial runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial, dispatch restore self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) - runner.step() # Restore step - runner.step() + runner.step() # Process restore + runner.step() # Process result self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10) self.assertEqual(trials[1].last_result["iterations_since_restore"], 1) self.assertGreater(trials[1].last_result["time_since_restore"], 0) - runner.step() + runner.step() # Process restore self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertGreater(trials[1].last_result["time_since_restore"], 0) - self.addCleanup(os.remove, checkpoint.value) + self.addCleanup(os.remove, trials[0].checkpoint.value) def testCheckpointingAtEnd(self): ray.init(num_cpus=1, num_gpus=1) @@ -264,11 +276,12 @@ class TrialRunnerTest2(unittest.TestCase): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) - runner.step() - runner.step() + runner.step() # Process result + runner.step() # Process result, dispatch save self.assertEqual(trials[0].last_result[DONE], True) + runner.step() # Process save self.assertEqual(trials[0].has_checkpoint(), True) def testResultDone(self): diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index e46706251..1732b5886 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -297,8 +297,9 @@ class TrialRunnerTest3(unittest.TestCase): checkpoint_freq=1) ] runner.add_trial(trials[0]) - runner.step() # start - runner.step() + runner.step() # Start trial + runner.step() # Process result, dispatch save + runner.step() # Process save self.assertEquals(trials[0].status, Trial.TERMINATED) trials += [ @@ -310,9 +311,10 @@ class TrialRunnerTest3(unittest.TestCase): config={"mock_error": True}) ] runner.add_trial(trials[1]) - runner.step() - runner.step() - runner.step() + runner.step() # Start trial + runner.step() # Process result, dispatch save + runner.step() # Process save + runner.step() # Error self.assertEquals(trials[1].status, Trial.ERROR) trials += [ @@ -323,7 +325,7 @@ class TrialRunnerTest3(unittest.TestCase): checkpoint_freq=1) ] runner.add_trial(trials[2]) - runner.step() + runner.step() # Start trial self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3) self.assertEquals(trials[2].status, Trial.RUNNING) @@ -336,9 +338,11 @@ class TrialRunnerTest3(unittest.TestCase): restored_trial = runner2.get_trial("trial_succ") self.assertEqual(Trial.PENDING, restored_trial.status) - runner2.step() - runner2.step() - runner2.step() + runner2.step() # Start trial + runner2.step() # Process result, dispatch save + runner2.step() # Process save + runner2.step() # Process result, dispatch save + runner2.step() # Process save self.assertRaises(TuneError, runner2.step) shutil.rmtree(tmpdir) @@ -444,18 +448,19 @@ class TrialRunnerTest3(unittest.TestCase): runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2})) trials = runner.get_trials() - runner.step() + runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - runner.step() # 0 + runner.step() # Process result self.assertFalse(trials[0].has_checkpoint()) - runner.step() # 1 + runner.step() # Process result self.assertFalse(trials[0].has_checkpoint()) - runner.step() # 2 + runner.step() # Process result, dispatch save + runner.step() # Process save self.assertTrue(trials[0].has_checkpoint()) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) - runner2.step() + runner2.step() # 5: Start trial and dispatch restore trials2 = runner2.get_trials() self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1) shutil.rmtree(tmpdir) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f02d082d3..98bfaaf55 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -194,6 +194,7 @@ class Trial: self.custom_trial_name = None # Checkpointing fields + self.saving_to = None if remote_checkpoint_dir: self.remote_checkpoint_dir_prefix = remote_checkpoint_dir else: @@ -210,7 +211,6 @@ class Trial: # Restoration fields self.restoring_from = None self.num_failures = 0 - self.num_consecutive_start_attempts = 0 # AutoML fields self.results = None @@ -460,6 +460,10 @@ class Trial: def is_restoring(self): return self.restoring_from is not None + @property + def is_saving(self): + return self.saving_to is not None + def __repr__(self): return str(self) @@ -497,6 +501,9 @@ class Trial: state["runner"] = None state["result_logger"] = None + # Avoid waiting for events that will never occur on resume. + state["resuming_from"] = None + state["saving_to"] = None if self.result_logger: self.result_logger.flush(sync_down=False) state["__logger_started__"] = True diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index c4393160e..6cff1ec24 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -45,7 +45,7 @@ class TrialExecutor: self.try_checkpoint_metadata(trial) def try_checkpoint_metadata(self, trial): - """Checkpoints metadata. + """Checkpoints trial metadata. Args: trial (Trial): Trial to checkpoint. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8e2e3421f..81b25147b 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -124,6 +124,8 @@ class TrialRunner: server_port (int): Port number for launching TuneServer. verbose (bool): Flag for verbosity. If False, trial results will not be output. + checkpoint_period (int): Trial runner checkpoint periodicity in + seconds. Defaults to 10. trial_executor (TrialExecutor): Defaults to RayTrialExecutor. """ self._search_alg = search_alg or BasicVariantGenerator() @@ -144,6 +146,7 @@ class TrialRunner: self._server = TuneServer(self, self._server_port) self._trials = [] + self._cached_trial_decisions = {} self._stop_queue = [] self._local_checkpoint_dir = local_checkpoint_dir @@ -281,7 +284,6 @@ class TrialRunner: Requires user to manually re-register their objects. Also stops all ongoing trials. """ - newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir) with open(newest_ckpt_path, "r") as f: runner_state = json.load(f, cls=_TuneFunctionDecoder) @@ -307,7 +309,6 @@ class TrialRunner: def is_finished(self): """Returns whether all trials have finished running.""" - if self._total_time > self._global_time_limit: logger.warning("Exceeded global time limit {} / {}".format( self._total_time, self._global_time_limit)) @@ -362,7 +363,6 @@ class TrialRunner: Note that the caller usually should not mutate trial state directly. """ - return self._trials def add_trial(self, trial): @@ -427,12 +427,34 @@ class TrialRunner: if trial.is_restoring: with warn_if_slow("process_trial_restore"): self._process_trial_restore(trial) + elif trial.is_saving: + with warn_if_slow("process_trial_save") as profile: + self._process_trial_save(trial) + if profile.too_slow and trial.sync_on_checkpoint: + # TODO(ujvl): Suggest using DurableTrainable once + # API has converged. + logger.warning( + "Consider turning off forced head-worker trial " + "checkpoint syncs by setting sync_on_checkpoint=False" + ". Note that this may result in faulty trial " + "restoration if a failure occurs while the checkpoint " + "is being synced from the worker to the head node.") else: with warn_if_slow("process_trial"): self._process_trial(trial) def _process_trial(self, trial): - """Processes a trial result.""" + """Processes a trial result. + + Fetches the trial's latest result and makes a scheduling decision + regarding its next action. If a checkpoint is taken, the decided + action is cached and acted on only after the checkpoint is later + processed (see `_process_trial_save`). Otherwise the decision is + acted on immediately. + + Args: + trial (Trial): Trial with a result ready to be processed. + """ try: result = self.trial_executor.fetch_result(trial) @@ -480,25 +502,53 @@ class TrialRunner: self._checkpoint_trial_if_needed( trial, force=result.get(SHOULD_CHECKPOINT, False)) - if decision == TrialScheduler.CONTINUE: - self.trial_executor.continue_training(trial) - elif decision == TrialScheduler.PAUSE: - self.trial_executor.pause_trial(trial) - elif decision == TrialScheduler.STOP: - self.trial_executor.export_trial_if_needed(trial) - self.trial_executor.stop_trial(trial) + if trial.is_saving: + # Cache decision to execute on after the save is processed. + # This prevents changing the trial's state or kicking off + # another training step prematurely. + self._cached_trial_decisions[trial.trial_id] = decision else: - assert False, "Invalid scheduling decision: {}".format( - decision) + self._execute_action(trial, decision) except Exception: logger.exception("Trial %s: Error processing event.", trial) self._process_trial_failure(trial, traceback.format_exc()) + def _process_trial_save(self, trial): + """Processes a trial save. + + Acts on the decision cached during the last `_process_trial` call. + + Args: + trial (Trial): Trial being saved. + """ + logger.debug("Trial %s: Processing trial save.", trial) + checkpoint_value = None + + try: + checkpoint_value = self.trial_executor.fetch_result(trial) + except Exception: + logger.exception("Trial %s: Error processing result.", trial) + self._process_trial_failure(trial, traceback.format_exc()) + + if checkpoint_value: + try: + trial.saving_to.value = checkpoint_value + trial.on_checkpoint(trial.saving_to) + self.trial_executor.try_checkpoint_metadata(trial) + except Exception: + logger.exception("Trial %s: Error handling checkpoint %s", + trial, checkpoint_value) + + trial.saving_to = None + decision = self._cached_trial_decisions.pop(trial.trial_id, None) + if decision and checkpoint_value: + self._execute_action(trial, decision) + def _process_trial_restore(self, trial): """Processes a trial restore. Args: - trial: Trial being restored. + trial (Trial): Trial being restored. """ logger.debug("Trial %s: Processing trial restore.", trial) try: @@ -529,13 +579,29 @@ class TrialRunner: self.trial_executor.stop_trial( trial, error=True, error_msg=error_msg) + def _execute_action(self, trial, decision): + """Executes action based on decision. + + Args: + trial (Trial): Trial to act on. + decision (str): Scheduling decision to undertake. + """ + if decision == TrialScheduler.CONTINUE: + self.trial_executor.continue_training(trial) + elif decision == TrialScheduler.PAUSE: + self.trial_executor.pause_trial(trial) + elif decision == TrialScheduler.STOP: + self.trial_executor.export_trial_if_needed(trial) + self.trial_executor.stop_trial(trial) + else: + raise ValueError("Invalid decision: {}".format(decision)) + def _checkpoint_trial_if_needed(self, trial, force=False): """Checkpoints trial based off trial.last_result.""" if trial.should_checkpoint() or force: - # Save trial runtime if possible + # Save trial runtime if possible. if trial.runner: self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT) - self.trial_executor.try_checkpoint_metadata(trial) def _try_recover(self, trial, error_msg): """Tries to recover trial.