Move the tune driver into a remote task (#13778)

This commit is contained in:
Eric Liang
2021-02-02 18:41:45 -08:00
committed by GitHub
parent b4684cf37a
commit d335ce2aab
7 changed files with 197 additions and 20 deletions
+8
View File
@@ -163,6 +163,14 @@ py_test(
tags = ["exclusive"],
)
py_test(
name = "test_remote",
size = "medium",
srcs = ["tests/test_remote.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)
py_test(
name = "test_sample",
size = "medium",
-13
View File
@@ -154,15 +154,7 @@ class RayTrialExecutor(TrialExecutor):
def __init__(self,
queue_trials: bool = False,
reuse_actors: bool = False,
ray_auto_init: Optional[bool] = None,
refresh_period: Optional[float] = None):
if ray_auto_init is None:
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
ray_auto_init = False
else:
ray_auto_init = True
super(RayTrialExecutor, self).__init__(queue_trials)
# Check for if we are launching a trial without resources in kick off
# autoscaler.
@@ -193,11 +185,6 @@ class RayTrialExecutor(TrialExecutor):
self._last_ip_refresh = float("-inf")
self._last_ip_addresses = set()
self._last_nontrivial_wait = time.time()
if not ray.is_initialized() and ray_auto_init:
logger.info("Initializing Ray automatically."
"For cluster usage or custom Ray initialization, "
"call `ray.init(...)` before `tune.run`.")
ray.init()
if ray.is_initialized():
self._update_avail_resources()
+77
View File
@@ -0,0 +1,77 @@
import unittest
import ray
from ray.tune import register_trainable, run_experiments, run
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
from ray.util.client.ray_client_helpers import ray_start_client_server
class RemoteTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
def testRemoteRunExperiments(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
})
[trial] = run_experiments(exp1, _remote=True)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testRemoteRun(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
analysis = run(train, _remote=True)
[trial] = analysis.trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testRemoteRunExperimentsInClient(self):
ray.init()
assert not ray.util.client.ray.is_connected()
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
})
[trial] = run_experiments(exp1)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testRemoteRunInClient(self):
ray.init()
assert not ray.util.client.ray.is_connected()
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
analysis = run(train)
[trial] = analysis.trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
@@ -697,6 +697,8 @@ class TrialRunnerTest3(unittest.TestCase):
@patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0)
def testCheckpointAutoPeriod(self):
ray.init(num_cpus=3)
# This makes checkpointing take 2 seconds.
def sync_up(source, target):
time.sleep(2)
@@ -73,6 +73,7 @@ class _MockTrialExecutor(RayTrialExecutor):
class TrialRunnerCallbacks(unittest.TestCase):
def setUp(self):
ray.init()
self.tmpdir = tempfile.mkdtemp()
self.callback = TestCallback()
self.executor = _MockTrialExecutor()
+7 -6
View File
@@ -166,6 +166,13 @@ class Trial:
"""
_nonjson_fields = [
"results",
"best_result",
"param_config",
"extra_arg",
]
PENDING = "PENDING"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
@@ -289,12 +296,6 @@ class Trial:
self.param_config = None
self.extra_arg = None
self._nonjson_fields = [
"results",
"best_result",
"param_config",
"extra_arg",
]
if trial_name_creator:
self.custom_trial_name = trial_name_creator(self)
+102 -1
View File
@@ -8,6 +8,7 @@ import signal
import sys
import time
import ray
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.callback import Callback
from ray.tune.error import TuneError
@@ -111,6 +112,7 @@ def run(
sync_to_cloud: Optional = None,
sync_to_driver: Optional = None,
sync_on_checkpoint: Optional = None,
_remote: bool = None,
) -> ExperimentAnalysis:
"""Executes training.
@@ -270,6 +272,9 @@ def run(
``ray.tune.callback.Callback`` class. If not passed,
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
_remote (bool): Whether to run the Tune driver in a remote function.
This is disabled automatically if a custom trial executor is
passed in. This is enabled by default in Ray client mode.
Returns:
ExperimentAnalysis: Object for experiment analysis.
@@ -277,6 +282,64 @@ def run(
Raises:
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""
if _remote is None:
_remote = ray.util.client.ray.is_connected()
if _remote is True and trial_executor:
raise ValueError("cannot use custom trial executor")
if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
_ray_auto_init()
if _remote:
return ray.get(
ray.remote(num_cpus=0)(run).remote(
run_or_experiment,
name,
metric,
mode,
stop,
time_budget_s,
config,
resources_per_trial,
num_samples,
local_dir,
search_alg,
scheduler,
keep_checkpoints_num,
checkpoint_score_attr,
checkpoint_freq,
checkpoint_at_end,
verbose,
progress_reporter,
log_to_file,
trial_name_creator,
trial_dirname_creator,
sync_config,
export_formats,
max_failures,
fail_fast,
restore,
server_port,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
callbacks,
# Deprecated args
loggers,
ray_auto_init,
run_errored_only,
global_checkpoint_period,
with_server,
upload_dir,
sync_to_cloud,
sync_to_driver,
sync_on_checkpoint,
_remote=False))
all_start = time.time()
if global_checkpoint_period:
raise ValueError("global_checkpoint_period is deprecated. Set env var "
@@ -509,7 +572,8 @@ def run_experiments(
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
concurrent: bool = True,
callbacks: Optional[Sequence[Callback]] = None):
callbacks: Optional[Sequence[Callback]] = None,
_remote: bool = None):
"""Runs and blocks until all trials finish.
Examples:
@@ -523,6 +587,32 @@ def run_experiments(
List of Trial objects, holding data for each executed trial.
"""
if _remote is None:
_remote = ray.util.client.ray.is_connected()
if _remote is True and trial_executor:
raise ValueError("cannot use custom trial executor")
if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
_ray_auto_init()
if _remote:
return ray.get(
ray.remote(num_cpus=0)(run_experiments).remote(
experiments,
scheduler,
server_port,
verbose,
progress_reporter,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
concurrent,
callbacks,
_remote=False))
# This is important to do this here
# because it schematize the experiments
# and it conducts the implicit registration.
@@ -557,3 +647,14 @@ def run_experiments(
scheduler=scheduler,
callbacks=callbacks).trials
return trials
def _ray_auto_init():
"""Initialize Ray unless already configured."""
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
elif not ray.is_initialized():
logger.info("Initializing Ray automatically."
"For cluster usage or custom Ray initialization, "
"call `ray.init(...)` before `tune.run`.")
ray.init()