diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index bc81b5025..a1fd4a8f3 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -16,7 +16,6 @@ from ray import ray_constants from ray.resource_spec import ResourceSpec from ray.tune.durable_trainable import DurableTrainable from ray.tune.error import AbortTrialExecution, TuneError -from ray.tune.function_runner import FunctionRunner from ray.tune.logger import NoopLogger from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.resources import Resources @@ -246,19 +245,19 @@ class RayTrialExecutor(TrialExecutor): return None - def _setup_remote_runner(self, trial, reuse_allowed): + def _setup_remote_runner(self, trial): trial.init_logdir() # We checkpoint metadata here to try mitigating logdir duplication self.try_checkpoint_metadata(trial) logger_creator = partial(noop_logger_creator, logdir=trial.logdir) - if (self._reuse_actors and reuse_allowed - and self._cached_actor is not None): + if (self._reuse_actors and self._cached_actor is not None): logger.debug("Trial %s: Reusing cached runner %s", trial, self._cached_actor) existing_runner = self._cached_actor self._cached_actor = None trial.set_runner(existing_runner) + if not self.reset_trial(trial, trial.config, trial.experiment_tag, logger_creator): raise AbortTrialExecution( @@ -378,14 +377,7 @@ class RayTrialExecutor(TrialExecutor): """ prior_status = trial.status if runner is None: - # We reuse actors when there is previously instantiated state on - # the actor. Function API calls are also supported when there is - # no checkpoint to continue from. - # TODO: Check preconditions - why is previous state needed? - reuse_allowed = checkpoint is not None or trial.has_checkpoint() \ - or issubclass(trial.get_trainable_cls(), - FunctionRunner) - runner = self._setup_remote_runner(trial, reuse_allowed) + runner = self._setup_remote_runner(trial) if not runner: return False trial.set_runner(runner) @@ -520,11 +512,20 @@ class RayTrialExecutor(TrialExecutor): trial.set_experiment_tag(new_experiment_tag) trial.set_config(new_config) trainable = trial.runner + + # Pass magic variables + extra_config = copy.deepcopy(new_config) + extra_config[TRIAL_INFO] = TrialInfo(trial) + + stdout_file, stderr_file = trial.log_to_file + extra_config[STDOUT_FILE] = stdout_file + extra_config[STDERR_FILE] = stderr_file + with self._change_working_directory(trial): with warn_if_slow("reset"): try: reset_val = ray.get( - trainable.reset.remote(new_config, logger_creator), + trainable.reset.remote(extra_config, logger_creator), timeout=DEFAULT_GET_TIMEOUT) except GetTimeoutError: logger.exception("Trial %s: reset timed out.", trial) diff --git a/python/ray/tune/tests/test_actor_reuse.py b/python/ray/tune/tests/test_actor_reuse.py index 8683df423..4dfac4a4f 100644 --- a/python/ray/tune/tests/test_actor_reuse.py +++ b/python/ray/tune/tests/test_actor_reuse.py @@ -49,6 +49,7 @@ def create_resettable_class(): if "fake_reset_not_supported" in self.config: return False self.num_resets += 1 + self.iter = 0 self.msg = new_config.get("message", "No message") return True @@ -131,7 +132,7 @@ class ActorReuseTest(unittest.TestCase): self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3]) self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) self.assertEqual([t.last_result["num_resets"] for t in trials], - [1, 2, 3, 4]) + [4, 5, 6, 7]) def testTrialReuseEnabledFunction(self): num_resets = defaultdict(lambda: 0) @@ -176,7 +177,7 @@ class ActorReuseTest(unittest.TestCase): reuse_actors=True).trials # Check trial 1 - self.assertEqual(trial1.last_result["num_resets"], 1) + self.assertEqual(trial1.last_result["num_resets"], 2) self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stdout"))) self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stderr"))) with open(os.path.join(trial1.logdir, "stdout"), "rt") as fp: @@ -191,7 +192,7 @@ class ActorReuseTest(unittest.TestCase): self.assertNotIn("LOG_STDERR: Second", content) # Check trial 2 - self.assertEqual(trial2.last_result["num_resets"], 2) + self.assertEqual(trial2.last_result["num_resets"], 3) self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stdout"))) self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stderr"))) with open(os.path.join(trial2.logdir, "stdout"), "rt") as fp: diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 16bcd9f2a..22d70af31 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -431,6 +431,10 @@ class Trainable: reset actor behavior for the new config.""" self.config = new_config + trial_info = new_config.pop(TRIAL_INFO, None) + if trial_info: + self._trial_info = trial_info + self._result_logger.flush() self._result_logger.close()