mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 16:31:16 +08:00
[tune] API revamp fix (#10518)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments, run
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis, Analysis
|
||||
from ray.tune.stopper import Stopper, EarlyStopping
|
||||
@@ -26,5 +27,6 @@ __all__ = [
|
||||
"loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
|
||||
"save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler"
|
||||
"save_checkpoint", "checkpoint_dir", "SyncConfig", "create_searcher",
|
||||
"create_scheduler"
|
||||
]
|
||||
|
||||
@@ -2,6 +2,15 @@ import getpass
|
||||
import os
|
||||
|
||||
|
||||
def is_ray_cluster():
|
||||
"""Checks if the bootstrap config file exists.
|
||||
|
||||
This will always exist if using an autoscaling cluster/started
|
||||
with the ray cluster launcher.
|
||||
"""
|
||||
return os.path.exists(os.path.expanduser("~/ray_bootstrap_config.yaml"))
|
||||
|
||||
|
||||
def get_ssh_user():
|
||||
"""Returns ssh username for connecting to cluster workers."""
|
||||
|
||||
|
||||
@@ -102,6 +102,12 @@ if __name__ == "__main__":
|
||||
address = None if args.local else "auto"
|
||||
ray.init(address=address)
|
||||
|
||||
sync_config = tune.SyncConfig(
|
||||
sync_to_driver=False,
|
||||
sync_on_checkpoint=False,
|
||||
upload_dir="s3://ray-tune-test/exps/",
|
||||
)
|
||||
|
||||
config = {
|
||||
"seed": None,
|
||||
"startup_delay": 0.001,
|
||||
@@ -117,12 +123,9 @@ if __name__ == "__main__":
|
||||
config=config,
|
||||
num_samples=4,
|
||||
verbose=1,
|
||||
queue_trials=True,
|
||||
# fault tolerance parameters
|
||||
sync_config=sync_config,
|
||||
max_failures=-1,
|
||||
checkpoint_freq=20,
|
||||
sync_to_driver=False,
|
||||
sync_on_checkpoint=False,
|
||||
upload_dir="s3://ray-tune-test/exps/",
|
||||
checkpoint_score_attr="training_iteration",
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ import ray
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray import ray_constants
|
||||
from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.cluster_info import is_ray_cluster
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.logger import NoopLogger
|
||||
@@ -135,10 +136,24 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""An implementation of TrialExecutor based on Ray."""
|
||||
|
||||
def __init__(self,
|
||||
queue_trials=False,
|
||||
queue_trials=None,
|
||||
reuse_actors=False,
|
||||
ray_auto_init=False,
|
||||
ray_auto_init=None,
|
||||
refresh_period=RESOURCE_REFRESH_PERIOD):
|
||||
if queue_trials is None:
|
||||
if os.environ.get("TUNE_DISABLE_QUEUE_TRIALS") == "1":
|
||||
logger.info("'TUNE_DISABLE_QUEUE_TRIALS=1' detected.")
|
||||
queue_trials = False
|
||||
elif is_ray_cluster():
|
||||
queue_trials = True
|
||||
|
||||
if ray_auto_init is None:
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
|
||||
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
|
||||
ray_auto_init = False
|
||||
else:
|
||||
ray_auto_init = True
|
||||
|
||||
super(RayTrialExecutor, self).__init__(queue_trials)
|
||||
# Check for if we are launching a trial without resources in kick off
|
||||
# autoscaler.
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
import distutils
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from inspect import isclass
|
||||
from shlex import quote
|
||||
|
||||
from ray import ray_constants
|
||||
from ray import services
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.utils.util import env_integer
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
||||
get_cloud_sync_client, NOOP)
|
||||
@@ -17,8 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Syncing period for syncing local checkpoints to cloud.
|
||||
# In env variable is not set, sync happens every 300 seconds.
|
||||
CLOUD_SYNC_PERIOD = ray_constants.env_integer(
|
||||
key="TUNE_CLOUD_SYNC_S", default=300)
|
||||
CLOUD_SYNC_PERIOD = 300
|
||||
|
||||
# Syncing period for syncing worker logs to driver.
|
||||
NODE_SYNC_PERIOD = 300
|
||||
@@ -32,6 +34,18 @@ def wait_for_sync():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
def set_sync_periods(sync_config):
|
||||
"""Sets sync periods from config."""
|
||||
global CLOUD_SYNC_PERIOD
|
||||
global NODE_SYNC_PERIOD
|
||||
if os.environ.get("TUNE_CLOUD_SYNC_S"):
|
||||
logger.warning("'TUNE_CLOUD_SYNC_S' is deprecated. Set "
|
||||
"`cloud_sync_period` via tune.SyncConfig instead.")
|
||||
CLOUD_SYNC_PERIOD = env_integer(key="TUNE_CLOUD_SYNC_S", default=300)
|
||||
NODE_SYNC_PERIOD = int(sync_config.node_sync_period)
|
||||
CLOUD_SYNC_PERIOD = int(sync_config.cloud_sync_period)
|
||||
|
||||
|
||||
def log_sync_template(options=""):
|
||||
"""Template enabling syncs between driver and worker when possible.
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
@@ -63,6 +77,42 @@ def log_sync_template(options=""):
|
||||
return template.format(options=options, rsh=quote(rsh))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncConfig:
|
||||
"""Configuration object for syncing.
|
||||
|
||||
Args:
|
||||
upload_dir (str): Optional URI to sync training results and checkpoints
|
||||
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
||||
sync_to_cloud (func|str): Function for syncing the local_dir to and
|
||||
from upload_dir. If string, then it must be a string template that
|
||||
includes `{source}` and `{target}` for the syncer to run. If not
|
||||
provided, the sync command defaults to standard S3 or gsutil sync
|
||||
commands. By default local_dir is synced to remote_dir every 300
|
||||
seconds. To change this, set the TUNE_CLOUD_SYNC_S
|
||||
environment variable in the driver machine.
|
||||
sync_to_driver (func|str|bool): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If True or not provided, it defaults to using rsync. If False,
|
||||
syncing to driver is disabled.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
|
||||
driver. If set to False, checkpoint syncing from worker to driver
|
||||
is asynchronous and best-effort. This does not affect persistent
|
||||
storage syncing. Defaults to True.
|
||||
node_sync_period (int): Syncing period for syncing worker logs to
|
||||
driver. Defaults to 300.
|
||||
cloud_sync_period (int): Syncing period for syncing local
|
||||
checkpoints to cloud. Defaults to 300.
|
||||
"""
|
||||
upload_dir: str = None
|
||||
sync_to_cloud: Any = None
|
||||
sync_to_driver: Any = None
|
||||
sync_on_checkpoint: bool = True
|
||||
node_sync_period: int = 300
|
||||
cloud_sync_period: int = 300
|
||||
|
||||
|
||||
class Syncer:
|
||||
def __init__(self, local_dir, remote_dir, sync_client=NOOP):
|
||||
"""Syncs between two directories with the sync_function.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from ray.tune import run
|
||||
@@ -44,9 +45,9 @@ if __name__ == "__main__":
|
||||
algo = ConcurrencyLimiter(algo, max_concurrent=1)
|
||||
from ray.tune import register_trainable
|
||||
register_trainable("trainable", MyTrainableClass)
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
run("trainable",
|
||||
search_alg=algo,
|
||||
global_checkpoint_period=0,
|
||||
resume=args.resume,
|
||||
verbose=0,
|
||||
num_samples=20,
|
||||
|
||||
@@ -245,17 +245,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
register_trainable("B", B)
|
||||
|
||||
def f(cpus, gpus, queue_trials):
|
||||
return run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "B",
|
||||
"config": {
|
||||
"cpu": cpus,
|
||||
"gpu": gpus,
|
||||
},
|
||||
}
|
||||
},
|
||||
queue_trials=queue_trials)[0]
|
||||
if not queue_trials:
|
||||
os.environ["TUNE_DISABLE_QUEUE_TRIALS"] = "1"
|
||||
else:
|
||||
os.environ.pop("TUNE_DISABLE_QUEUE_TRIALS", None)
|
||||
return run_experiments({
|
||||
"foo": {
|
||||
"run": "B",
|
||||
"config": {
|
||||
"cpu": cpus,
|
||||
"gpu": gpus,
|
||||
},
|
||||
}
|
||||
})[0]
|
||||
|
||||
# Should all succeed
|
||||
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
|
||||
@@ -639,8 +641,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
loggers=None)
|
||||
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
|
||||
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
|
||||
new_trials = tune.run(
|
||||
test, resume=True, run_errored_only=True, **config).trials
|
||||
new_trials = tune.run(test, resume="ERRORED_ONLY", **config).trials
|
||||
self.assertEqual(Counter(t.status for t in new_trials)["ERROR"], 0)
|
||||
self.assertTrue(
|
||||
all(t.last_result.get("hello") == 123 for t in new_trials))
|
||||
|
||||
@@ -642,10 +642,13 @@ def test_cluster_interrupt(start_connected_cluster, tmpdir):
|
||||
for line in inspect.getsource(_Mock).split("\n"))
|
||||
|
||||
script = """
|
||||
import os
|
||||
import time
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
|
||||
ray.init(address="{address}")
|
||||
|
||||
{fail_class_code}
|
||||
@@ -656,7 +659,6 @@ tune.run(
|
||||
stop=dict(training_iteration=5),
|
||||
local_dir="{checkpoint_dir}",
|
||||
checkpoint_freq=1,
|
||||
global_checkpoint_period=0,
|
||||
max_failures=1,
|
||||
raise_on_failed_trial=False)
|
||||
""".format(
|
||||
|
||||
@@ -147,7 +147,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
MyTrainableClass,
|
||||
name="test_example",
|
||||
local_dir=self.test_dir,
|
||||
return_trials=False,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=1,
|
||||
config={
|
||||
|
||||
@@ -135,7 +135,6 @@ class AnalysisSuite(unittest.TestCase):
|
||||
run(MyTrainableClass,
|
||||
name=test_name,
|
||||
local_dir=self.test_dir,
|
||||
return_trials=False,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=self.num_samples,
|
||||
config={
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.cluster_utils import Cluster
|
||||
class RayTrialExecutorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
ray.init()
|
||||
ray.init(ignore_reinit_error=True)
|
||||
_register_all() # Needed for flaky tests
|
||||
|
||||
def tearDown(self):
|
||||
@@ -182,8 +182,6 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
|
||||
class RayExecutorQueueTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(
|
||||
queue_trials=True, refresh_period=0)
|
||||
self.cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
@@ -193,6 +191,8 @@ class RayExecutorQueueTest(unittest.TestCase):
|
||||
"num_heartbeats_timeout": 10
|
||||
}
|
||||
})
|
||||
self.trial_executor = RayTrialExecutor(
|
||||
queue_trials=True, refresh_period=0)
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
|
||||
@@ -247,8 +247,8 @@ class RayExecutorQueueTest(unittest.TestCase):
|
||||
|
||||
class LocalModeExecutorTest(RayTrialExecutorTest):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
ray.init(local_mode=True)
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
@@ -31,12 +31,11 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_cloud": "echo {source} {target}"
|
||||
}).trials
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=tune.SyncConfig(
|
||||
**{"sync_to_cloud": "echo {source} {target}"})).trials
|
||||
|
||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||
def testCloudProperString(self):
|
||||
@@ -45,26 +44,26 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=tune.SyncConfig(**{
|
||||
"upload_dir": "test",
|
||||
"sync_to_cloud": "ls {target}"
|
||||
}).trials
|
||||
})).trials
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=tune.SyncConfig(**{
|
||||
"upload_dir": "test",
|
||||
"sync_to_cloud": "ls {source}"
|
||||
}).trials
|
||||
})).trials
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
logfile = os.path.join(tmpdir, "test.log")
|
||||
@@ -73,13 +72,14 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_to_cloud": "echo {source} {target} > " + logfile
|
||||
}).trials
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=tune.SyncConfig(
|
||||
**{
|
||||
"upload_dir": "test",
|
||||
"sync_to_cloud": "echo {source} {target} > " + logfile
|
||||
})).trials
|
||||
with open(logfile) as f:
|
||||
lines = f.read()
|
||||
self.assertTrue("test" in lines)
|
||||
@@ -89,42 +89,41 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"""Tests that invalid commands throw.."""
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "ls {target}"
|
||||
}).trials
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "ls {source}"
|
||||
sync_config=sync_config,
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
}).trials
|
||||
|
||||
with patch.object(CommandBasedClient, "_execute") as mock_fn:
|
||||
with patch("ray.services.get_node_ip_address") as mock_sync:
|
||||
sync_config = tune.SyncConfig(
|
||||
sync_to_driver="echo {source} {target}")
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "echo {source} {target}"
|
||||
sync_config=sync_config,
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
}).trials
|
||||
self.assertGreater(mock_fn.call_count, 0)
|
||||
|
||||
@@ -137,6 +136,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
for filename in glob.glob(os.path.join(local, "*.json")):
|
||||
shutil.copy(filename, remote)
|
||||
|
||||
sync_config = tune.SyncConfig(
|
||||
upload_dir=tmpdir2, sync_to_cloud=sync_func)
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
@@ -145,8 +146,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
upload_dir=tmpdir2,
|
||||
sync_to_cloud=sync_func).trials
|
||||
sync_config=sync_config).trials
|
||||
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
|
||||
self.assertTrue(test_file_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
@@ -167,18 +167,21 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
def counter(local, remote):
|
||||
mock()
|
||||
|
||||
tune.syncer.CLOUD_SYNC_PERIOD = 1
|
||||
sync_config = tune.SyncConfig(
|
||||
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
|
||||
# This was originally set to 0.5
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
self.addCleanup(
|
||||
lambda: os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S", None))
|
||||
[trial] = tune.run(
|
||||
trainable,
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
local_dir=tmpdir,
|
||||
upload_dir="test",
|
||||
sync_to_cloud=counter,
|
||||
stop={
|
||||
"training_iteration": 10
|
||||
},
|
||||
global_checkpoint_period=0.5,
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
self.assertEqual(mock.call_count, 12)
|
||||
@@ -192,6 +195,9 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
print("writing to", f.name)
|
||||
f.write(source)
|
||||
|
||||
sync_config = tune.SyncConfig(
|
||||
sync_to_driver=sync_func_driver, node_sync_period=5)
|
||||
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
@@ -199,12 +205,13 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_to_driver=sync_func_driver).trials
|
||||
sync_config=sync_config).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertFalse(os.path.exists(test_file_path))
|
||||
|
||||
with patch("ray.services.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
sync_config = tune.SyncConfig(sync_to_driver=sync_func_driver)
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
@@ -212,7 +219,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_to_driver=sync_func_driver).trials
|
||||
sync_config=sync_config).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertTrue(os.path.exists(test_file_path))
|
||||
os.remove(test_file_path)
|
||||
@@ -223,17 +230,17 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
def sync_func(source, target):
|
||||
pass
|
||||
|
||||
sync_config = tune.SyncConfig(sync_to_driver=sync_func)
|
||||
|
||||
with patch.object(CommandBasedClient, "_execute") as mock_sync:
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": sync_func
|
||||
}).trials
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_config=sync_config).trials
|
||||
self.assertEqual(mock_sync.call_count, 0)
|
||||
|
||||
|
||||
|
||||
@@ -385,10 +385,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
assert trials[0].status == Trial.ERROR
|
||||
del runner
|
||||
|
||||
new_runner = TrialRunner(
|
||||
run_errored_only=False,
|
||||
resume=True,
|
||||
local_checkpoint_dir=self.tmpdir)
|
||||
new_runner = TrialRunner(resume=True, local_checkpoint_dir=self.tmpdir)
|
||||
assert len(new_runner.get_trials()) == 3
|
||||
assert Trial.ERROR in (t.status for t in new_runner.get_trials())
|
||||
|
||||
@@ -418,9 +415,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
del runner
|
||||
|
||||
new_runner = TrialRunner(
|
||||
run_errored_only=True,
|
||||
resume=True,
|
||||
local_checkpoint_dir=self.tmpdir)
|
||||
resume="ERRORED_ONLY", local_checkpoint_dir=self.tmpdir)
|
||||
assert len(new_runner.get_trials()) == 3
|
||||
assert Trial.ERROR not in (t.status for t in new_runner.get_trials())
|
||||
# The below is just a check for standard behavior.
|
||||
|
||||
@@ -29,7 +29,7 @@ class TuneServerSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(launch_web_server=True, server_port=port)
|
||||
self.runner = TrialRunner(server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
|
||||
@@ -165,10 +165,8 @@ class TrialExecutor:
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster has only {}. "
|
||||
"Pass `queue_trials=True` in "
|
||||
"ray.tune.run() or on the command "
|
||||
"line to queue trials until the cluster scales "
|
||||
"up or resources become available. {}").format(
|
||||
"This error should not occur if running on an "
|
||||
"autoscaling cluster. {}").format(
|
||||
trial.resources.summary_string(),
|
||||
self.resource_string(),
|
||||
trial.get_trainable_cls().resource_help(
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray.tune.syncer import get_cloud_syncer
|
||||
from ray.tune.trial import Checkpoint, Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.util.debug import log_once
|
||||
@@ -95,7 +95,6 @@ class TrialRunner:
|
||||
search_alg (SearchAlgorithm): SearchAlgorithm for generating
|
||||
Trial objects.
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
local_checkpoint_dir (str): Path where
|
||||
global checkpoints are stored and restored from.
|
||||
remote_checkpoint_dir (str): Remote path where
|
||||
@@ -110,10 +109,6 @@ class TrialRunner:
|
||||
If fail_fast='raise' provided, Tune will automatically
|
||||
raise the exception received by the Trainable. fail_fast='raise'
|
||||
can easily leak resources and should be used with caution.
|
||||
run_errored_only (bool): Resets and reruns failed trials, assuming
|
||||
the provided Trainable is the same. Previous trial artifacts
|
||||
will be left untouched. Only to be used with
|
||||
`resume` enabled. Raises ValueError otherwise.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
checkpoint_period (int): Trial runner checkpoint periodicity in
|
||||
@@ -122,23 +117,21 @@ class TrialRunner:
|
||||
"""
|
||||
|
||||
CKPT_FILE_TMPL = "experiment_state-{}.json"
|
||||
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT"]
|
||||
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"]
|
||||
RAISE = "RAISE"
|
||||
|
||||
def __init__(self,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
local_checkpoint_dir=None,
|
||||
remote_checkpoint_dir=None,
|
||||
sync_to_cloud=None,
|
||||
stopper=None,
|
||||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
server_port=None,
|
||||
fail_fast=False,
|
||||
run_errored_only=False,
|
||||
verbose=True,
|
||||
checkpoint_period=10,
|
||||
checkpoint_period=None,
|
||||
trial_executor=None):
|
||||
self._search_alg = search_alg or BasicVariantGenerator()
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
@@ -168,7 +161,7 @@ class TrialRunner:
|
||||
|
||||
self._server = None
|
||||
self._server_port = server_port
|
||||
if launch_web_server:
|
||||
if server_port is not None:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
self._trials = []
|
||||
@@ -187,8 +180,11 @@ class TrialRunner:
|
||||
self._resumed = False
|
||||
|
||||
if self._validate_resume(resume_type=resume):
|
||||
errored_only = False
|
||||
if isinstance(resume, str):
|
||||
errored_only = resume.upper() == "ERRORED_ONLY"
|
||||
try:
|
||||
self.resume(run_errored_only=run_errored_only)
|
||||
self.resume(run_errored_only=errored_only)
|
||||
self._resumed = True
|
||||
except Exception as e:
|
||||
if self._verbose:
|
||||
@@ -198,15 +194,12 @@ class TrialRunner:
|
||||
raise
|
||||
logger.info("Restarting experiment.")
|
||||
else:
|
||||
if run_errored_only:
|
||||
raise ValueError(
|
||||
"'run_errored_only' should only be used with 'resume'. "
|
||||
f"Got: resume={resume}, "
|
||||
f"run_errored_only={run_errored_only}")
|
||||
logger.debug("Starting a new experiment.")
|
||||
|
||||
self._start_time = time.time()
|
||||
self._last_checkpoint_time = -float("inf")
|
||||
if checkpoint_period is None:
|
||||
checkpoint_period = env_integer("TUNE_GLOBAL_CHECKPOINT_S", 10)
|
||||
self._checkpoint_period = checkpoint_period
|
||||
self._session_str = datetime.fromtimestamp(
|
||||
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
@@ -228,8 +221,10 @@ class TrialRunner:
|
||||
"""Checks whether to resume experiment.
|
||||
|
||||
Args:
|
||||
resume_type: One of True, "REMOTE", "LOCAL", "PROMPT".
|
||||
resume_type: One of True, "REMOTE", "LOCAL",
|
||||
"PROMPT", "ERRORED_ONLY".
|
||||
"""
|
||||
# TODO: Consider supporting ERRORED_ONLY+REMOTE?
|
||||
if not resume_type:
|
||||
return False
|
||||
assert resume_type in self.VALID_RESUME_TYPES, (
|
||||
@@ -238,7 +233,7 @@ class TrialRunner:
|
||||
# Not clear if we need this assertion, since we should always have a
|
||||
# local checkpoint dir.
|
||||
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
|
||||
if resume_type in [True, "LOCAL", "PROMPT"]:
|
||||
if resume_type in [True, "LOCAL", "PROMPT", "ERRORED_ONLY"]:
|
||||
if not self.checkpoint_exists(self._local_checkpoint_dir):
|
||||
raise ValueError("Called resume when no checkpoint exists "
|
||||
"in local directory.")
|
||||
|
||||
+101
-131
@@ -10,12 +10,11 @@ from ray.tune.trial import Trial
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.syncer import wait_for_sync
|
||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
FIFOScheduler, MedianStoppingRule)
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,44 +65,47 @@ def _report_progress(runner, reporter, done=False):
|
||||
reporter.report(trials, done, sched_debug_str, executor_debug_str)
|
||||
|
||||
|
||||
def run(run_or_experiment,
|
||||
def run(
|
||||
run_or_experiment,
|
||||
name=None,
|
||||
stop=None,
|
||||
config=None,
|
||||
resources_per_trial=None,
|
||||
num_samples=1,
|
||||
local_dir=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
sync_to_cloud=None,
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr=None,
|
||||
global_checkpoint_period=10,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
verbose=2,
|
||||
progress_reporter=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
sync_config=None,
|
||||
export_formats=None,
|
||||
max_failures=0,
|
||||
fail_fast=False,
|
||||
restore=None,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=2,
|
||||
progress_reporter=None,
|
||||
server_port=None,
|
||||
resume=False,
|
||||
run_errored_only=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
return_trials=False,
|
||||
ray_auto_init=True):
|
||||
# Deprecated args
|
||||
ray_auto_init=None,
|
||||
run_errored_only=None,
|
||||
queue_trials=None,
|
||||
global_checkpoint_period=None,
|
||||
with_server=None,
|
||||
upload_dir=None,
|
||||
sync_to_cloud=None,
|
||||
sync_to_driver=None,
|
||||
sync_on_checkpoint=None,
|
||||
):
|
||||
"""Executes training.
|
||||
|
||||
Examples:
|
||||
@@ -130,7 +132,7 @@ def run(run_or_experiment,
|
||||
|
||||
# Rerun ONLY failed trials after an experiment is finished.
|
||||
tune.run(my_trainable, config=space,
|
||||
local_dir=<path/to/dir>, resume=True, run_errored_only=True)
|
||||
local_dir=<path/to/dir>, resume="ERRORED_ONLY")
|
||||
|
||||
Args:
|
||||
run_or_experiment (function | class | str | :class:`Experiment`): If
|
||||
@@ -166,14 +168,30 @@ def run(run_or_experiment,
|
||||
`num_samples` of times.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results and checkpoints
|
||||
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
||||
trial_name_creator (Callable[[Trial], str]): Optional function
|
||||
for generating the trial string representation.
|
||||
trial_dirname_creator (Callable[[Trial], str]): Function
|
||||
for generating the trial dirname. This function should take
|
||||
in a Trial object and return a string representing the
|
||||
name of the directory. The return value cannot be a path.
|
||||
search_alg (Searcher): Search algorithm for optimization.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
|
||||
ray.tune.schedulers for more options.
|
||||
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
|
||||
`None` keeps all checkpoints. Defaults to `None`. If set, need
|
||||
to provide `checkpoint_score_attr`.
|
||||
checkpoint_score_attr (str): Specifies by which attribute to rank the
|
||||
best checkpoint. Default is increasing order. If attribute starts
|
||||
with `min-` it will rank attribute in decreasing order, i.e.
|
||||
`min-validation_loss`.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
This has no effect when using the Functional Training API.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
This has no effect when using the Functional Training API.
|
||||
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
|
||||
1 = only status updates, 2 = status and trial results.
|
||||
progress_reporter (ProgressReporter): Progress reporter for reporting
|
||||
intermediate experiment progress. Defaults to CLIReporter if
|
||||
running in command-line, or JupyterNotebookReporter if running in
|
||||
a Jupyter notebook.
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
@@ -185,38 +203,14 @@ def run(run_or_experiment,
|
||||
both streams are written. If this is a Sequence (e.g. a Tuple),
|
||||
it has to have length 2 and the elements indicate the files to
|
||||
which stdout and stderr are written, respectively.
|
||||
sync_to_cloud (func|str): Function for syncing the local_dir to and
|
||||
from upload_dir. If string, then it must be a string template that
|
||||
includes `{source}` and `{target}` for the syncer to run. If not
|
||||
provided, the sync command defaults to standard S3 or gsutil sync
|
||||
commands. By default local_dir is synced to remote_dir every 300
|
||||
seconds. To change this, set the TUNE_CLOUD_SYNC_S
|
||||
environment variable in the driver machine.
|
||||
sync_to_driver (func|str|bool): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If True or not provided, it defaults to using rsync. If False,
|
||||
syncing to driver is disabled.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
This has no effect when using the Functional Training API.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
This has no effect when using the Functional Training API.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
|
||||
driver. If set to False, checkpoint syncing from worker to driver
|
||||
is asynchronous and best-effort. This does not affect persistent
|
||||
storage syncing. Defaults to True.
|
||||
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
|
||||
`None` keeps all checkpoints. Defaults to `None`. If set, need
|
||||
to provide `checkpoint_score_attr`.
|
||||
checkpoint_score_attr (str): Specifies by which attribute to rank the
|
||||
best checkpoint. Default is increasing order. If attribute starts
|
||||
with `min-` it will rank attribute in decreasing order, i.e.
|
||||
`min-validation_loss`.
|
||||
global_checkpoint_period (int): Seconds between global checkpointing.
|
||||
This does not affect `checkpoint_freq`, which specifies frequency
|
||||
for individual trials.
|
||||
trial_name_creator (Callable[[Trial], str]): Optional function
|
||||
for generating the trial string representation.
|
||||
trial_dirname_creator (Callable[[Trial], str]): Function
|
||||
for generating the trial dirname. This function should take
|
||||
in a Trial object and return a string representing the
|
||||
name of the directory. The return value cannot be a path.
|
||||
sync_config (SyncConfig): Configuration object for syncing. See
|
||||
tune.SyncConfig.
|
||||
export_formats (list): List of formats that exported at the end of
|
||||
the experiment. Default is None.
|
||||
max_failures (int): Try to recover a trial at least this many times.
|
||||
@@ -230,35 +224,16 @@ def run(run_or_experiment,
|
||||
is best used with `ray.init(local_mode=True)`).
|
||||
restore (str): Path to checkpoint. Only makes sense to set if
|
||||
running 1 trial. Defaults to None.
|
||||
search_alg (Searcher): Search algorithm for optimization.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
|
||||
ray.tune.schedulers for more options.
|
||||
with_server (bool): Starts a background Tune server. Needed for
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
|
||||
1 = only status updates, 2 = status and trial results.
|
||||
progress_reporter (ProgressReporter): Progress reporter for reporting
|
||||
intermediate experiment progress. Defaults to CLIReporter if
|
||||
running in command-line, or JupyterNotebookReporter if running in
|
||||
a Jupyter notebook.
|
||||
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool.
|
||||
LOCAL/True restores the checkpoint from the local_checkpoint_dir.
|
||||
REMOTE restores the checkpoint from remote_checkpoint_dir.
|
||||
PROMPT provides CLI feedback. False forces a new
|
||||
experiment. If resume is set but checkpoint does not exist,
|
||||
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY",
|
||||
or bool. LOCAL/True restores the checkpoint from the
|
||||
local_checkpoint_dir, determined
|
||||
by `name` and `local_dir`. REMOTE restores the checkpoint
|
||||
from remote_checkpoint_dir. PROMPT provides CLI feedback.
|
||||
False forces a new experiment. ERRORED_ONLY resets and reruns
|
||||
ERRORED trials upon resume - previous trial artifacts will
|
||||
be left untouched. If resume is set but checkpoint does not exist,
|
||||
ValueError will be thrown.
|
||||
run_errored_only (bool): Only to be used with `resume` enabled.
|
||||
Resets and reruns ERRORED trials upon resume.
|
||||
Experiment location is determined
|
||||
by `name` and `local_dir`. Previous trial artifacts will
|
||||
be left untouched.
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
reuse_actors (bool): Whether to reuse actors between different trials
|
||||
when possible. This can drastically speed up experiments that start
|
||||
and stop actors often (e.g., PBT in time-multiplexing mode). This
|
||||
@@ -266,9 +241,6 @@ def run(run_or_experiment,
|
||||
trial_executor (TrialExecutor): Manage the execution of trials.
|
||||
raise_on_failed_trial (bool): Raise TuneError if there exists failed
|
||||
trial (of ERROR state) when the experiments complete.
|
||||
ray_auto_init (bool): Automatically starts a local Ray cluster
|
||||
if using a RayTrialExecutor (which is the default) and
|
||||
if Ray is not initialized. Defaults to True.
|
||||
|
||||
|
||||
Returns:
|
||||
@@ -277,12 +249,35 @@ def run(run_or_experiment,
|
||||
Raises:
|
||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
if global_checkpoint_period:
|
||||
raise ValueError("global_checkpoint_period is deprecated. Set env var "
|
||||
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
|
||||
if queue_trials:
|
||||
raise ValueError(
|
||||
"queue_trials is deprecated. "
|
||||
"Set env var 'TUNE_DISABLE_QUEUE_TRIALS=1' instead to "
|
||||
"disable queuing behavior.")
|
||||
if ray_auto_init:
|
||||
raise ValueError("ray_auto_init is deprecated. "
|
||||
"Set env var 'TUNE_DISABLE_AUTO_INIT=1' instead or "
|
||||
"call 'ray.init' before calling 'tune.run'.")
|
||||
if with_server:
|
||||
raise ValueError(
|
||||
"with_server is deprecated. It is now enabled by default "
|
||||
"if 'server_port' is not None.")
|
||||
if sync_on_checkpoint or sync_to_cloud or sync_to_driver or upload_dir:
|
||||
raise ValueError(
|
||||
"sync_on_checkpoint / sync_to_cloud / sync_to_driver / "
|
||||
"upload_dir must now be set via `tune.run("
|
||||
"sync_config=SyncConfig(...)`. See `ray.tune.SyncConfig` for "
|
||||
"more details.")
|
||||
|
||||
config = config or {}
|
||||
sync_config = sync_config or SyncConfig()
|
||||
set_sync_periods(sync_config)
|
||||
|
||||
trial_executor = trial_executor or RayTrialExecutor(
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
ray_auto_init=ray_auto_init)
|
||||
reuse_actors=reuse_actors)
|
||||
if isinstance(run_or_experiment, list):
|
||||
experiments = run_or_experiment
|
||||
else:
|
||||
@@ -298,15 +293,15 @@ def run(run_or_experiment,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
sync_to_driver=sync_to_driver,
|
||||
upload_dir=sync_config.upload_dir,
|
||||
sync_to_driver=sync_config.sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
trial_dirname_creator=trial_dirname_creator,
|
||||
loggers=loggers,
|
||||
log_to_file=log_to_file,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
sync_on_checkpoint=sync_on_checkpoint,
|
||||
sync_on_checkpoint=sync_config.sync_on_checkpoint,
|
||||
keep_checkpoints_num=keep_checkpoints_num,
|
||||
checkpoint_score_attr=checkpoint_score_attr,
|
||||
export_formats=export_formats,
|
||||
@@ -315,7 +310,7 @@ def run(run_or_experiment,
|
||||
else:
|
||||
logger.debug("Ignoring some parameters passed into tune.run.")
|
||||
|
||||
if sync_to_cloud:
|
||||
if sync_config.sync_to_cloud:
|
||||
for exp in experiments:
|
||||
assert exp.remote_checkpoint_dir, (
|
||||
"Need `upload_dir` if `sync_to_cloud` given.")
|
||||
@@ -344,12 +339,9 @@ def run(run_or_experiment,
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
local_checkpoint_dir=experiments[0].checkpoint_dir,
|
||||
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
sync_to_cloud=sync_config.sync_to_cloud,
|
||||
stopper=experiments[0].stopper,
|
||||
checkpoint_period=global_checkpoint_period,
|
||||
resume=resume,
|
||||
run_errored_only=run_errored_only,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
fail_fast=fail_fast,
|
||||
@@ -413,8 +405,6 @@ def run(run_or_experiment,
|
||||
logger.error("Trials did not complete: %s", incomplete_trials)
|
||||
|
||||
trials = runner.get_trials()
|
||||
if return_trials:
|
||||
return trials
|
||||
return ExperimentAnalysis(
|
||||
runner.checkpoint_file,
|
||||
trials=trials,
|
||||
@@ -423,14 +413,11 @@ def run(run_or_experiment,
|
||||
|
||||
|
||||
def run_experiments(experiments,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
server_port=None,
|
||||
verbose=2,
|
||||
progress_reporter=None,
|
||||
resume=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
@@ -444,15 +431,6 @@ def run_experiments(experiments,
|
||||
>>> experiment_spec = {"experiment": {"run": my_func}}
|
||||
>>> run_experiments(experiments=experiment_spec)
|
||||
|
||||
>>> run_experiments(
|
||||
>>> experiments=experiment_spec,
|
||||
>>> scheduler=MedianStoppingRule(...))
|
||||
|
||||
>>> run_experiments(
|
||||
>>> experiments=experiment_spec,
|
||||
>>> search_alg=SearchAlgorithm(),
|
||||
>>> scheduler=MedianStoppingRule(...))
|
||||
|
||||
Returns:
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
@@ -465,33 +443,25 @@ def run_experiments(experiments,
|
||||
if concurrent:
|
||||
return run(
|
||||
experiments,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
with_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
progress_reporter=progress_reporter,
|
||||
resume=resume,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
return_trials=True)
|
||||
scheduler=scheduler).trials
|
||||
else:
|
||||
trials = []
|
||||
for exp in experiments:
|
||||
trials += run(
|
||||
exp,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
with_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
progress_reporter=progress_reporter,
|
||||
resume=resume,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
return_trials=True)
|
||||
scheduler=scheduler).trials
|
||||
return trials
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization, env_integer
|
||||
|
||||
__all__ = [
|
||||
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization"
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
|
||||
"env_integer"
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
@@ -152,6 +153,17 @@ class Tee(object):
|
||||
self.stream2.flush(*args, **kwargs)
|
||||
|
||||
|
||||
def env_integer(key, default):
|
||||
# TODO(rliaw): move into ray.constants
|
||||
if key in os.environ:
|
||||
value = os.environ[key]
|
||||
if value.isdigit():
|
||||
return int(os.environ[key])
|
||||
raise ValueError(f"Found {key} in environment, but value must "
|
||||
f"be an integer. Got: {value}.")
|
||||
return default
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user