diff --git a/.travis.yml b/.travis.yml index 4a1985666..5c14718c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -123,6 +123,7 @@ script: - python test/monitor_test.py - python test/trial_runner_test.py - python test/trial_scheduler_test.py + - python test/tune_server_test.py - python test/cython_test.py - python -m pytest python/ray/dataframe/test/test_dataframe.py - python -m pytest python/ray/dataframe/test/test_series.py diff --git a/doc/source/tune.rst b/doc/source/tune.rst index 2dcfe7498..01fe95338 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -159,3 +159,26 @@ Running in a large cluster -------------------------- The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script `__. + +Client API +---------- + +You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with a flag, either from the command-line, e.g.: + +:: + + cd ray/python/tune + ./tune.py -f examples/tune_mnist_ray.yaml --server=True --server-port=4321 + +Or within the Python API, e.g.: +:: + + run_experiments({...}, with_server=True, server_port=4321) + +Then, on the client side, you can use the following class. The server address defaults to ``localhost:4321``. If on a cluster, you may want to forward this port so that you can use the Client on your local machine. + +.. autoclass:: ray.tune.web_server.TuneClient + :members: + + +For an example notebook for using the Client API, see the `Client API Example `__. diff --git a/python/ray/tune/TuneClient.ipynb b/python/ray/tune/TuneClient.ipynb new file mode 100644 index 000000000..b62932bac --- /dev/null +++ b/python/ray/tune/TuneClient.ipynb @@ -0,0 +1,88 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.tune.web_server import TuneClient\n", + "\n", + "manager = TuneClient(tune_address=\"localhost:4321\")\n", + "\n", + "x = manager.get_all_trials()\n", + "\n", + "[((y[\"id\"]), y[\"status\"]) for y in x[\"trials\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "scrolled": false + }, + "outputs": [], + "source": [ + "for y in x[\"trials\"][-10:]:\n", + " manager.stop_trial(y[\"id\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from ray.tune.variant_generator import generate_trials\n", + "import yaml\n", + "\n", + "with open(\"../rllib/tuned_examples/hyperband-cartpole.yaml\") as f:\n", + " d = yaml.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "name, spec = [x for x in d.items()][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "manager.add_trial(name, spec)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/ray/tune/error.py b/python/ray/tune/error.py index badf60a08..b23d62a08 100644 --- a/python/ray/tune/error.py +++ b/python/ray/tune/error.py @@ -6,3 +6,8 @@ from __future__ import print_function class TuneError(Exception): """General error class raised by ray.tune.""" pass + + +class TuneManagerError(TuneError): + """Error raised in operating the Tune Manager.""" + pass diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index e56923890..2ee51c0c0 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -87,7 +87,9 @@ class HyperBandScheduler(FIFOScheduler): self._time_attr = time_attr def on_trial_add(self, trial_runner, trial): - """On a new trial add, if current bracket is not filled, + """Adds new trial. + + On a new trial add, if current bracket is not filled, add to current bracket. Else, if current band is not filled, create new bracket, add to current bracket. Else, create new iteration, create new bracket, add to bracket.""" @@ -136,9 +138,8 @@ class HyperBandScheduler(FIFOScheduler): This scheduler will not start trials but will stop trials. The current running trial will not be handled, - as the trialrunner will be given control to handle it. + as the trialrunner will be given control to handle it.""" - # TODO(rliaw) should be only called if trial has not errored""" bracket, _ = self._trial_info[trial] bracket.update_trial_stats(trial, result) @@ -160,16 +161,17 @@ class HyperBandScheduler(FIFOScheduler): action = TrialScheduler.PAUSE if bracket.cur_iter_done(): if bracket.finished(): - self._cleanup_bracket(trial_runner, bracket) + bracket.cleanup_full(trial_runner) return TrialScheduler.CONTINUE good, bad = bracket.successive_halving(self._reward_attr) # kill bad trials + self._num_stopped += len(bad) for t in bad: if t.status == Trial.PAUSED: - self._cleanup_trial(trial_runner, t, bracket, hard=True) + trial_runner.stop_trial(t) elif t.status == Trial.RUNNING: - self._cleanup_trial(trial_runner, t, bracket, hard=False) + bracket.cleanup_trial(t) action = TrialScheduler.STOP else: raise Exception("Trial with unexpected status encountered") @@ -185,47 +187,30 @@ class HyperBandScheduler(FIFOScheduler): action = TrialScheduler.CONTINUE return action - def _cleanup_trial(self, trial_runner, t, bracket, hard=False): - """Bookkeeping for trials finished. If `hard=True`, then - this scheduler will force the trial_runner to release resources. + def on_trial_remove(self, trial_runner, trial): + """Notification when trial terminates. - Otherwise, only clean up trial information locally.""" - self._num_stopped += 1 - if hard: - trial_runner._stop_trial(t) - bracket.cleanup_trial(t) - - def _cleanup_bracket(self, trial_runner, bracket): - """Cleans up bracket after bracket is completely finished. - Lets the last trial continue to run until termination condition - kicks in.""" - for trial in bracket.current_trials(): - if (trial.status == Trial.PAUSED): - self._cleanup_trial( - trial_runner, trial, bracket, - hard=True) + Trial info is removed from bracket. Triggers halving if bracket is + not finished.""" + bracket, _ = self._trial_info[trial] + bracket.cleanup_trial(trial) + if not bracket.finished(): + self._process_bracket(trial_runner, bracket, trial) def on_trial_complete(self, trial_runner, trial, result): """Cleans up trial info from bracket if trial completed early.""" - - bracket, _ = self._trial_info[trial] - self._cleanup_trial(trial_runner, trial, bracket, hard=False) - self._process_bracket(trial_runner, bracket, trial) + self.on_trial_remove(trial_runner, trial) def on_trial_error(self, trial_runner, trial): """Cleans up trial info from bracket if trial errored early.""" - - bracket, _ = self._trial_info[trial] - self._cleanup_trial(trial_runner, trial, bracket, hard=False) - self._process_bracket(trial_runner, bracket, trial) + self.on_trial_remove(trial_runner, trial) def choose_trial_to_run(self, trial_runner, *args): """Fair scheduling within iteration by completion percentage. - List of trials not used since all trials are tracked as state - of scheduler. - If iteration is occupied (ie, no trials to run), then look into - next iteration.""" + List of trials not used since all trials are tracked as state + of scheduler. If iteration is occupied (ie, no trials to run), + then look into next iteration.""" for hyperband in self._hyperbands: for bracket in sorted(hyperband, @@ -237,6 +222,7 @@ class HyperBandScheduler(FIFOScheduler): return None def debug_string(self): + # TODO(rliaw): This debug string needs work brackets = [ "({0}/{1})".format( len(bracket._live_trials), len(bracket._all_trials)) @@ -301,8 +287,11 @@ class Bracket(): return False def filled(self): - """We will only let new trials be added at current level, - minimizing the need to backtrack and bookkeep previous medians""" + """Checks if bracket is filled. + + Only let new trials be added at current level minimizing the need + to backtrack and bookkeep previous medians.""" + return len(self._live_trials) == self._n def successive_halving(self, reward_attr): @@ -346,6 +335,15 @@ class Bracket(): assert trial in self._live_trials del self._live_trials[trial] + def cleanup_full(self, trial_runner): + """Cleans up bracket after bracket is completely finished. + + Lets the last trial continue to run until termination condition + kicks in.""" + for trial in self.current_trials(): + if (trial.status == Trial.PAUSED): + trial_runner.stop_trial(trial) + def completion_percentage(self): """Returns a progress metric. @@ -374,5 +372,8 @@ class Bracket(): "r={}".format(self._r), "progress={}".format(self.completion_percentage()) ]) + return "Bracket({})".format(status) + + def debug_string(self): trials = ", ".join([t.status for t in self._live_trials]) - return "Bracket({})[{}]".format(status, trials) + return "{}[{}]".format(self, trials) diff --git a/python/ray/tune/median_stopping_rule.py b/python/ray/tune/median_stopping_rule.py index 1f3ab6aa0..8306ea3aa 100644 --- a/python/ray/tune/median_stopping_rule.py +++ b/python/ray/tune/median_stopping_rule.py @@ -5,6 +5,7 @@ from __future__ import print_function import collections import numpy as np +from ray.tune.trial import Trial from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler @@ -74,6 +75,11 @@ class MedianStoppingRule(FIFOScheduler): self._results[trial].append(result) self._completed_trials.add(trial) + def on_trial_remove(self, trial_runner, trial): + """Marks trial as completed if it is paused and has previously ran.""" + if trial.status is Trial.PAUSED and trial in self._results: + self._completed_trials.add(trial) + def debug_string(self): return "Using MedianStoppingRule: num_stopped={}.".format( len(self._stopped_trials)) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 3234ae031..6f76f4067 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -29,7 +29,7 @@ TrainingResult = namedtuple("TrainingResult", [ # (Required) Accumulated timesteps for this entire experiment. "timesteps_total", - # (Optional) If training is finished. + # (Optional) If training is terminated. "done", # (Optional) Custom metadata to report for this iteration. diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 8285c7b4b..76b8f8ab6 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -9,6 +9,7 @@ import ray import os from collections import namedtuple +from ray.utils import random_string, binary_to_hex from ray.tune import TuneError from ray.tune.logger import NoopLogger, UnifiedLogger from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print @@ -105,6 +106,7 @@ class Trial(object): self.location = None self.logdir = None self.result_logger = None + self.trial_id = binary_to_hex(random_string())[:8] def start(self): """Starts this trial. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 64729a476..a948e3358 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -8,6 +8,7 @@ import time import traceback from ray.tune import TuneError +from ray.tune.web_server import TuneServer from ray.tune.trial import Trial, Resources from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler @@ -34,8 +35,14 @@ class TrialRunner(object): misleading benchmark results. """ - def __init__(self, scheduler=None): - """Initializes a new TrialRunner.""" + def __init__(self, scheduler=None, launch_web_server=False, + server_port=TuneServer.DEFAULT_PORT): + """Initializes a new TrialRunner. + + Args: + scheduler (TrialScheduler): Defaults to FIFOScheduler. + launch_web_server (bool): Flag for starting TuneServer + server_port (int): Port number for launching TuneServer""" self._scheduler_alg = scheduler or FIFOScheduler() self._trials = [] @@ -49,6 +56,10 @@ class TrialRunner(object): self._global_time_limit = float( os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf'))) self._total_time = 0 + self._server = None + if launch_web_server: + self._server = TuneServer(self, server_port) + self._stop_queue = [] def is_finished(self): """Returns whether all trials have finished running.""" @@ -70,7 +81,6 @@ class TrialRunner(object): Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ - if self._can_launch_more(): self._launch_trial() elif self._running: @@ -91,6 +101,16 @@ class TrialRunner(object): "trials with sufficient resources.") raise TuneError("Called step when all trials finished?") + if self._server: + self._process_requests() + + if self.is_finished(): + self._server.shutdown() + + def get_trial(self, tid): + trial = [t for t in self._trials if t.trial_id == tid] + return trial[0] if trial else None + def get_trials(self): """Returns the list of trials managed by this TrialRunner. @@ -207,6 +227,43 @@ class TrialRunner(object): assert self._committed_resources.cpu >= 0 assert self._committed_resources.gpu >= 0 + def request_stop_trial(self, trial): + self._stop_queue.append(trial) + + def _process_requests(self): + while self._stop_queue: + t = self._stop_queue.pop() + self.stop_trial(t) + + def stop_trial(self, trial): + """Stops trial. + + Trials may be stopped at any time. If trial is in state PENDING + or PAUSED, calls `scheduler.on_trial_remove`. Otherwise waits for + result for the trial and calls `scheduler.on_trial_complete` + if RUNNING.""" + error = False + + if trial.status in [Trial.ERROR, Trial.TERMINATED]: + return + elif trial.status in [Trial.PENDING, Trial.PAUSED]: + self._scheduler_alg.on_trial_remove(self, trial) + elif trial.status is Trial.RUNNING: + # NOTE: There should only be one... + result_id = [rid for rid, t in self._running.items() + if t is trial][0] + self._running.pop(result_id) + try: + result = ray.get(result_id) + trial.update_last_result(result, terminate=True) + self._scheduler_alg.on_trial_complete(self, trial, result) + except Exception: + print("Error processing event:", traceback.format_exc()) + self._scheduler_alg.on_trial_error(self, trial) + error = True + + self._stop_trial(trial, error=error) + def _stop_trial(self, trial, error=False): """Only returns resources if resources allocated.""" prior_status = trial.status diff --git a/python/ray/tune/trial_scheduler.py b/python/ray/tune/trial_scheduler.py index 3076c2854..5aa5238fc 100644 --- a/python/ray/tune/trial_scheduler.py +++ b/python/ray/tune/trial_scheduler.py @@ -26,14 +26,24 @@ class TrialScheduler(object): """Called on each intermediate result returned by a trial. At this point, the trial scheduler can make a decision by returning - one of CONTINUE, PAUSE, and STOP.""" + one of CONTINUE, PAUSE, and STOP. This will only be called when the + trial is in the RUNNING state.""" raise NotImplementedError def on_trial_complete(self, trial_runner, trial, result): """Notification for the completion of trial. - This will only be called when the trial completes naturally.""" + This will only be called when the trial is in the RUNNING state and + either completes naturally or by manual termination.""" + + raise NotImplementedError + + def on_trial_remove(self, trial_runner, trial): + """Called to remove trial. + + This is called when the trial is in PAUSED or PENDING state. Otherwise, + call `on_trial_complete`.""" raise NotImplementedError @@ -66,6 +76,9 @@ class FIFOScheduler(TrialScheduler): def on_trial_complete(self, trial_runner, trial, result): pass + def on_trial_remove(self, trial_runner, trial): + pass + def choose_trial_to_run(self, trial_runner): for trial in trial_runner.get_trials(): if (trial.status == Trial.PENDING and diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index acff5a264..a4fa9497e 100755 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -16,6 +16,7 @@ from ray.tune.median_stopping_rule import MedianStoppingRule from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.trial_scheduler import FIFOScheduler +from ray.tune.web_server import TuneServer from ray.tune.variant_generator import generate_trials @@ -41,6 +42,10 @@ parser.add_argument("--scheduler", default="FIFO", type=str, help="FIFO, MedianStopping, or HyperBand") parser.add_argument("--scheduler-config", default="{}", type=json.loads, help="Config options to pass to the scheduler.") +parser.add_argument("--server", default=False, type=bool, + help="Option to launch Tune Server") +parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT, + type=int, help="Option to launch Tune Server") parser.add_argument("-f", "--config-file", required=True, type=str, help="Read experiment options from this JSON/YAML file.") @@ -61,10 +66,13 @@ def _make_scheduler(args): args.scheduler, _SCHEDULERS.keys())) -def run_experiments(experiments, scheduler=None, **ray_args): +def run_experiments(experiments, scheduler=None, with_server=False, + server_port=TuneServer.DEFAULT_PORT, **ray_args): if scheduler is None: scheduler = FIFOScheduler() - runner = TrialRunner(scheduler) + + runner = TrialRunner( + scheduler, launch_web_server=with_server, server_port=server_port) for name, spec in experiments.items(): for trial in generate_trials(spec, name): @@ -78,6 +86,7 @@ def run_experiments(experiments, scheduler=None, **ray_args): print(runner.debug_string()) for trial in runner.get_trials(): + # TODO(rliaw): What about errored? if trial.status != Trial.TERMINATED: raise TuneError("Trial did not complete", trial) @@ -90,5 +99,6 @@ if __name__ == "__main__": with open(args.config_file) as f: experiments = yaml.load(f) run_experiments( - experiments, _make_scheduler(args), redis_address=args.redis_address, + experiments, _make_scheduler(args), with_server=args.server, + server_port=args.server_port, redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus) diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py new file mode 100644 index 000000000..c4d3a4a7f --- /dev/null +++ b/python/ray/tune/web_server.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import requests +import json +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler + +from ray.tune.error import TuneError, TuneManagerError +from ray.tune.variant_generator import generate_trials + + +class TuneClient(object): + """Client to interact with ongoing Tune experiment. + + Requires server to have started running.""" + STOP = "STOP" + ADD = "ADD" + GET_LIST = "GET_LIST" + GET_TRIAL = "GET_TRIAL" + + def __init__(self, tune_address): + # TODO(rliaw): Better to specify address and port forward + self._tune_address = tune_address + self._path = "http://{}".format(tune_address) + + def get_all_trials(self): + """Returns a list of all trials (trial_id, config, status).""" + return self._get_response( + {"command": TuneClient.GET_LIST}) + + def get_trial(self, trial_id): + """Returns the last result for queried trial.""" + return self._get_response( + {"command": TuneClient.GET_TRIAL, + "trial_id": trial_id}) + + def add_trial(self, name, trial_spec): + """Adds a trial of `name` with configurations.""" + # TODO(rliaw): have better way of specifying a new trial + return self._get_response( + {"command": TuneClient.ADD, + "name": name, + "spec": trial_spec}) + + def stop_trial(self, trial_id): + """Requests to stop trial.""" + return self._get_response( + {"command": TuneClient.STOP, + "trial_id": trial_id}) + + def _get_response(self, data): + payload = json.dumps(data).encode() + response = requests.get(self._path, data=payload) + parsed = response.json() + return parsed + + +def RunnerHandler(runner): + class Handler(BaseHTTPRequestHandler): + + def do_GET(self): + content_len = int(self.headers.get('Content-Length'), 0) + raw_body = self.rfile.read(content_len) + parsed_input = json.loads(raw_body.decode()) + status, response = self.execute_command(parsed_input) + if status: + self.send_response(200) + else: + self.send_response(400) + self.end_headers() + self.wfile.write(json.dumps( + response).encode()) + + def trial_info(self, trial): + if trial.last_result: + result = trial.last_result._asdict() + else: + result = None + info_dict = { + "id": trial.trial_id, + "trainable_name": trial.trainable_name, + "config": trial.config, + "status": trial.status, + "result": result + } + return info_dict + + def execute_command(self, args): + def get_trial(): + trial = runner.get_trial(args["trial_id"]) + if trial is None: + error = "Trial ({}) not found.".format(args["trial_id"]) + raise TuneManagerError(error) + else: + return trial + + command = args["command"] + response = {} + try: + if command == TuneClient.GET_LIST: + response["trials"] = [self.trial_info(t) + for t in runner.get_trials()] + elif command == TuneClient.GET_TRIAL: + trial = get_trial() + response["trial_info"] = self.trial_info(trial) + elif command == TuneClient.STOP: + trial = get_trial() + runner.request_stop_trial(trial) + elif command == TuneClient.ADD: + name = args["name"] + spec = args["spec"] + for trial in generate_trials(spec, name): + runner.add_trial(trial) + else: + raise TuneManagerError("Unknown command.") + status = True + except TuneError as e: + status = False + response["message"] = str(e) + + return status, response + + return Handler + + +class TuneServer(threading.Thread): + + DEFAULT_PORT = 4321 + + def __init__(self, runner, port=None): + + threading.Thread.__init__(self) + self._port = port if port else self.DEFAULT_PORT + address = ('localhost', self._port) + print("Starting Tune Server...") + self._server = HTTPServer( + address, RunnerHandler(runner)) + self.start() + + def run(self): + self._server.serve_forever() + + def shutdown(self): + self._server.shutdown() diff --git a/test/trial_runner_test.py b/test/trial_runner_test.py index 5cd58d9c2..6aa73ddd1 100644 --- a/test/trial_runner_test.py +++ b/test/trial_runner_test.py @@ -497,6 +497,46 @@ class TrialRunnerTest(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) + def testStopTrial(self): + ray.init(num_cpus=4, num_gpus=2) + runner = TrialRunner() + kwargs = { + "stopping_criterion": {"training_iteration": 5}, + "resources": Resources(cpu=1, gpu=1), + } + trials = [ + Trial("__fake", **kwargs), + Trial("__fake", **kwargs), + Trial("__fake", **kwargs), + Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(trials[1].status, Trial.PENDING) + + # Stop trial while running + runner.stop_trial(trials[0]) + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.PENDING) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.RUNNING) + self.assertEqual(trials[-1].status, Trial.PENDING) + + # Stop trial while pending + runner.stop_trial(trials[-1]) + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.RUNNING) + self.assertEqual(trials[-1].status, Trial.TERMINATED) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.RUNNING) + self.assertEqual(trials[2].status, Trial.RUNNING) + self.assertEqual(trials[-1].status, Trial.TERMINATED) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py index a31b0959b..48ec1603e 100644 --- a/test/trial_scheduler_test.py +++ b/test/trial_scheduler_test.py @@ -140,8 +140,25 @@ class EarlyStoppingSuite(unittest.TestCase): class _MockTrialRunner(): - def _stop_trial(self, trial): - trial.stop() + def __init__(self, scheduler): + self._scheduler_alg = scheduler + + def process_action(self, trial, action): + if action == TrialScheduler.CONTINUE: + pass + elif action == TrialScheduler.PAUSE: + self._pause_trial(trial) + elif action == TrialScheduler.STOP: + trial.stop() + + def stop_trial(self, trial): + if trial.status in [Trial.ERROR, Trial.TERMINATED]: + return + elif trial.status in [Trial.PENDING, Trial.PAUSED]: + self._scheduler_alg.on_trial_remove(self, trial) + else: + + self._scheduler_alg.on_trial_complete(self, trial, result(100, 10)) def has_resources(self, resources): return True @@ -168,7 +185,7 @@ class HyperbandSuite(unittest.TestCase): for i in range(num_trials): t = Trial("__fake") sched.on_trial_add(None, t) - runner = _MockTrialRunner() + runner = _MockTrialRunner(sched) return sched, runner def default_statistics(self): @@ -186,14 +203,6 @@ class HyperbandSuite(unittest.TestCase): def downscale(self, n, sched): return int(np.ceil(n / sched._eta)) - def process(self, trl, mock_runner, action): - if action == TrialScheduler.CONTINUE: - pass - elif action == TrialScheduler.PAUSE: - mock_runner._pause_trial(trl) - elif action == TrialScheduler.STOP: - self.stopTrial(trl, mock_runner) - def basicSetup(self): """Setup and verify full band. """ @@ -224,10 +233,6 @@ class HyperbandSuite(unittest.TestCase): return sched - def stopTrial(self, trial, mock_runner): - self.assertNotEqual(trial.status, Trial.TERMINATED) - mock_runner._stop_trial(trial) - def testConfigSameEta(self): sched = HyperBandScheduler() i = 0 @@ -283,7 +288,7 @@ class HyperbandSuite(unittest.TestCase): mock_runner, trl, result(cur_units, i)) if i < current_length - 1: self.assertEqual(action, TrialScheduler.PAUSE) - self.process(trl, mock_runner, action) + mock_runner.process_action(trl, action) self.assertEqual(action, TrialScheduler.CONTINUE) new_length = len(big_bracket.current_trials()) @@ -304,7 +309,7 @@ class HyperbandSuite(unittest.TestCase): for i, trl in reversed(list(enumerate(big_bracket.current_trials()))): action = sched.on_trial_result( mock_runner, trl, result(cur_units, i)) - self.process(trl, mock_runner, action) + mock_runner.process_action(trl, action) self.assertEqual(action, TrialScheduler.STOP) @@ -321,7 +326,7 @@ class HyperbandSuite(unittest.TestCase): for i, trl in enumerate(big_bracket.current_trials()): action = sched.on_trial_result( mock_runner, trl, result(cur_units, i)) - self.process(trl, mock_runner, action) + mock_runner.process_action(trl, action) self.assertEqual(action, TrialScheduler.CONTINUE) @@ -412,9 +417,9 @@ class HyperbandSuite(unittest.TestCase): mock_runner._launch_trial(t) for i, t in enumerate(bracket_trials): - status = sched.on_trial_result( + action = sched.on_trial_result( mock_runner, t, result(init_units, i)) - self.assertEqual(status, TrialScheduler.CONTINUE) + self.assertEqual(action, TrialScheduler.CONTINUE) t = Trial("__fake") sched.on_trial_add(None, t) mock_runner._launch_trial(t) @@ -442,7 +447,7 @@ class HyperbandSuite(unittest.TestCase): for i in range(stats["max_trials"]): t = Trial("__fake") sched.on_trial_add(None, t) - runner = _MockTrialRunner() + runner = _MockTrialRunner(sched) big_bracket = sched._hyperbands[0][-1] @@ -452,17 +457,11 @@ class HyperbandSuite(unittest.TestCase): # Provides results from 0 to 8 in order, keeping the last one running for i, trl in enumerate(big_bracket.current_trials()): - status = sched.on_trial_result(runner, trl, result2(1, i)) - if status == TrialScheduler.CONTINUE: - continue - elif status == TrialScheduler.PAUSE: - runner._pause_trial(trl) - elif status == TrialScheduler.STOP: - self.assertNotEqual(trl.status, Trial.TERMINATED) - self.stopTrial(trl, runner) + action = sched.on_trial_result(runner, trl, result2(1, i)) + runner.process_action(trl, action) new_length = len(big_bracket.current_trials()) - self.assertEqual(status, TrialScheduler.CONTINUE) + self.assertEqual(action, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched)) def testJumpingTime(self): @@ -476,21 +475,38 @@ class HyperbandSuite(unittest.TestCase): main_trials = big_bracket.current_trials()[:-1] jump = big_bracket.current_trials()[-1] for i, trl in enumerate(main_trials): - status = sched.on_trial_result(mock_runner, trl, result(1, i)) - if status == TrialScheduler.CONTINUE: - continue - elif status == TrialScheduler.PAUSE: - mock_runner._pause_trial(trl) - elif status == TrialScheduler.STOP: - self.assertNotEqual(trl.status, Trial.TERMINATED) - self.stopTrial(trl, mock_runner) + action = sched.on_trial_result(mock_runner, trl, result(1, i)) + mock_runner.process_action(trl, action) - status = sched.on_trial_result(mock_runner, jump, result(4, i)) - self.assertEqual(status, TrialScheduler.PAUSE) + action = sched.on_trial_result(mock_runner, jump, result(4, i)) + self.assertEqual(action, TrialScheduler.PAUSE) current_length = len(big_bracket.current_trials()) self.assertLess(current_length, 27) + def testRemove(self): + """Test with 4: start 1, remove 1 pending, add 2, remove 1 pending""" + sched, runner = self.schedulerSetup(4) + trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id) + runner._launch_trial(trials[0]) + sched.on_trial_result(runner, trials[0], result(1, 5)) + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(trials[1].status, Trial.PENDING) + + bracket, _ = sched._trial_info[trials[1]] + self.assertTrue(trials[1] in bracket._live_trials) + sched.on_trial_remove(runner, trials[1]) + self.assertFalse(trials[1] in bracket._live_trials) + + for i in range(2): + trial = Trial("__fake") + sched.on_trial_add(None, trial) + + bracket, _ = sched._trial_info[trial] + self.assertTrue(trial in bracket._live_trials) + sched.on_trial_remove(runner, trial) # where trial is not running + self.assertFalse(trial in bracket._live_trials) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/tune_server_test.py b/test/tune_server_test.py new file mode 100644 index 000000000..d0c0fe0ad --- /dev/null +++ b/test/tune_server_test.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import socket + +import ray +from ray.rllib import _register_all +from ray.tune.trial import Trial, Resources +from ray.tune.web_server import TuneClient +from ray.tune.trial_runner import TrialRunner + + +def get_valid_port(): + port = 4321 + while True: + try: + print("Trying port", port) + port_test_socket = socket.socket() + port_test_socket.bind(("127.0.0.1", port)) + port_test_socket.close() + break + except socket.error: + port += 1 + return port + + +class TuneServerSuite(unittest.TestCase): + def basicSetup(self): + ray.init(num_cpus=4, num_gpus=1) + port = get_valid_port() + self.runner = TrialRunner( + launch_web_server=True, server_port=port) + runner = self.runner + kwargs = { + "stopping_criterion": {"training_iteration": 3}, + "resources": Resources(cpu=1, gpu=1), + } + trials = [ + Trial("__fake", **kwargs), + Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + client = TuneClient("localhost:{}".format(port)) + return runner, client + + def tearDown(self): + print("Tearing down....") + try: + self.runner._server.shutdown() + self.runner = None + except Exception as e: + print(e) + ray.worker.cleanup() + _register_all() + + def testAddTrial(self): + runner, client = self.basicSetup() + for i in range(3): + runner.step() + spec = { + "run": "__fake", + "stop": {"training_iteration": 3}, + "resources": dict(cpu=1, gpu=1), + } + client.add_trial("test", spec) + runner.step() + all_trials = client.get_all_trials()["trials"] + runner.step() + self.assertEqual(len(all_trials), 3) + + def testGetTrials(self): + runner, client = self.basicSetup() + for i in range(3): + runner.step() + all_trials = client.get_all_trials()["trials"] + self.assertEqual(len(all_trials), 2) + tid = all_trials[0]["id"] + client.get_trial(tid) + runner.step() + self.assertEqual(len(all_trials), 2) + + def testStopTrial(self): + """Check if Stop Trial works""" + runner, client = self.basicSetup() + for i in range(2): + runner.step() + all_trials = client.get_all_trials()["trials"] + self.assertEqual( + len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1) + + tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"] + client.stop_trial(tid) + runner.step() + + all_trials = client.get_all_trials()["trials"] + self.assertEqual( + len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0) + + +if __name__ == "__main__": + unittest.main(verbosity=2)