[tune] Experiment stopping API (#6886)

This commit is contained in:
Richard Liaw
2020-01-30 00:34:08 -08:00
committed by GitHub
parent 5bdfc50bf6
commit 5ab395236b
9 changed files with 172 additions and 39 deletions
+2
View File
@@ -2,6 +2,7 @@ from ray.tune.error import TuneError
from ray.tune.tune import run_experiments, run
from ray.tune.experiment import Experiment
from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.stopper import Stopper
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
@@ -20,6 +21,7 @@ __all__ = [
"register_trainable",
"run",
"run_experiments",
"Stopper",
"Experiment",
"function",
"sample_from",
@@ -98,17 +98,20 @@ if __name__ == "__main__":
# __pbt_end__
# __tune_begin__
class Stopper:
class CustomStopper(tune.Stopper):
def __init__(self):
self.should_stop = False
def stop(self, trial_id, result):
def __call__(self, trial_id, result):
max_iter = 5 if args.smoke_test else 100
if not self.should_stop and result["mean_accuracy"] > 0.96:
self.should_stop = True
return self.should_stop or result["training_iteration"] >= max_iter
stopper = Stopper()
def stop_all(self):
return self.should_stop
stopper = CustomStopper()
analysis = tune.run(
PytorchTrainble,
@@ -116,7 +119,7 @@ if __name__ == "__main__":
scheduler=scheduler,
reuse_actors=True,
verbose=1,
stop=stopper.stop,
stop=stopper,
export_formats=[ExportFormat.MODEL],
checkpoint_score_attr="mean_accuracy",
checkpoint_freq=5,
+25 -15
View File
@@ -1,14 +1,13 @@
import copy
import inspect
import logging
import os
import six
import types
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import sample_from
from ray.tune.stopper import FunctionStopper, Stopper
logger = logging.getLogger(__name__)
@@ -107,18 +106,6 @@ class Experiment:
_raise_deprecation_note(
"sync_function", "sync_to_driver", soft=False)
stop = stop or {}
if not isinstance(stop, dict) and not callable(stop):
raise ValueError("Invalid stop criteria: {}. Must be a callable "
"or dict".format(stop))
if callable(stop):
nargs = len(inspect.getargspec(stop).args)
is_method = isinstance(stop, types.MethodType)
if (is_method and nargs != 3) or (not is_method and nargs != 2):
raise ValueError(
"Invalid stop criteria: {}. Callable "
"criteria must take exactly 2 parameters.".format(stop))
config = config or {}
self._run_identifier = Experiment.register_if_needed(run)
self.name = name or self._run_identifier
@@ -127,11 +114,30 @@ class Experiment:
else:
self.remote_checkpoint_dir = None
self._stopper = None
stopping_criteria = {}
if not stop:
pass
elif isinstance(stop, dict):
stopping_criteria = stop
elif callable(stop):
if FunctionStopper.is_valid_function(stop):
self._stopper = FunctionStopper(stop)
elif issubclass(type(stop), Stopper):
self._stopper = stop
else:
raise ValueError("Provided stop object must be either a dict, "
"a function, or a subclass of "
"`ray.tune.Stopper`.")
else:
raise ValueError("Invalid stop criteria: {}. Must be a "
"callable or dict".format(stop))
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
spec = {
"run": self._run_identifier,
"stop": stop,
"stop": stopping_criteria,
"config": config,
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
@@ -214,6 +220,10 @@ class Experiment:
else:
raise TuneError("Improper 'run' - not string nor trainable.")
@property
def stopper(self):
return self._stopper
@property
def local_dir(self):
return self.spec.get("local_dir")
+63
View File
@@ -0,0 +1,63 @@
class Stopper:
"""Base class for implementing a Tune experiment stopper.
Allows users to implement experiment-level stopping via ``stop_all``. By
default, this class does not stop any trials. Subclasses need to
implement ``__call__`` and ``stop_all``.
.. code-block:: python
import time
from ray import tune
from ray.tune import Stopper
class TimeStopper(Stopper):
def __init__(self):
self._start = time.time()
self._deadline = 300
def __call__(self, trial_id, result):
return False
def stop_all(self):
return time.time() - self._start > self.deadline
tune.run(Trainable, num_samples=200, stop=TimeStopper())
"""
def __call__(self, trial_id, result):
"""Returns true if the trial should be terminated given the result."""
raise NotImplementedError
def stop_all(self):
"""Returns true if the experiment should be terminated."""
raise NotImplementedError
class NoopStopper(Stopper):
def __call__(self, trial_id, result):
return False
def stop_all(self):
return False
class FunctionStopper(Stopper):
def __init__(self, function):
self._fn = function
def __call__(self, trial_id, result):
return self._fn(trial_id, result)
def stop_all(self):
return False
@classmethod
def is_valid_function(cls, fn):
is_function = callable(fn) and not issubclass(type(fn), Stopper)
if is_function and hasattr(fn, "stop_all"):
raise ValueError(
"Stop object must be ray.tune.Stopper subclass to be detected "
"correctly.")
return is_function
+31 -7
View File
@@ -10,7 +10,7 @@ import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import DurableTrainable, Trainable, TuneError
from ray.tune import DurableTrainable, Trainable, TuneError, Stopper
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.trial import Trial
@@ -452,28 +452,52 @@ class TrainableFunctionApiTest(unittest.TestCase):
for i in range(10):
reporter(test=i)
class Stopper:
class Stopclass:
def stop(self, trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=Stopper().stop).trials
[trial] = tune.run(train, stop=Stopclass().stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStopper(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class CustomStopper(Stopper):
def __init__(self):
self._count = 0
def __call__(self, trial_id, result):
print("called")
self._count += 1
return result["test"] > 6
def stop_all(self):
return self._count > 5
trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
self.assertTrue(
any(
t.last_result.get("training_iteration") is None
for t in trials))
def testBadStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class Stopper:
class CustomStopper:
def stop(self, result):
return result["test"] > 6
def stop(result):
return result["test"] > 6
with self.assertRaises(ValueError):
tune.run(train, stop=Stopper().stop)
with self.assertRaises(ValueError):
with self.assertRaises(TuneError):
tune.run(train, stop=CustomStopper().stop)
with self.assertRaises(TuneError):
tune.run(train, stop=stop)
def testEarlyReturn(self):
-3
View File
@@ -330,9 +330,6 @@ class Trial:
if result.get(DONE):
return True
if callable(self.stopping_criterion):
return self.stopping_criterion(self.trial_id, result)
for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
+16 -2
View File
@@ -9,6 +9,7 @@ import types
import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.stopper import NoopStopper
from ray.tune.progress_reporter import trial_progress_str
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
@@ -98,6 +99,7 @@ class TrialRunner:
local_checkpoint_dir=None,
remote_checkpoint_dir=None,
sync_to_cloud=None,
stopper=None,
resume=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
@@ -115,6 +117,8 @@ class TrialRunner:
remote_checkpoint_dir (str): Remote path where
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
stopper: Custom class for stopping whole experiments. See
``Stopper``.
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): See `tune.py:run`.
server_port (int): Port number for launching TuneServer.
@@ -149,7 +153,7 @@ class TrialRunner:
self._remote_checkpoint_dir = remote_checkpoint_dir
self._syncer = get_cloud_syncer(local_checkpoint_dir,
remote_checkpoint_dir, sync_to_cloud)
self._stopper = stopper or NoopStopper()
self._resumed = False
if self._validate_resume(resume_type=resume):
@@ -331,6 +335,8 @@ class TrialRunner:
else:
self.trial_executor.on_no_available_trials(self)
self._stop_experiment_if_needed()
try:
with warn_if_slow("experiment_checkpoint"):
self.checkpoint()
@@ -385,6 +391,13 @@ class TrialRunner:
"""Returns whether this runner has at least the specified resources."""
return self.trial_executor.has_resources(resources)
def _stop_experiment_if_needed(self):
"""Stops all trials if the user condition is satisfied."""
if self._stopper.stop_all():
[self.trial_executor.stop_trial(t) for t in self._trials]
logger.info("All trials stopped due to ``stopper.stop_all``.")
def _get_next_trial(self):
"""Replenishes queue.
@@ -435,7 +448,8 @@ class TrialRunner:
self._total_time += result.get(TIME_THIS_ITER_S, 0)
flat_result = flatten_dict(result)
if trial.should_stop(flat_result):
if self._stopper(trial.trial_id,
result) or trial.should_stop(flat_result):
# Hook into scheduler
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
self._search_alg.on_trial_complete(
+7 -2
View File
@@ -112,11 +112,14 @@ def run(run_or_experiment,
If Experiment, then Tune will execute training based on
Experiment.spec.
name (str): Name of experiment.
stop (dict|func): The stopping criteria. If dict, the keys may be
stop (dict|callable): The stopping criteria. If dict, the keys may be
any field in the return result of 'train()', whichever is
reached first. If function, it must take (trial_id, result) as
arguments and return a boolean (True if trial should be stopped,
False otherwise).
False otherwise). This can also be a subclass of
``ray.tune.Stopper``, which allows users to implement
custom experiment-wide stopping (i.e., stopping an entire Tune
run based on some time constraint).
config (dict): Algorithm-specific configuration for Tune variant
generation (e.g. env, hyperparams). Defaults to empty dict.
Custom search algorithms may ignore this.
@@ -243,6 +246,7 @@ def run(run_or_experiment,
logger.info(
"Running multiple concurrent experiments is experimental and may "
"not work with certain features.")
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
run_identifier = Experiment.register_if_needed(exp)
@@ -281,6 +285,7 @@ def run(run_or_experiment,
local_checkpoint_dir=experiments[0].checkpoint_dir,
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
stopper=experiments[0].stopper,
checkpoint_period=global_checkpoint_period,
resume=resume,
launch_web_server=with_server,