From edb7aaf7c7e857c2b5b90a33f8cce2ffe374736e Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 9 Jan 2019 11:55:32 -0800 Subject: [PATCH] [tune] Better Serialization for Server (#3708) * Add cloudpickle for serialization * Fix tests --- python/ray/tune/test/tune_server_test.py | 22 ++++++++++++++++++++++ python/ray/tune/web_server.py | 20 ++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/test/tune_server_test.py b/python/ray/tune/test/tune_server_test.py index db99aae2e..e93c7d976 100644 --- a/python/ray/tune/test/tune_server_test.py +++ b/python/ray/tune/test/tune_server_test.py @@ -6,6 +6,7 @@ import unittest import socket import ray +from ray import tune from ray.rllib import _register_all from ray.tune.trial import Trial, Resources from ray.tune.web_server import TuneClient @@ -87,6 +88,27 @@ class TuneServerSuite(unittest.TestCase): runner.step() self.assertEqual(len(all_trials), 2) + def testGetTrialsWithFunction(self): + runner, client = self.basicSetup() + test_trial = Trial( + "__fake", + trial_id="function_trial", + stopping_criterion={"training_iteration": 3}, + config={ + "callbacks": { + "on_episode_start": tune.function(lambda x: None) + } + }) + runner.add_trial(test_trial) + + for i in range(3): + runner.step() + all_trials = client.get_all_trials()["trials"] + self.assertEqual(len(all_trials), 3) + client.get_trial("function_trial") + runner.step() + self.assertEqual(len(all_trials), 3) + def testStopTrial(self): """Check if Stop Trial works.""" runner, client = self.basicSetup() diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py index 06f66887e..81ca33227 100644 --- a/python/ray/tune/web_server.py +++ b/python/ray/tune/web_server.py @@ -7,8 +7,10 @@ import logging import sys import threading +import ray.cloudpickle as cloudpickle from ray.tune.error import TuneError, TuneManagerError from ray.tune.suggest import BasicVariantGenerator +from ray.utils import binary_to_hex, hex_to_binary if sys.version_info[0] == 2: from SimpleHTTPServer import SimpleHTTPRequestHandler @@ -26,6 +28,13 @@ except ImportError: "Be sure to install it on the client side.") +def load_trial_info(trial_info): + trial_info["config"] = cloudpickle.loads( + hex_to_binary(trial_info["config"])) + trial_info["result"] = cloudpickle.loads( + hex_to_binary(trial_info["result"])) + + class TuneClient(object): """Client to interact with ongoing Tune experiment. @@ -71,6 +80,13 @@ class TuneClient(object): payload = json.dumps(data).encode() response = requests.get(self._path, data=payload) parsed = response.json() + + if "trial_info" in parsed: + load_trial_info(parsed["trial_info"]) + elif "trials" in parsed: + for trial_info in parsed["trials"]: + load_trial_info(trial_info) + return parsed @@ -96,9 +112,9 @@ def RunnerHandler(runner): info_dict = { "id": trial.trial_id, "trainable_name": trial.trainable_name, - "config": trial.config, + "config": binary_to_hex(cloudpickle.dumps(trial.config)), "status": trial.status, - "result": result + "result": binary_to_hex(cloudpickle.dumps(result)) } return info_dict