From 574f0b73bc314daf4caf8f81b1c6f0ddd6f4057a Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 10 Jan 2019 19:26:10 -0800 Subject: [PATCH] [tune] Fix Trial Serialization (#3743) --- python/ray/tune/test/trial_runner_test.py | 24 +++++++++++++++++++++++ python/ray/tune/trial.py | 6 ++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index ab63f1286..4e8357989 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -1772,6 +1772,30 @@ class TrialRunnerTest(unittest.TestCase): runner2.step() shutil.rmtree(tmpdir) + def testCheckpointWithFunction(self): + ray.init() + trial = Trial( + "__fake", + config={ + "callbacks": { + "on_episode_start": tune.function(lambda i: i), + } + }, + checkpoint_freq=1) + tmpdir = tempfile.mkdtemp() + runner = TrialRunner( + BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner.add_trial(trial) + for i in range(5): + runner.step() + # force checkpoint + runner.checkpoint() + runner2 = TrialRunner.restore(tmpdir) + new_trial = runner2.get_trials()[0] + self.assertTrue("callbacks" in new_trial.config) + self.assertTrue("on_episode_start" in new_trial.config["callbacks"]) + shutil.rmtree(tmpdir) + class SearchAlgorithmTest(unittest.TestCase): def testNestedSuggestion(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 7228c9fbd..777edbb50 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -409,7 +409,8 @@ class Trial(object): "_checkpoint": self._checkpoint, "config": self.config, "custom_loggers": self.custom_loggers, - "sync_function": self.sync_function + "sync_function": self.sync_function, + "last_result": self.last_result } for key, value in pickle_data.items(): @@ -430,7 +431,8 @@ class Trial(object): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) for key in [ - "_checkpoint", "config", "custom_loggers", "sync_function" + "_checkpoint", "config", "custom_loggers", "sync_function", + "last_result" ]: state[key] = cloudpickle.loads(hex_to_binary(state[key]))