mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:02:03 +08:00
163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
import requests
|
|
import socket
|
|
import subprocess
|
|
import unittest
|
|
import json
|
|
|
|
import ray
|
|
from ray.rllib import _register_all
|
|
from ray.tune.trial import Trial, Resources
|
|
from ray.tune.web_server import TuneClient
|
|
from ray.tune.trial_runner import TrialRunner
|
|
|
|
|
|
def get_valid_port():
|
|
port = 4321
|
|
while True:
|
|
try:
|
|
print("Trying port", port)
|
|
port_test_socket = socket.socket()
|
|
port_test_socket.bind(("127.0.0.1", port))
|
|
port_test_socket.close()
|
|
break
|
|
except socket.error:
|
|
port += 1
|
|
return port
|
|
|
|
|
|
class TuneServerSuite(unittest.TestCase):
|
|
def basicSetup(self):
|
|
ray.init(num_cpus=4, num_gpus=1)
|
|
port = get_valid_port()
|
|
self.runner = TrialRunner(server_port=port)
|
|
runner = self.runner
|
|
kwargs = {
|
|
"stopping_criterion": {
|
|
"training_iteration": 3
|
|
},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
|
for t in trials:
|
|
runner.add_trial(t)
|
|
client = TuneClient("localhost", port)
|
|
return runner, client
|
|
|
|
def tearDown(self):
|
|
print("Tearing down....")
|
|
try:
|
|
self.runner._server.shutdown()
|
|
self.runner = None
|
|
except Exception as e:
|
|
print(e)
|
|
ray.shutdown()
|
|
_register_all()
|
|
|
|
def testAddTrial(self):
|
|
runner, client = self.basicSetup()
|
|
for i in range(3):
|
|
runner.step()
|
|
spec = {
|
|
"run": "__fake",
|
|
"stop": {
|
|
"training_iteration": 3
|
|
},
|
|
"resources_per_trial": {
|
|
"cpu": 1,
|
|
"gpu": 1
|
|
},
|
|
}
|
|
client.add_trial("test", spec)
|
|
runner.step()
|
|
all_trials = client.get_all_trials()["trials"]
|
|
runner.step()
|
|
self.assertEqual(len(all_trials), 3)
|
|
|
|
def testGetTrials(self):
|
|
runner, client = self.basicSetup()
|
|
for i in range(3):
|
|
runner.step()
|
|
all_trials = client.get_all_trials()["trials"]
|
|
self.assertEqual(len(all_trials), 2)
|
|
tid = all_trials[0]["id"]
|
|
client.get_trial(tid)
|
|
runner.step()
|
|
self.assertEqual(len(all_trials), 2)
|
|
|
|
def testGetTrialsWithFunction(self):
|
|
runner, client = self.basicSetup()
|
|
test_trial = Trial(
|
|
"__fake",
|
|
trial_id="function_trial",
|
|
stopping_criterion={"training_iteration": 3},
|
|
config={"callbacks": {
|
|
"on_episode_start": lambda x: None
|
|
}})
|
|
runner.add_trial(test_trial)
|
|
|
|
for i in range(3):
|
|
runner.step()
|
|
all_trials = client.get_all_trials()["trials"]
|
|
self.assertEqual(len(all_trials), 3)
|
|
client.get_trial("function_trial")
|
|
runner.step()
|
|
self.assertEqual(len(all_trials), 3)
|
|
|
|
def testStopTrial(self):
|
|
"""Check if Stop Trial works."""
|
|
runner, client = self.basicSetup()
|
|
for i in range(2):
|
|
runner.step()
|
|
all_trials = client.get_all_trials()["trials"]
|
|
self.assertEqual(
|
|
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
|
|
|
|
tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"]
|
|
client.stop_trial(tid)
|
|
runner.step()
|
|
|
|
all_trials = client.get_all_trials()["trials"]
|
|
self.assertEqual(
|
|
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
|
|
|
|
def testStopExperiment(self):
|
|
"""Check if stop_experiment works."""
|
|
runner, client = self.basicSetup()
|
|
for i in range(2):
|
|
runner.step()
|
|
all_trials = client.get_all_trials()["trials"]
|
|
self.assertEqual(
|
|
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
|
|
|
|
client.stop_experiment()
|
|
runner.step()
|
|
self.assertTrue(runner.is_finished())
|
|
self.assertRaises(
|
|
requests.exceptions.ReadTimeout,
|
|
lambda: client.get_all_trials(timeout=1))
|
|
|
|
def testCurlCommand(self):
|
|
"""Check if Stop Trial works."""
|
|
runner, client = self.basicSetup()
|
|
for i in range(2):
|
|
runner.step()
|
|
stdout = subprocess.check_output(
|
|
"curl \"http://{}:{}/trials\"".format(client.server_address,
|
|
client.server_port),
|
|
shell=True)
|
|
self.assertNotEqual(stdout, None)
|
|
curl_trials = json.loads(stdout.decode())["trials"]
|
|
client_trials = client.get_all_trials()["trials"]
|
|
for curl_trial, client_trial in zip(curl_trials, client_trials):
|
|
self.assertEqual(curl_trial.keys(), client_trial.keys())
|
|
self.assertEqual(curl_trial["id"], client_trial["id"])
|
|
self.assertEqual(curl_trial["trainable_name"],
|
|
client_trial["trainable_name"])
|
|
self.assertEqual(curl_trial["status"], client_trial["status"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|