mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 01:03:15 +08:00
[tune] Experiment stopping API (#6886)
This commit is contained in:
@@ -2,6 +2,7 @@ from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments, run
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis, Analysis
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
"register_trainable",
|
||||
"run",
|
||||
"run_experiments",
|
||||
"Stopper",
|
||||
"Experiment",
|
||||
"function",
|
||||
"sample_from",
|
||||
|
||||
@@ -98,17 +98,20 @@ if __name__ == "__main__":
|
||||
# __pbt_end__
|
||||
|
||||
# __tune_begin__
|
||||
class Stopper:
|
||||
class CustomStopper(tune.Stopper):
|
||||
def __init__(self):
|
||||
self.should_stop = False
|
||||
|
||||
def stop(self, trial_id, result):
|
||||
def __call__(self, trial_id, result):
|
||||
max_iter = 5 if args.smoke_test else 100
|
||||
if not self.should_stop and result["mean_accuracy"] > 0.96:
|
||||
self.should_stop = True
|
||||
return self.should_stop or result["training_iteration"] >= max_iter
|
||||
|
||||
stopper = Stopper()
|
||||
def stop_all(self):
|
||||
return self.should_stop
|
||||
|
||||
stopper = CustomStopper()
|
||||
|
||||
analysis = tune.run(
|
||||
PytorchTrainble,
|
||||
@@ -116,7 +119,7 @@ if __name__ == "__main__":
|
||||
scheduler=scheduler,
|
||||
reuse_actors=True,
|
||||
verbose=1,
|
||||
stop=stopper.stop,
|
||||
stop=stopper,
|
||||
export_formats=[ExportFormat.MODEL],
|
||||
checkpoint_score_attr="mean_accuracy",
|
||||
checkpoint_freq=5,
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import sample_from
|
||||
from ray.tune.stopper import FunctionStopper, Stopper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,18 +106,6 @@ class Experiment:
|
||||
_raise_deprecation_note(
|
||||
"sync_function", "sync_to_driver", soft=False)
|
||||
|
||||
stop = stop or {}
|
||||
if not isinstance(stop, dict) and not callable(stop):
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a callable "
|
||||
"or dict".format(stop))
|
||||
if callable(stop):
|
||||
nargs = len(inspect.getargspec(stop).args)
|
||||
is_method = isinstance(stop, types.MethodType)
|
||||
if (is_method and nargs != 3) or (not is_method and nargs != 2):
|
||||
raise ValueError(
|
||||
"Invalid stop criteria: {}. Callable "
|
||||
"criteria must take exactly 2 parameters.".format(stop))
|
||||
|
||||
config = config or {}
|
||||
self._run_identifier = Experiment.register_if_needed(run)
|
||||
self.name = name or self._run_identifier
|
||||
@@ -127,11 +114,30 @@ class Experiment:
|
||||
else:
|
||||
self.remote_checkpoint_dir = None
|
||||
|
||||
self._stopper = None
|
||||
stopping_criteria = {}
|
||||
if not stop:
|
||||
pass
|
||||
elif isinstance(stop, dict):
|
||||
stopping_criteria = stop
|
||||
elif callable(stop):
|
||||
if FunctionStopper.is_valid_function(stop):
|
||||
self._stopper = FunctionStopper(stop)
|
||||
elif issubclass(type(stop), Stopper):
|
||||
self._stopper = stop
|
||||
else:
|
||||
raise ValueError("Provided stop object must be either a dict, "
|
||||
"a function, or a subclass of "
|
||||
"`ray.tune.Stopper`.")
|
||||
else:
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a "
|
||||
"callable or dict".format(stop))
|
||||
|
||||
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
|
||||
|
||||
spec = {
|
||||
"run": self._run_identifier,
|
||||
"stop": stop,
|
||||
"stop": stopping_criteria,
|
||||
"config": config,
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
@@ -214,6 +220,10 @@ class Experiment:
|
||||
else:
|
||||
raise TuneError("Improper 'run' - not string nor trainable.")
|
||||
|
||||
@property
|
||||
def stopper(self):
|
||||
return self._stopper
|
||||
|
||||
@property
|
||||
def local_dir(self):
|
||||
return self.spec.get("local_dir")
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
class Stopper:
|
||||
"""Base class for implementing a Tune experiment stopper.
|
||||
|
||||
Allows users to implement experiment-level stopping via ``stop_all``. By
|
||||
default, this class does not stop any trials. Subclasses need to
|
||||
implement ``__call__`` and ``stop_all``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
from ray.tune import Stopper
|
||||
|
||||
class TimeStopper(Stopper):
|
||||
def __init__(self):
|
||||
self._start = time.time()
|
||||
self._deadline = 300
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
return False
|
||||
|
||||
def stop_all(self):
|
||||
return time.time() - self._start > self.deadline
|
||||
|
||||
tune.run(Trainable, num_samples=200, stop=TimeStopper())
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
"""Returns true if the trial should be terminated given the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
def stop_all(self):
|
||||
"""Returns true if the experiment should be terminated."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoopStopper(Stopper):
|
||||
def __call__(self, trial_id, result):
|
||||
return False
|
||||
|
||||
def stop_all(self):
|
||||
return False
|
||||
|
||||
|
||||
class FunctionStopper(Stopper):
|
||||
def __init__(self, function):
|
||||
self._fn = function
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
return self._fn(trial_id, result)
|
||||
|
||||
def stop_all(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_valid_function(cls, fn):
|
||||
is_function = callable(fn) and not issubclass(type(fn), Stopper)
|
||||
if is_function and hasattr(fn, "stop_all"):
|
||||
raise ValueError(
|
||||
"Stop object must be ray.tune.Stopper subclass to be detected "
|
||||
"correctly.")
|
||||
return is_function
|
||||
@@ -10,7 +10,7 @@ import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import DurableTrainable, Trainable, TuneError
|
||||
from ray.tune import DurableTrainable, Trainable, TuneError, Stopper
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.trial import Trial
|
||||
@@ -452,28 +452,52 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
class Stopper:
|
||||
class Stopclass:
|
||||
def stop(self, trial_id, result):
|
||||
return result["test"] > 6
|
||||
|
||||
[trial] = tune.run(train, stop=Stopper().stop).trials
|
||||
[trial] = tune.run(train, stop=Stopclass().stop).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 8)
|
||||
|
||||
def testStopper(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
class CustomStopper(Stopper):
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
print("called")
|
||||
self._count += 1
|
||||
return result["test"] > 6
|
||||
|
||||
def stop_all(self):
|
||||
return self._count > 5
|
||||
|
||||
trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
|
||||
self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
|
||||
self.assertTrue(
|
||||
any(
|
||||
t.last_result.get("training_iteration") is None
|
||||
for t in trials))
|
||||
|
||||
def testBadStoppingFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
class Stopper:
|
||||
class CustomStopper:
|
||||
def stop(self, result):
|
||||
return result["test"] > 6
|
||||
|
||||
def stop(result):
|
||||
return result["test"] > 6
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tune.run(train, stop=Stopper().stop)
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(TuneError):
|
||||
tune.run(train, stop=CustomStopper().stop)
|
||||
with self.assertRaises(TuneError):
|
||||
tune.run(train, stop=stop)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
|
||||
@@ -330,9 +330,6 @@ class Trial:
|
||||
if result.get(DONE):
|
||||
return True
|
||||
|
||||
if callable(self.stopping_criterion):
|
||||
return self.stopping_criterion(self.trial_id, result)
|
||||
|
||||
for criteria, stop_value in self.stopping_criterion.items():
|
||||
if criteria not in result:
|
||||
raise TuneError(
|
||||
|
||||
@@ -9,6 +9,7 @@ import types
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.stopper import NoopStopper
|
||||
from ray.tune.progress_reporter import trial_progress_str
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
@@ -98,6 +99,7 @@ class TrialRunner:
|
||||
local_checkpoint_dir=None,
|
||||
remote_checkpoint_dir=None,
|
||||
sync_to_cloud=None,
|
||||
stopper=None,
|
||||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
@@ -115,6 +117,8 @@ class TrialRunner:
|
||||
remote_checkpoint_dir (str): Remote path where
|
||||
global checkpoints are stored and restored from. Used
|
||||
if `resume` == REMOTE.
|
||||
stopper: Custom class for stopping whole experiments. See
|
||||
``Stopper``.
|
||||
resume (str|False): see `tune.py:run`.
|
||||
sync_to_cloud (func|str): See `tune.py:run`.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
@@ -149,7 +153,7 @@ class TrialRunner:
|
||||
self._remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self._syncer = get_cloud_syncer(local_checkpoint_dir,
|
||||
remote_checkpoint_dir, sync_to_cloud)
|
||||
|
||||
self._stopper = stopper or NoopStopper()
|
||||
self._resumed = False
|
||||
|
||||
if self._validate_resume(resume_type=resume):
|
||||
@@ -331,6 +335,8 @@ class TrialRunner:
|
||||
else:
|
||||
self.trial_executor.on_no_available_trials(self)
|
||||
|
||||
self._stop_experiment_if_needed()
|
||||
|
||||
try:
|
||||
with warn_if_slow("experiment_checkpoint"):
|
||||
self.checkpoint()
|
||||
@@ -385,6 +391,13 @@ class TrialRunner:
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
return self.trial_executor.has_resources(resources)
|
||||
|
||||
def _stop_experiment_if_needed(self):
|
||||
"""Stops all trials if the user condition is satisfied."""
|
||||
|
||||
if self._stopper.stop_all():
|
||||
[self.trial_executor.stop_trial(t) for t in self._trials]
|
||||
logger.info("All trials stopped due to ``stopper.stop_all``.")
|
||||
|
||||
def _get_next_trial(self):
|
||||
"""Replenishes queue.
|
||||
|
||||
@@ -435,7 +448,8 @@ class TrialRunner:
|
||||
self._total_time += result.get(TIME_THIS_ITER_S, 0)
|
||||
|
||||
flat_result = flatten_dict(result)
|
||||
if trial.should_stop(flat_result):
|
||||
if self._stopper(trial.trial_id,
|
||||
result) or trial.should_stop(flat_result):
|
||||
# Hook into scheduler
|
||||
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
|
||||
self._search_alg.on_trial_complete(
|
||||
|
||||
@@ -112,11 +112,14 @@ def run(run_or_experiment,
|
||||
If Experiment, then Tune will execute training based on
|
||||
Experiment.spec.
|
||||
name (str): Name of experiment.
|
||||
stop (dict|func): The stopping criteria. If dict, the keys may be
|
||||
stop (dict|callable): The stopping criteria. If dict, the keys may be
|
||||
any field in the return result of 'train()', whichever is
|
||||
reached first. If function, it must take (trial_id, result) as
|
||||
arguments and return a boolean (True if trial should be stopped,
|
||||
False otherwise).
|
||||
False otherwise). This can also be a subclass of
|
||||
``ray.tune.Stopper``, which allows users to implement
|
||||
custom experiment-wide stopping (i.e., stopping an entire Tune
|
||||
run based on some time constraint).
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
@@ -243,6 +246,7 @@ def run(run_or_experiment,
|
||||
logger.info(
|
||||
"Running multiple concurrent experiments is experimental and may "
|
||||
"not work with certain features.")
|
||||
|
||||
for i, exp in enumerate(experiments):
|
||||
if not isinstance(exp, Experiment):
|
||||
run_identifier = Experiment.register_if_needed(exp)
|
||||
@@ -281,6 +285,7 @@ def run(run_or_experiment,
|
||||
local_checkpoint_dir=experiments[0].checkpoint_dir,
|
||||
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
stopper=experiments[0].stopper,
|
||||
checkpoint_period=global_checkpoint_period,
|
||||
resume=resume,
|
||||
launch_web_server=with_server,
|
||||
|
||||
Reference in New Issue
Block a user