[tune] API revamp fix (#10518)

This commit is contained in:
Richard Liaw
2020-09-05 15:34:53 -07:00
committed by GitHub
parent 8a891b3c30
commit 551c597312
26 changed files with 349 additions and 269 deletions
+3 -1
View File
@@ -1,5 +1,6 @@
from ray.tune.error import TuneError
from ray.tune.tune import run_experiments, run
from ray.tune.syncer import SyncConfig
from ray.tune.experiment import Experiment
from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.stopper import Stopper, EarlyStopping
@@ -26,5 +27,6 @@ __all__ = [
"loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
"save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler"
"save_checkpoint", "checkpoint_dir", "SyncConfig", "create_searcher",
"create_scheduler"
]
+9
View File
@@ -2,6 +2,15 @@ import getpass
import os
def is_ray_cluster():
"""Checks if the bootstrap config file exists.
This will always exist if using an autoscaling cluster/started
with the ray cluster launcher.
"""
return os.path.exists(os.path.expanduser("~/ray_bootstrap_config.yaml"))
def get_ssh_user():
"""Returns ssh username for connecting to cluster workers."""
@@ -102,6 +102,12 @@ if __name__ == "__main__":
address = None if args.local else "auto"
ray.init(address=address)
sync_config = tune.SyncConfig(
sync_to_driver=False,
sync_on_checkpoint=False,
upload_dir="s3://ray-tune-test/exps/",
)
config = {
"seed": None,
"startup_delay": 0.001,
@@ -117,12 +123,9 @@ if __name__ == "__main__":
config=config,
num_samples=4,
verbose=1,
queue_trials=True,
# fault tolerance parameters
sync_config=sync_config,
max_failures=-1,
checkpoint_freq=20,
sync_to_driver=False,
sync_on_checkpoint=False,
upload_dir="s3://ray-tune-test/exps/",
checkpoint_score_attr="training_iteration",
)
+17 -2
View File
@@ -12,6 +12,7 @@ import ray
from ray.exceptions import GetTimeoutError
from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.cluster_info import is_ray_cluster
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
@@ -135,10 +136,24 @@ class RayTrialExecutor(TrialExecutor):
"""An implementation of TrialExecutor based on Ray."""
def __init__(self,
queue_trials=False,
queue_trials=None,
reuse_actors=False,
ray_auto_init=False,
ray_auto_init=None,
refresh_period=RESOURCE_REFRESH_PERIOD):
if queue_trials is None:
if os.environ.get("TUNE_DISABLE_QUEUE_TRIALS") == "1":
logger.info("'TUNE_DISABLE_QUEUE_TRIALS=1' detected.")
queue_trials = False
elif is_ray_cluster():
queue_trials = True
if ray_auto_init is None:
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
ray_auto_init = False
else:
ray_auto_init = True
super(RayTrialExecutor, self).__init__(queue_trials)
# Check for if we are launching a trial without resources in kick off
# autoscaler.
+53 -3
View File
@@ -1,14 +1,17 @@
from typing import Any
import distutils
import logging
import os
import time
from dataclasses import dataclass
from inspect import isclass
from shlex import quote
from ray import ray_constants
from ray import services
from ray.util.debug import log_once
from ray.tune.utils.util import env_integer
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
get_cloud_sync_client, NOOP)
@@ -17,8 +20,7 @@ logger = logging.getLogger(__name__)
# Syncing period for syncing local checkpoints to cloud.
# In env variable is not set, sync happens every 300 seconds.
CLOUD_SYNC_PERIOD = ray_constants.env_integer(
key="TUNE_CLOUD_SYNC_S", default=300)
CLOUD_SYNC_PERIOD = 300
# Syncing period for syncing worker logs to driver.
NODE_SYNC_PERIOD = 300
@@ -32,6 +34,18 @@ def wait_for_sync():
syncer.wait()
def set_sync_periods(sync_config):
"""Sets sync periods from config."""
global CLOUD_SYNC_PERIOD
global NODE_SYNC_PERIOD
if os.environ.get("TUNE_CLOUD_SYNC_S"):
logger.warning("'TUNE_CLOUD_SYNC_S' is deprecated. Set "
"`cloud_sync_period` via tune.SyncConfig instead.")
CLOUD_SYNC_PERIOD = env_integer(key="TUNE_CLOUD_SYNC_S", default=300)
NODE_SYNC_PERIOD = int(sync_config.node_sync_period)
CLOUD_SYNC_PERIOD = int(sync_config.cloud_sync_period)
def log_sync_template(options=""):
"""Template enabling syncs between driver and worker when possible.
Requires ray cluster to be started with the autoscaler. Also requires
@@ -63,6 +77,42 @@ def log_sync_template(options=""):
return template.format(options=options, rsh=quote(rsh))
@dataclass
class SyncConfig:
"""Configuration object for syncing.
Args:
upload_dir (str): Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket`` or ``gs://bucket``).
sync_to_cloud (func|str): Function for syncing the local_dir to and
from upload_dir. If string, then it must be a string template that
includes `{source}` and `{target}` for the syncer to run. If not
provided, the sync command defaults to standard S3 or gsutil sync
commands. By default local_dir is synced to remote_dir every 300
seconds. To change this, set the TUNE_CLOUD_SYNC_S
environment variable in the driver machine.
sync_to_driver (func|str|bool): Function for syncing trial logdir from
remote node to local. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If True or not provided, it defaults to using rsync. If False,
syncing to driver is disabled.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
driver. If set to False, checkpoint syncing from worker to driver
is asynchronous and best-effort. This does not affect persistent
storage syncing. Defaults to True.
node_sync_period (int): Syncing period for syncing worker logs to
driver. Defaults to 300.
cloud_sync_period (int): Syncing period for syncing local
checkpoints to cloud. Defaults to 300.
"""
upload_dir: str = None
sync_to_cloud: Any = None
sync_to_driver: Any = None
sync_on_checkpoint: bool = True
node_sync_period: int = 300
cloud_sync_period: int = 300
class Syncer:
def __init__(self, local_dir, remote_dir, sync_client=NOOP):
"""Syncs between two directories with the sync_function.
@@ -1,3 +1,4 @@
import os
import argparse
from ray.tune import run
@@ -44,9 +45,9 @@ if __name__ == "__main__":
algo = ConcurrencyLimiter(algo, max_concurrent=1)
from ray.tune import register_trainable
register_trainable("trainable", MyTrainableClass)
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
run("trainable",
search_alg=algo,
global_checkpoint_period=0,
resume=args.resume,
verbose=0,
num_samples=20,
+14 -13
View File
@@ -245,17 +245,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
register_trainable("B", B)
def f(cpus, gpus, queue_trials):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
queue_trials=queue_trials)[0]
if not queue_trials:
os.environ["TUNE_DISABLE_QUEUE_TRIALS"] = "1"
else:
os.environ.pop("TUNE_DISABLE_QUEUE_TRIALS", None)
return run_experiments({
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
})[0]
# Should all succeed
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
@@ -639,8 +641,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
loggers=None)
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
new_trials = tune.run(
test, resume=True, run_errored_only=True, **config).trials
new_trials = tune.run(test, resume="ERRORED_ONLY", **config).trials
self.assertEqual(Counter(t.status for t in new_trials)["ERROR"], 0)
self.assertTrue(
all(t.last_result.get("hello") == 123 for t in new_trials))
+3 -1
View File
@@ -642,10 +642,13 @@ def test_cluster_interrupt(start_connected_cluster, tmpdir):
for line in inspect.getsource(_Mock).split("\n"))
script = """
import os
import time
import ray
from ray import tune
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
ray.init(address="{address}")
{fail_class_code}
@@ -656,7 +659,6 @@ tune.run(
stop=dict(training_iteration=5),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
global_checkpoint_period=0,
max_failures=1,
raise_on_failed_trial=False)
""".format(
@@ -147,7 +147,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
MyTrainableClass,
name="test_example",
local_dir=self.test_dir,
return_trials=False,
stop={"training_iteration": 1},
num_samples=1,
config={
@@ -135,7 +135,6 @@ class AnalysisSuite(unittest.TestCase):
run(MyTrainableClass,
name=test_name,
local_dir=self.test_dir,
return_trials=False,
stop={"training_iteration": 1},
num_samples=self.num_samples,
config={
@@ -16,7 +16,7 @@ from ray.cluster_utils import Cluster
class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init()
ray.init(ignore_reinit_error=True)
_register_all() # Needed for flaky tests
def tearDown(self):
@@ -182,8 +182,6 @@ class RayTrialExecutorTest(unittest.TestCase):
class RayExecutorQueueTest(unittest.TestCase):
def setUp(self):
self.trial_executor = RayTrialExecutor(
queue_trials=True, refresh_period=0)
self.cluster = Cluster(
initialize_head=True,
connect=True,
@@ -193,6 +191,8 @@ class RayExecutorQueueTest(unittest.TestCase):
"num_heartbeats_timeout": 10
}
})
self.trial_executor = RayTrialExecutor(
queue_trials=True, refresh_period=0)
# Pytest doesn't play nicely with imports
_register_all()
@@ -247,8 +247,8 @@ class RayExecutorQueueTest(unittest.TestCase):
class LocalModeExecutorTest(RayTrialExecutorTest):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init(local_mode=True)
self.trial_executor = RayTrialExecutor(queue_trials=False)
def tearDown(self):
ray.shutdown()
+60 -53
View File
@@ -31,12 +31,11 @@ class TestSyncFunctionality(unittest.TestCase):
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_cloud": "echo {source} {target}"
}).trials
stop={
"training_iteration": 1
},
sync_config=tune.SyncConfig(
**{"sync_to_cloud": "echo {source} {target}"})).trials
@patch("ray.tune.sync_client.S3_PREFIX", "test")
def testCloudProperString(self):
@@ -45,26 +44,26 @@ class TestSyncFunctionality(unittest.TestCase):
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
stop={
"training_iteration": 1
},
sync_config=tune.SyncConfig(**{
"upload_dir": "test",
"sync_to_cloud": "ls {target}"
}).trials
})).trials
with self.assertRaises(ValueError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
stop={
"training_iteration": 1
},
sync_config=tune.SyncConfig(**{
"upload_dir": "test",
"sync_to_cloud": "ls {source}"
}).trials
})).trials
tmpdir = tempfile.mkdtemp()
logfile = os.path.join(tmpdir, "test.log")
@@ -73,13 +72,14 @@ class TestSyncFunctionality(unittest.TestCase):
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"upload_dir": "test",
"sync_to_cloud": "echo {source} {target} > " + logfile
}).trials
stop={
"training_iteration": 1
},
sync_config=tune.SyncConfig(
**{
"upload_dir": "test",
"sync_to_cloud": "echo {source} {target} > " + logfile
})).trials
with open(logfile) as f:
lines = f.read()
self.assertTrue("test" in lines)
@@ -89,42 +89,41 @@ class TestSyncFunctionality(unittest.TestCase):
"""Tests that invalid commands throw.."""
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "ls {target}"
}).trials
stop={
"training_iteration": 1
},
sync_config=sync_config,
).trials
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "ls {source}"
sync_config=sync_config,
stop={
"training_iteration": 1
}).trials
with patch.object(CommandBasedClient, "_execute") as mock_fn:
with patch("ray.services.get_node_ip_address") as mock_sync:
sync_config = tune.SyncConfig(
sync_to_driver="echo {source} {target}")
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "echo {source} {target}"
sync_config=sync_config,
stop={
"training_iteration": 1
}).trials
self.assertGreater(mock_fn.call_count, 0)
@@ -137,6 +136,8 @@ class TestSyncFunctionality(unittest.TestCase):
for filename in glob.glob(os.path.join(local, "*.json")):
shutil.copy(filename, remote)
sync_config = tune.SyncConfig(
upload_dir=tmpdir2, sync_to_cloud=sync_func)
[trial] = tune.run(
"__fake",
name="foo",
@@ -145,8 +146,7 @@ class TestSyncFunctionality(unittest.TestCase):
stop={
"training_iteration": 1
},
upload_dir=tmpdir2,
sync_to_cloud=sync_func).trials
sync_config=sync_config).trials
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
self.assertTrue(test_file_path)
shutil.rmtree(tmpdir)
@@ -167,18 +167,21 @@ class TestSyncFunctionality(unittest.TestCase):
def counter(local, remote):
mock()
tune.syncer.CLOUD_SYNC_PERIOD = 1
sync_config = tune.SyncConfig(
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
# This was originally set to 0.5
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
self.addCleanup(
lambda: os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S", None))
[trial] = tune.run(
trainable,
name="foo",
max_failures=0,
local_dir=tmpdir,
upload_dir="test",
sync_to_cloud=counter,
stop={
"training_iteration": 10
},
global_checkpoint_period=0.5,
sync_config=sync_config,
).trials
self.assertEqual(mock.call_count, 12)
@@ -192,6 +195,9 @@ class TestSyncFunctionality(unittest.TestCase):
print("writing to", f.name)
f.write(source)
sync_config = tune.SyncConfig(
sync_to_driver=sync_func_driver, node_sync_period=5)
[trial] = tune.run(
"__fake",
name="foo",
@@ -199,12 +205,13 @@ class TestSyncFunctionality(unittest.TestCase):
stop={
"training_iteration": 1
},
sync_to_driver=sync_func_driver).trials
sync_config=sync_config).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertFalse(os.path.exists(test_file_path))
with patch("ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
sync_config = tune.SyncConfig(sync_to_driver=sync_func_driver)
[trial] = tune.run(
"__fake",
name="foo",
@@ -212,7 +219,7 @@ class TestSyncFunctionality(unittest.TestCase):
stop={
"training_iteration": 1
},
sync_to_driver=sync_func_driver).trials
sync_config=sync_config).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertTrue(os.path.exists(test_file_path))
os.remove(test_file_path)
@@ -223,17 +230,17 @@ class TestSyncFunctionality(unittest.TestCase):
def sync_func(source, target):
pass
sync_config = tune.SyncConfig(sync_to_driver=sync_func)
with patch.object(CommandBasedClient, "_execute") as mock_sync:
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": sync_func
}).trials
stop={
"training_iteration": 1
},
sync_config=sync_config).trials
self.assertEqual(mock_sync.call_count, 0)
+2 -7
View File
@@ -385,10 +385,7 @@ class TrialRunnerTest3(unittest.TestCase):
assert trials[0].status == Trial.ERROR
del runner
new_runner = TrialRunner(
run_errored_only=False,
resume=True,
local_checkpoint_dir=self.tmpdir)
new_runner = TrialRunner(resume=True, local_checkpoint_dir=self.tmpdir)
assert len(new_runner.get_trials()) == 3
assert Trial.ERROR in (t.status for t in new_runner.get_trials())
@@ -418,9 +415,7 @@ class TrialRunnerTest3(unittest.TestCase):
del runner
new_runner = TrialRunner(
run_errored_only=True,
resume=True,
local_checkpoint_dir=self.tmpdir)
resume="ERRORED_ONLY", local_checkpoint_dir=self.tmpdir)
assert len(new_runner.get_trials()) == 3
assert Trial.ERROR not in (t.status for t in new_runner.get_trials())
# The below is just a check for standard behavior.
+1 -1
View File
@@ -29,7 +29,7 @@ class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()
self.runner = TrialRunner(launch_web_server=True, server_port=port)
self.runner = TrialRunner(server_port=port)
runner = self.runner
kwargs = {
"stopping_criterion": {
+2 -4
View File
@@ -165,10 +165,8 @@ class TrialExecutor:
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster has only {}. "
"Pass `queue_trials=True` in "
"ray.tune.run() or on the command "
"line to queue trials until the cluster scales "
"up or resources become available. {}").format(
"This error should not occur if running on an "
"autoscaling cluster. {}").format(
trial.resources.summary_string(),
self.resource_string(),
trial.get_trainable_cls().resource_help(
+15 -20
View File
@@ -19,7 +19,7 @@ from ray.tune.syncer import get_cloud_syncer
from ray.tune.trial import Checkpoint, Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
from ray.tune.web_server import TuneServer
from ray.utils import binary_to_hex, hex_to_binary
from ray.util.debug import log_once
@@ -95,7 +95,6 @@ class TrialRunner:
search_alg (SearchAlgorithm): SearchAlgorithm for generating
Trial objects.
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
local_checkpoint_dir (str): Path where
global checkpoints are stored and restored from.
remote_checkpoint_dir (str): Remote path where
@@ -110,10 +109,6 @@ class TrialRunner:
If fail_fast='raise' provided, Tune will automatically
raise the exception received by the Trainable. fail_fast='raise'
can easily leak resources and should be used with caution.
run_errored_only (bool): Resets and reruns failed trials, assuming
the provided Trainable is the same. Previous trial artifacts
will be left untouched. Only to be used with
`resume` enabled. Raises ValueError otherwise.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
checkpoint_period (int): Trial runner checkpoint periodicity in
@@ -122,23 +117,21 @@ class TrialRunner:
"""
CKPT_FILE_TMPL = "experiment_state-{}.json"
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT"]
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"]
RAISE = "RAISE"
def __init__(self,
search_alg=None,
scheduler=None,
launch_web_server=False,
local_checkpoint_dir=None,
remote_checkpoint_dir=None,
sync_to_cloud=None,
stopper=None,
resume=False,
server_port=TuneServer.DEFAULT_PORT,
server_port=None,
fail_fast=False,
run_errored_only=False,
verbose=True,
checkpoint_period=10,
checkpoint_period=None,
trial_executor=None):
self._search_alg = search_alg or BasicVariantGenerator()
self._scheduler_alg = scheduler or FIFOScheduler()
@@ -168,7 +161,7 @@ class TrialRunner:
self._server = None
self._server_port = server_port
if launch_web_server:
if server_port is not None:
self._server = TuneServer(self, self._server_port)
self._trials = []
@@ -187,8 +180,11 @@ class TrialRunner:
self._resumed = False
if self._validate_resume(resume_type=resume):
errored_only = False
if isinstance(resume, str):
errored_only = resume.upper() == "ERRORED_ONLY"
try:
self.resume(run_errored_only=run_errored_only)
self.resume(run_errored_only=errored_only)
self._resumed = True
except Exception as e:
if self._verbose:
@@ -198,15 +194,12 @@ class TrialRunner:
raise
logger.info("Restarting experiment.")
else:
if run_errored_only:
raise ValueError(
"'run_errored_only' should only be used with 'resume'. "
f"Got: resume={resume}, "
f"run_errored_only={run_errored_only}")
logger.debug("Starting a new experiment.")
self._start_time = time.time()
self._last_checkpoint_time = -float("inf")
if checkpoint_period is None:
checkpoint_period = env_integer("TUNE_GLOBAL_CHECKPOINT_S", 10)
self._checkpoint_period = checkpoint_period
self._session_str = datetime.fromtimestamp(
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
@@ -228,8 +221,10 @@ class TrialRunner:
"""Checks whether to resume experiment.
Args:
resume_type: One of True, "REMOTE", "LOCAL", "PROMPT".
resume_type: One of True, "REMOTE", "LOCAL",
"PROMPT", "ERRORED_ONLY".
"""
# TODO: Consider supporting ERRORED_ONLY+REMOTE?
if not resume_type:
return False
assert resume_type in self.VALID_RESUME_TYPES, (
@@ -238,7 +233,7 @@ class TrialRunner:
# Not clear if we need this assertion, since we should always have a
# local checkpoint dir.
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
if resume_type in [True, "LOCAL", "PROMPT"]:
if resume_type in [True, "LOCAL", "PROMPT", "ERRORED_ONLY"]:
if not self.checkpoint_exists(self._local_checkpoint_dir):
raise ValueError("Called resume when no checkpoint exists "
"in local directory.")
+101 -131
View File
@@ -10,12 +10,11 @@ from ray.tune.trial import Trial
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
FIFOScheduler, MedianStoppingRule)
from ray.tune.web_server import TuneServer
logger = logging.getLogger(__name__)
@@ -66,44 +65,47 @@ def _report_progress(runner, reporter, done=False):
reporter.report(trials, done, sched_debug_str, executor_debug_str)
def run(run_or_experiment,
def run(
run_or_experiment,
name=None,
stop=None,
config=None,
resources_per_trial=None,
num_samples=1,
local_dir=None,
upload_dir=None,
trial_name_creator=None,
trial_dirname_creator=None,
loggers=None,
log_to_file=False,
sync_to_cloud=None,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
search_alg=None,
scheduler=None,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
global_checkpoint_period=10,
checkpoint_freq=0,
checkpoint_at_end=False,
verbose=2,
progress_reporter=None,
loggers=None,
log_to_file=False,
trial_name_creator=None,
trial_dirname_creator=None,
sync_config=None,
export_formats=None,
max_failures=0,
fail_fast=False,
restore=None,
search_alg=None,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=2,
progress_reporter=None,
server_port=None,
resume=False,
run_errored_only=False,
queue_trials=False,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True,
return_trials=False,
ray_auto_init=True):
# Deprecated args
ray_auto_init=None,
run_errored_only=None,
queue_trials=None,
global_checkpoint_period=None,
with_server=None,
upload_dir=None,
sync_to_cloud=None,
sync_to_driver=None,
sync_on_checkpoint=None,
):
"""Executes training.
Examples:
@@ -130,7 +132,7 @@ def run(run_or_experiment,
# Rerun ONLY failed trials after an experiment is finished.
tune.run(my_trainable, config=space,
local_dir=<path/to/dir>, resume=True, run_errored_only=True)
local_dir=<path/to/dir>, resume="ERRORED_ONLY")
Args:
run_or_experiment (function | class | str | :class:`Experiment`): If
@@ -166,14 +168,30 @@ def run(run_or_experiment,
`num_samples` of times.
local_dir (str): Local dir to save training results to.
Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket`` or ``gs://bucket``).
trial_name_creator (Callable[[Trial], str]): Optional function
for generating the trial string representation.
trial_dirname_creator (Callable[[Trial], str]): Function
for generating the trial dirname. This function should take
in a Trial object and return a string representing the
name of the directory. The return value cannot be a path.
search_alg (Searcher): Search algorithm for optimization.
scheduler (TrialScheduler): Scheduler for executing
the experiment. Choose among FIFO (default), MedianStopping,
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
ray.tune.schedulers for more options.
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
`None` keeps all checkpoints. Defaults to `None`. If set, need
to provide `checkpoint_score_attr`.
checkpoint_score_attr (str): Specifies by which attribute to rank the
best checkpoint. Default is increasing order. If attribute starts
with `min-` it will rank attribute in decreasing order, i.e.
`min-validation_loss`.
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
This has no effect when using the Functional Training API.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
This has no effect when using the Functional Training API.
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
1 = only status updates, 2 = status and trial results.
progress_reporter (ProgressReporter): Progress reporter for reporting
intermediate experiment progress. Defaults to CLIReporter if
running in command-line, or JupyterNotebookReporter if running in
a Jupyter notebook.
loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`.
@@ -185,38 +203,14 @@ def run(run_or_experiment,
both streams are written. If this is a Sequence (e.g. a Tuple),
it has to have length 2 and the elements indicate the files to
which stdout and stderr are written, respectively.
sync_to_cloud (func|str): Function for syncing the local_dir to and
from upload_dir. If string, then it must be a string template that
includes `{source}` and `{target}` for the syncer to run. If not
provided, the sync command defaults to standard S3 or gsutil sync
commands. By default local_dir is synced to remote_dir every 300
seconds. To change this, set the TUNE_CLOUD_SYNC_S
environment variable in the driver machine.
sync_to_driver (func|str|bool): Function for syncing trial logdir from
remote node to local. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If True or not provided, it defaults to using rsync. If False,
syncing to driver is disabled.
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
This has no effect when using the Functional Training API.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
This has no effect when using the Functional Training API.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
driver. If set to False, checkpoint syncing from worker to driver
is asynchronous and best-effort. This does not affect persistent
storage syncing. Defaults to True.
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
`None` keeps all checkpoints. Defaults to `None`. If set, need
to provide `checkpoint_score_attr`.
checkpoint_score_attr (str): Specifies by which attribute to rank the
best checkpoint. Default is increasing order. If attribute starts
with `min-` it will rank attribute in decreasing order, i.e.
`min-validation_loss`.
global_checkpoint_period (int): Seconds between global checkpointing.
This does not affect `checkpoint_freq`, which specifies frequency
for individual trials.
trial_name_creator (Callable[[Trial], str]): Optional function
for generating the trial string representation.
trial_dirname_creator (Callable[[Trial], str]): Function
for generating the trial dirname. This function should take
in a Trial object and return a string representing the
name of the directory. The return value cannot be a path.
sync_config (SyncConfig): Configuration object for syncing. See
tune.SyncConfig.
export_formats (list): List of formats that exported at the end of
the experiment. Default is None.
max_failures (int): Try to recover a trial at least this many times.
@@ -230,35 +224,16 @@ def run(run_or_experiment,
is best used with `ray.init(local_mode=True)`).
restore (str): Path to checkpoint. Only makes sense to set if
running 1 trial. Defaults to None.
search_alg (Searcher): Search algorithm for optimization.
scheduler (TrialScheduler): Scheduler for executing
the experiment. Choose among FIFO (default), MedianStopping,
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
ray.tune.schedulers for more options.
with_server (bool): Starts a background Tune server. Needed for
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
1 = only status updates, 2 = status and trial results.
progress_reporter (ProgressReporter): Progress reporter for reporting
intermediate experiment progress. Defaults to CLIReporter if
running in command-line, or JupyterNotebookReporter if running in
a Jupyter notebook.
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool.
LOCAL/True restores the checkpoint from the local_checkpoint_dir.
REMOTE restores the checkpoint from remote_checkpoint_dir.
PROMPT provides CLI feedback. False forces a new
experiment. If resume is set but checkpoint does not exist,
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY",
or bool. LOCAL/True restores the checkpoint from the
local_checkpoint_dir, determined
by `name` and `local_dir`. REMOTE restores the checkpoint
from remote_checkpoint_dir. PROMPT provides CLI feedback.
False forces a new experiment. ERRORED_ONLY resets and reruns
ERRORED trials upon resume - previous trial artifacts will
be left untouched. If resume is set but checkpoint does not exist,
ValueError will be thrown.
run_errored_only (bool): Only to be used with `resume` enabled.
Resets and reruns ERRORED trials upon resume.
Experiment location is determined
by `name` and `local_dir`. Previous trial artifacts will
be left untouched.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
reuse_actors (bool): Whether to reuse actors between different trials
when possible. This can drastically speed up experiments that start
and stop actors often (e.g., PBT in time-multiplexing mode). This
@@ -266,9 +241,6 @@ def run(run_or_experiment,
trial_executor (TrialExecutor): Manage the execution of trials.
raise_on_failed_trial (bool): Raise TuneError if there exists failed
trial (of ERROR state) when the experiments complete.
ray_auto_init (bool): Automatically starts a local Ray cluster
if using a RayTrialExecutor (which is the default) and
if Ray is not initialized. Defaults to True.
Returns:
@@ -277,12 +249,35 @@ def run(run_or_experiment,
Raises:
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""
if global_checkpoint_period:
raise ValueError("global_checkpoint_period is deprecated. Set env var "
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
if queue_trials:
raise ValueError(
"queue_trials is deprecated. "
"Set env var 'TUNE_DISABLE_QUEUE_TRIALS=1' instead to "
"disable queuing behavior.")
if ray_auto_init:
raise ValueError("ray_auto_init is deprecated. "
"Set env var 'TUNE_DISABLE_AUTO_INIT=1' instead or "
"call 'ray.init' before calling 'tune.run'.")
if with_server:
raise ValueError(
"with_server is deprecated. It is now enabled by default "
"if 'server_port' is not None.")
if sync_on_checkpoint or sync_to_cloud or sync_to_driver or upload_dir:
raise ValueError(
"sync_on_checkpoint / sync_to_cloud / sync_to_driver / "
"upload_dir must now be set via `tune.run("
"sync_config=SyncConfig(...)`. See `ray.tune.SyncConfig` for "
"more details.")
config = config or {}
sync_config = sync_config or SyncConfig()
set_sync_periods(sync_config)
trial_executor = trial_executor or RayTrialExecutor(
queue_trials=queue_trials,
reuse_actors=reuse_actors,
ray_auto_init=ray_auto_init)
reuse_actors=reuse_actors)
if isinstance(run_or_experiment, list):
experiments = run_or_experiment
else:
@@ -298,15 +293,15 @@ def run(run_or_experiment,
resources_per_trial=resources_per_trial,
num_samples=num_samples,
local_dir=local_dir,
upload_dir=upload_dir,
sync_to_driver=sync_to_driver,
upload_dir=sync_config.upload_dir,
sync_to_driver=sync_config.sync_to_driver,
trial_name_creator=trial_name_creator,
trial_dirname_creator=trial_dirname_creator,
loggers=loggers,
log_to_file=log_to_file,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
sync_on_checkpoint=sync_on_checkpoint,
sync_on_checkpoint=sync_config.sync_on_checkpoint,
keep_checkpoints_num=keep_checkpoints_num,
checkpoint_score_attr=checkpoint_score_attr,
export_formats=export_formats,
@@ -315,7 +310,7 @@ def run(run_or_experiment,
else:
logger.debug("Ignoring some parameters passed into tune.run.")
if sync_to_cloud:
if sync_config.sync_to_cloud:
for exp in experiments:
assert exp.remote_checkpoint_dir, (
"Need `upload_dir` if `sync_to_cloud` given.")
@@ -344,12 +339,9 @@ def run(run_or_experiment,
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiments[0].checkpoint_dir,
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
sync_to_cloud=sync_config.sync_to_cloud,
stopper=experiments[0].stopper,
checkpoint_period=global_checkpoint_period,
resume=resume,
run_errored_only=run_errored_only,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
fail_fast=fail_fast,
@@ -413,8 +405,6 @@ def run(run_or_experiment,
logger.error("Trials did not complete: %s", incomplete_trials)
trials = runner.get_trials()
if return_trials:
return trials
return ExperimentAnalysis(
runner.checkpoint_file,
trials=trials,
@@ -423,14 +413,11 @@ def run(run_or_experiment,
def run_experiments(experiments,
search_alg=None,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
server_port=None,
verbose=2,
progress_reporter=None,
resume=False,
queue_trials=False,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True,
@@ -444,15 +431,6 @@ def run_experiments(experiments,
>>> experiment_spec = {"experiment": {"run": my_func}}
>>> run_experiments(experiments=experiment_spec)
>>> run_experiments(
>>> experiments=experiment_spec,
>>> scheduler=MedianStoppingRule(...))
>>> run_experiments(
>>> experiments=experiment_spec,
>>> search_alg=SearchAlgorithm(),
>>> scheduler=MedianStoppingRule(...))
Returns:
List of Trial objects, holding data for each executed trial.
@@ -465,33 +443,25 @@ def run_experiments(experiments,
if concurrent:
return run(
experiments,
search_alg=search_alg,
scheduler=scheduler,
with_server=with_server,
server_port=server_port,
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,
return_trials=True)
scheduler=scheduler).trials
else:
trials = []
for exp in experiments:
trials += run(
exp,
search_alg=search_alg,
scheduler=scheduler,
with_server=with_server,
server_port=server_port,
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,
return_trials=True)
scheduler=scheduler).trials
return trials
+3 -2
View File
@@ -1,9 +1,10 @@
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
validate_save_restore, warn_if_slow, diagnose_serialization
validate_save_restore, warn_if_slow, diagnose_serialization, env_integer
__all__ = [
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization"
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
"env_integer"
]
+12
View File
@@ -1,5 +1,6 @@
import copy
import logging
import os
import inspect
import threading
import time
@@ -152,6 +153,17 @@ class Tee(object):
self.stream2.flush(*args, **kwargs)
def env_integer(key, default):
# TODO(rliaw): move into ray.constants
if key in os.environ:
value = os.environ[key]
if value.isdigit():
return int(os.environ[key])
raise ValueError(f"Found {key} in environment, but value must "
f"be an integer. Got: {value}.")
return default
def merge_dicts(d1, d2):
"""
Args: