[tune] Better Serialization for Server (#3708)

* Add cloudpickle for serialization

* Fix tests
This commit is contained in:
Richard Liaw
2019-01-09 11:55:32 -08:00
committed by GitHub
parent 04f31db54d
commit edb7aaf7c7
2 changed files with 40 additions and 2 deletions
+22
View File
@@ -6,6 +6,7 @@ import unittest
import socket
import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune.trial import Trial, Resources
from ray.tune.web_server import TuneClient
@@ -87,6 +88,27 @@ class TuneServerSuite(unittest.TestCase):
runner.step()
self.assertEqual(len(all_trials), 2)
def testGetTrialsWithFunction(self):
runner, client = self.basicSetup()
test_trial = Trial(
"__fake",
trial_id="function_trial",
stopping_criterion={"training_iteration": 3},
config={
"callbacks": {
"on_episode_start": tune.function(lambda x: None)
}
})
runner.add_trial(test_trial)
for i in range(3):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(len(all_trials), 3)
client.get_trial("function_trial")
runner.step()
self.assertEqual(len(all_trials), 3)
def testStopTrial(self):
"""Check if Stop Trial works."""
runner, client = self.basicSetup()
+18 -2
View File
@@ -7,8 +7,10 @@ import logging
import sys
import threading
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError, TuneManagerError
from ray.tune.suggest import BasicVariantGenerator
from ray.utils import binary_to_hex, hex_to_binary
if sys.version_info[0] == 2:
from SimpleHTTPServer import SimpleHTTPRequestHandler
@@ -26,6 +28,13 @@ except ImportError:
"Be sure to install it on the client side.")
def load_trial_info(trial_info):
trial_info["config"] = cloudpickle.loads(
hex_to_binary(trial_info["config"]))
trial_info["result"] = cloudpickle.loads(
hex_to_binary(trial_info["result"]))
class TuneClient(object):
"""Client to interact with ongoing Tune experiment.
@@ -71,6 +80,13 @@ class TuneClient(object):
payload = json.dumps(data).encode()
response = requests.get(self._path, data=payload)
parsed = response.json()
if "trial_info" in parsed:
load_trial_info(parsed["trial_info"])
elif "trials" in parsed:
for trial_info in parsed["trials"]:
load_trial_info(trial_info)
return parsed
@@ -96,9 +112,9 @@ def RunnerHandler(runner):
info_dict = {
"id": trial.trial_id,
"trainable_name": trial.trainable_name,
"config": trial.config,
"config": binary_to_hex(cloudpickle.dumps(trial.config)),
"status": trial.status,
"result": result
"result": binary_to_hex(cloudpickle.dumps(result))
}
return info_dict