mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 06:51:30 +08:00
[tune] Better Serialization for Server (#3708)
* Add cloudpickle for serialization * Fix tests
This commit is contained in:
@@ -6,6 +6,7 @@ import unittest
|
||||
import socket
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.web_server import TuneClient
|
||||
@@ -87,6 +88,27 @@ class TuneServerSuite(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
|
||||
def testGetTrialsWithFunction(self):
|
||||
runner, client = self.basicSetup()
|
||||
test_trial = Trial(
|
||||
"__fake",
|
||||
trial_id="function_trial",
|
||||
stopping_criterion={"training_iteration": 3},
|
||||
config={
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(lambda x: None)
|
||||
}
|
||||
})
|
||||
runner.add_trial(test_trial)
|
||||
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(len(all_trials), 3)
|
||||
client.get_trial("function_trial")
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 3)
|
||||
|
||||
def testStopTrial(self):
|
||||
"""Check if Stop Trial works."""
|
||||
runner, client = self.basicSetup()
|
||||
|
||||
@@ -7,8 +7,10 @@ import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.error import TuneError, TuneManagerError
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
@@ -26,6 +28,13 @@ except ImportError:
|
||||
"Be sure to install it on the client side.")
|
||||
|
||||
|
||||
def load_trial_info(trial_info):
|
||||
trial_info["config"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["config"]))
|
||||
trial_info["result"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["result"]))
|
||||
|
||||
|
||||
class TuneClient(object):
|
||||
"""Client to interact with ongoing Tune experiment.
|
||||
|
||||
@@ -71,6 +80,13 @@ class TuneClient(object):
|
||||
payload = json.dumps(data).encode()
|
||||
response = requests.get(self._path, data=payload)
|
||||
parsed = response.json()
|
||||
|
||||
if "trial_info" in parsed:
|
||||
load_trial_info(parsed["trial_info"])
|
||||
elif "trials" in parsed:
|
||||
for trial_info in parsed["trials"]:
|
||||
load_trial_info(trial_info)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -96,9 +112,9 @@ def RunnerHandler(runner):
|
||||
info_dict = {
|
||||
"id": trial.trial_id,
|
||||
"trainable_name": trial.trainable_name,
|
||||
"config": trial.config,
|
||||
"config": binary_to_hex(cloudpickle.dumps(trial.config)),
|
||||
"status": trial.status,
|
||||
"result": result
|
||||
"result": binary_to_hex(cloudpickle.dumps(result))
|
||||
}
|
||||
return info_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user