From 5ab395236ba97b66627e86de78ed10e24f9c41b8 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 30 Jan 2020 00:34:08 -0800 Subject: [PATCH] [tune] Experiment stopping API (#6886) --- doc/source/tune-usage.rst | 27 ++++++-- python/ray/tune/__init__.py | 2 + .../ray/tune/examples/pbt_convnet_example.py | 11 ++-- python/ray/tune/experiment.py | 40 +++++++----- python/ray/tune/stopper.py | 63 +++++++++++++++++++ python/ray/tune/tests/test_api.py | 38 ++++++++--- python/ray/tune/trial.py | 3 - python/ray/tune/trial_runner.py | 18 +++++- python/ray/tune/tune.py | 9 ++- 9 files changed, 172 insertions(+), 39 deletions(-) create mode 100644 python/ray/tune/stopper.py diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 516435749..64990c263 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -492,23 +492,38 @@ In the example below, each trial will be stopped either when it completes 10 ite For more flexibility, you can pass in a function instead. If a function is passed in, it must take ``(trial_id, result)`` as arguments and return a boolean (``True`` if trial should be stopped and ``False`` otherwise). -You can use this to stop all trials after the criteria is fulfilled by any individual trial: +.. code-block:: python + + + def stopper(trial_id, result): + return result["mean_accuracy"] / result["training_iteration"] > 5 + + tune.run(my_trainable, stop=stopper) + +Finally, you can implement the ``Stopper`` abstract class for stopping entire experiments. For example, the following example stops all trials after the criteria is fulfilled by any individual trial, and prevents new ones from starting: .. code-block:: python - class Stopper: + from ray.tune import Stopper + + class CustomStopper(Stopper): def __init__(self): self.should_stop = False - def stop(self, trial_id, result): + def __call__(self, trial_id, result): if not self.should_stop and result['foo'] > 10: self.should_stop = True return self.should_stop - stopper = Stopper() - tune.run(my_trainable, stop=stopper.stop) + def stop_all(self): + """Returns whether to stop trials and prevent new ones from starting.""" + return self.should_stop -Note that in the above example all trials will not stop immediately, but will do so once their current iterations are complete. + stopper = CustomStopper() + tune.run(my_trainable, stop=stopper) + + +Note that in the above example the currently running trials will not stop immediately but will do so once their current iterations are complete. Auto-Filled Results ------------------- diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 132a6edd7..37e89b0df 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -2,6 +2,7 @@ from ray.tune.error import TuneError from ray.tune.tune import run_experiments, run from ray.tune.experiment import Experiment from ray.tune.analysis import ExperimentAnalysis, Analysis +from ray.tune.stopper import Stopper from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable from ray.tune.durable_trainable import DurableTrainable @@ -20,6 +21,7 @@ __all__ = [ "register_trainable", "run", "run_experiments", + "Stopper", "Experiment", "function", "sample_from", diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index 28bf407af..387fcb4ac 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -98,17 +98,20 @@ if __name__ == "__main__": # __pbt_end__ # __tune_begin__ - class Stopper: + class CustomStopper(tune.Stopper): def __init__(self): self.should_stop = False - def stop(self, trial_id, result): + def __call__(self, trial_id, result): max_iter = 5 if args.smoke_test else 100 if not self.should_stop and result["mean_accuracy"] > 0.96: self.should_stop = True return self.should_stop or result["training_iteration"] >= max_iter - stopper = Stopper() + def stop_all(self): + return self.should_stop + + stopper = CustomStopper() analysis = tune.run( PytorchTrainble, @@ -116,7 +119,7 @@ if __name__ == "__main__": scheduler=scheduler, reuse_actors=True, verbose=1, - stop=stopper.stop, + stop=stopper, export_formats=[ExportFormat.MODEL], checkpoint_score_attr="mean_accuracy", checkpoint_freq=5, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 0bd5a7fe4..7c4415b53 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -1,14 +1,13 @@ import copy -import inspect import logging import os import six -import types from ray.tune.error import TuneError from ray.tune.registry import register_trainable, get_trainable_cls from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.sample import sample_from +from ray.tune.stopper import FunctionStopper, Stopper logger = logging.getLogger(__name__) @@ -107,18 +106,6 @@ class Experiment: _raise_deprecation_note( "sync_function", "sync_to_driver", soft=False) - stop = stop or {} - if not isinstance(stop, dict) and not callable(stop): - raise ValueError("Invalid stop criteria: {}. Must be a callable " - "or dict".format(stop)) - if callable(stop): - nargs = len(inspect.getargspec(stop).args) - is_method = isinstance(stop, types.MethodType) - if (is_method and nargs != 3) or (not is_method and nargs != 2): - raise ValueError( - "Invalid stop criteria: {}. Callable " - "criteria must take exactly 2 parameters.".format(stop)) - config = config or {} self._run_identifier = Experiment.register_if_needed(run) self.name = name or self._run_identifier @@ -127,11 +114,30 @@ class Experiment: else: self.remote_checkpoint_dir = None + self._stopper = None + stopping_criteria = {} + if not stop: + pass + elif isinstance(stop, dict): + stopping_criteria = stop + elif callable(stop): + if FunctionStopper.is_valid_function(stop): + self._stopper = FunctionStopper(stop) + elif issubclass(type(stop), Stopper): + self._stopper = stop + else: + raise ValueError("Provided stop object must be either a dict, " + "a function, or a subclass of " + "`ray.tune.Stopper`.") + else: + raise ValueError("Invalid stop criteria: {}. Must be a " + "callable or dict".format(stop)) + _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir) spec = { "run": self._run_identifier, - "stop": stop, + "stop": stopping_criteria, "config": config, "resources_per_trial": resources_per_trial, "num_samples": num_samples, @@ -214,6 +220,10 @@ class Experiment: else: raise TuneError("Improper 'run' - not string nor trainable.") + @property + def stopper(self): + return self._stopper + @property def local_dir(self): return self.spec.get("local_dir") diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py new file mode 100644 index 000000000..984239105 --- /dev/null +++ b/python/ray/tune/stopper.py @@ -0,0 +1,63 @@ +class Stopper: + """Base class for implementing a Tune experiment stopper. + + Allows users to implement experiment-level stopping via ``stop_all``. By + default, this class does not stop any trials. Subclasses need to + implement ``__call__`` and ``stop_all``. + + .. code-block:: python + + import time + from ray import tune + from ray.tune import Stopper + + class TimeStopper(Stopper): + def __init__(self): + self._start = time.time() + self._deadline = 300 + + def __call__(self, trial_id, result): + return False + + def stop_all(self): + return time.time() - self._start > self.deadline + + tune.run(Trainable, num_samples=200, stop=TimeStopper()) + + """ + + def __call__(self, trial_id, result): + """Returns true if the trial should be terminated given the result.""" + raise NotImplementedError + + def stop_all(self): + """Returns true if the experiment should be terminated.""" + raise NotImplementedError + + +class NoopStopper(Stopper): + def __call__(self, trial_id, result): + return False + + def stop_all(self): + return False + + +class FunctionStopper(Stopper): + def __init__(self, function): + self._fn = function + + def __call__(self, trial_id, result): + return self._fn(trial_id, result) + + def stop_all(self): + return False + + @classmethod + def is_valid_function(cls, fn): + is_function = callable(fn) and not issubclass(type(fn), Stopper) + if is_function and hasattr(fn, "stop_all"): + raise ValueError( + "Stop object must be ray.tune.Stopper subclass to be detected " + "correctly.") + return is_function diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 22d51e614..c578b26bd 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -10,7 +10,7 @@ import ray from ray.rllib import _register_all from ray import tune -from ray.tune import DurableTrainable, Trainable, TuneError +from ray.tune import DurableTrainable, Trainable, TuneError, Stopper from ray.tune import register_env, register_trainable, run_experiments from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.trial import Trial @@ -452,28 +452,52 @@ class TrainableFunctionApiTest(unittest.TestCase): for i in range(10): reporter(test=i) - class Stopper: + class Stopclass: def stop(self, trial_id, result): return result["test"] > 6 - [trial] = tune.run(train, stop=Stopper().stop).trials + [trial] = tune.run(train, stop=Stopclass().stop).trials self.assertEqual(trial.last_result["training_iteration"], 8) + def testStopper(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + class CustomStopper(Stopper): + def __init__(self): + self._count = 0 + + def __call__(self, trial_id, result): + print("called") + self._count += 1 + return result["test"] > 6 + + def stop_all(self): + return self._count > 5 + + trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials + self.assertTrue(all(t.status == Trial.TERMINATED for t in trials)) + self.assertTrue( + any( + t.last_result.get("training_iteration") is None + for t in trials)) + def testBadStoppingFunction(self): def train(config, reporter): for i in range(10): reporter(test=i) - class Stopper: + class CustomStopper: def stop(self, result): return result["test"] > 6 def stop(result): return result["test"] > 6 - with self.assertRaises(ValueError): - tune.run(train, stop=Stopper().stop) - with self.assertRaises(ValueError): + with self.assertRaises(TuneError): + tune.run(train, stop=CustomStopper().stop) + with self.assertRaises(TuneError): tune.run(train, stop=stop) def testEarlyReturn(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index a7bfbbd4e..f02d082d3 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -330,9 +330,6 @@ class Trial: if result.get(DONE): return True - if callable(self.stopping_criterion): - return self.stopping_criterion(self.trial_id, result) - for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 2b171e250..8e2e3421f 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -9,6 +9,7 @@ import types import ray.cloudpickle as cloudpickle from ray.tune import TuneError +from ray.tune.stopper import NoopStopper from ray.tune.progress_reporter import trial_progress_str from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, @@ -98,6 +99,7 @@ class TrialRunner: local_checkpoint_dir=None, remote_checkpoint_dir=None, sync_to_cloud=None, + stopper=None, resume=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, @@ -115,6 +117,8 @@ class TrialRunner: remote_checkpoint_dir (str): Remote path where global checkpoints are stored and restored from. Used if `resume` == REMOTE. + stopper: Custom class for stopping whole experiments. See + ``Stopper``. resume (str|False): see `tune.py:run`. sync_to_cloud (func|str): See `tune.py:run`. server_port (int): Port number for launching TuneServer. @@ -149,7 +153,7 @@ class TrialRunner: self._remote_checkpoint_dir = remote_checkpoint_dir self._syncer = get_cloud_syncer(local_checkpoint_dir, remote_checkpoint_dir, sync_to_cloud) - + self._stopper = stopper or NoopStopper() self._resumed = False if self._validate_resume(resume_type=resume): @@ -331,6 +335,8 @@ class TrialRunner: else: self.trial_executor.on_no_available_trials(self) + self._stop_experiment_if_needed() + try: with warn_if_slow("experiment_checkpoint"): self.checkpoint() @@ -385,6 +391,13 @@ class TrialRunner: """Returns whether this runner has at least the specified resources.""" return self.trial_executor.has_resources(resources) + def _stop_experiment_if_needed(self): + """Stops all trials if the user condition is satisfied.""" + + if self._stopper.stop_all(): + [self.trial_executor.stop_trial(t) for t in self._trials] + logger.info("All trials stopped due to ``stopper.stop_all``.") + def _get_next_trial(self): """Replenishes queue. @@ -435,7 +448,8 @@ class TrialRunner: self._total_time += result.get(TIME_THIS_ITER_S, 0) flat_result = flatten_dict(result) - if trial.should_stop(flat_result): + if self._stopper(trial.trial_id, + result) or trial.should_stop(flat_result): # Hook into scheduler self._scheduler_alg.on_trial_complete(self, trial, flat_result) self._search_alg.on_trial_complete( diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 305ff832a..f17e85f1a 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -112,11 +112,14 @@ def run(run_or_experiment, If Experiment, then Tune will execute training based on Experiment.spec. name (str): Name of experiment. - stop (dict|func): The stopping criteria. If dict, the keys may be + stop (dict|callable): The stopping criteria. If dict, the keys may be any field in the return result of 'train()', whichever is reached first. If function, it must take (trial_id, result) as arguments and return a boolean (True if trial should be stopped, - False otherwise). + False otherwise). This can also be a subclass of + ``ray.tune.Stopper``, which allows users to implement + custom experiment-wide stopping (i.e., stopping an entire Tune + run based on some time constraint). config (dict): Algorithm-specific configuration for Tune variant generation (e.g. env, hyperparams). Defaults to empty dict. Custom search algorithms may ignore this. @@ -243,6 +246,7 @@ def run(run_or_experiment, logger.info( "Running multiple concurrent experiments is experimental and may " "not work with certain features.") + for i, exp in enumerate(experiments): if not isinstance(exp, Experiment): run_identifier = Experiment.register_if_needed(exp) @@ -281,6 +285,7 @@ def run(run_or_experiment, local_checkpoint_dir=experiments[0].checkpoint_dir, remote_checkpoint_dir=experiments[0].remote_checkpoint_dir, sync_to_cloud=sync_to_cloud, + stopper=experiments[0].stopper, checkpoint_period=global_checkpoint_period, resume=resume, launch_web_server=with_server,