mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
Move the tune driver into a remote task (#13778)
This commit is contained in:
@@ -163,6 +163,14 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_remote",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_remote.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_sample",
|
||||
size = "medium",
|
||||
|
||||
@@ -154,15 +154,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
def __init__(self,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
ray_auto_init: Optional[bool] = None,
|
||||
refresh_period: Optional[float] = None):
|
||||
if ray_auto_init is None:
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
|
||||
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
|
||||
ray_auto_init = False
|
||||
else:
|
||||
ray_auto_init = True
|
||||
|
||||
super(RayTrialExecutor, self).__init__(queue_trials)
|
||||
# Check for if we are launching a trial without resources in kick off
|
||||
# autoscaler.
|
||||
@@ -193,11 +185,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._last_ip_refresh = float("-inf")
|
||||
self._last_ip_addresses = set()
|
||||
self._last_nontrivial_wait = time.time()
|
||||
if not ray.is_initialized() and ray_auto_init:
|
||||
logger.info("Initializing Ray automatically."
|
||||
"For cluster usage or custom Ray initialization, "
|
||||
"call `ray.init(...)` before `tune.run`.")
|
||||
ray.init()
|
||||
|
||||
if ray.is_initialized():
|
||||
self._update_avail_resources()
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.tune import register_trainable, run_experiments, run
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
class RemoteTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testRemoteRunExperiments(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
})
|
||||
[trial] = run_experiments(exp1, _remote=True)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testRemoteRun(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
analysis = run(train, _remote=True)
|
||||
[trial] = analysis.trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testRemoteRunExperimentsInClient(self):
|
||||
ray.init()
|
||||
assert not ray.util.client.ray.is_connected()
|
||||
with ray_start_client_server():
|
||||
assert ray.util.client.ray.is_connected()
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
})
|
||||
[trial] = run_experiments(exp1)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testRemoteRunInClient(self):
|
||||
ray.init()
|
||||
assert not ray.util.client.ray.is_connected()
|
||||
with ray_start_client_server():
|
||||
assert ray.util.client.ray.is_connected()
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
analysis = run(train)
|
||||
[trial] = analysis.trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -697,6 +697,8 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
|
||||
@patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0)
|
||||
def testCheckpointAutoPeriod(self):
|
||||
ray.init(num_cpus=3)
|
||||
|
||||
# This makes checkpointing take 2 seconds.
|
||||
def sync_up(source, target):
|
||||
time.sleep(2)
|
||||
|
||||
@@ -73,6 +73,7 @@ class _MockTrialExecutor(RayTrialExecutor):
|
||||
|
||||
class TrialRunnerCallbacks(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.callback = TestCallback()
|
||||
self.executor = _MockTrialExecutor()
|
||||
|
||||
@@ -166,6 +166,13 @@ class Trial:
|
||||
|
||||
"""
|
||||
|
||||
_nonjson_fields = [
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
"extra_arg",
|
||||
]
|
||||
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
PAUSED = "PAUSED"
|
||||
@@ -289,12 +296,6 @@ class Trial:
|
||||
self.param_config = None
|
||||
self.extra_arg = None
|
||||
|
||||
self._nonjson_fields = [
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
"extra_arg",
|
||||
]
|
||||
if trial_name_creator:
|
||||
self.custom_trial_name = trial_name_creator(self)
|
||||
|
||||
|
||||
+102
-1
@@ -8,6 +8,7 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.error import TuneError
|
||||
@@ -111,6 +112,7 @@ def run(
|
||||
sync_to_cloud: Optional = None,
|
||||
sync_to_driver: Optional = None,
|
||||
sync_on_checkpoint: Optional = None,
|
||||
_remote: bool = None,
|
||||
) -> ExperimentAnalysis:
|
||||
"""Executes training.
|
||||
|
||||
@@ -270,6 +272,9 @@ def run(
|
||||
``ray.tune.callback.Callback`` class. If not passed,
|
||||
`LoggerCallback` and `SyncerCallback` callbacks are automatically
|
||||
added.
|
||||
_remote (bool): Whether to run the Tune driver in a remote function.
|
||||
This is disabled automatically if a custom trial executor is
|
||||
passed in. This is enabled by default in Ray client mode.
|
||||
|
||||
Returns:
|
||||
ExperimentAnalysis: Object for experiment analysis.
|
||||
@@ -277,6 +282,64 @@ def run(
|
||||
Raises:
|
||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
|
||||
if _remote is None:
|
||||
_remote = ray.util.client.ray.is_connected()
|
||||
|
||||
if _remote is True and trial_executor:
|
||||
raise ValueError("cannot use custom trial executor")
|
||||
|
||||
if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
|
||||
_ray_auto_init()
|
||||
|
||||
if _remote:
|
||||
return ray.get(
|
||||
ray.remote(num_cpus=0)(run).remote(
|
||||
run_or_experiment,
|
||||
name,
|
||||
metric,
|
||||
mode,
|
||||
stop,
|
||||
time_budget_s,
|
||||
config,
|
||||
resources_per_trial,
|
||||
num_samples,
|
||||
local_dir,
|
||||
search_alg,
|
||||
scheduler,
|
||||
keep_checkpoints_num,
|
||||
checkpoint_score_attr,
|
||||
checkpoint_freq,
|
||||
checkpoint_at_end,
|
||||
verbose,
|
||||
progress_reporter,
|
||||
log_to_file,
|
||||
trial_name_creator,
|
||||
trial_dirname_creator,
|
||||
sync_config,
|
||||
export_formats,
|
||||
max_failures,
|
||||
fail_fast,
|
||||
restore,
|
||||
server_port,
|
||||
resume,
|
||||
queue_trials,
|
||||
reuse_actors,
|
||||
trial_executor,
|
||||
raise_on_failed_trial,
|
||||
callbacks,
|
||||
# Deprecated args
|
||||
loggers,
|
||||
ray_auto_init,
|
||||
run_errored_only,
|
||||
global_checkpoint_period,
|
||||
with_server,
|
||||
upload_dir,
|
||||
sync_to_cloud,
|
||||
sync_to_driver,
|
||||
sync_on_checkpoint,
|
||||
_remote=False))
|
||||
|
||||
all_start = time.time()
|
||||
if global_checkpoint_period:
|
||||
raise ValueError("global_checkpoint_period is deprecated. Set env var "
|
||||
@@ -509,7 +572,8 @@ def run_experiments(
|
||||
trial_executor: Optional[RayTrialExecutor] = None,
|
||||
raise_on_failed_trial: bool = True,
|
||||
concurrent: bool = True,
|
||||
callbacks: Optional[Sequence[Callback]] = None):
|
||||
callbacks: Optional[Sequence[Callback]] = None,
|
||||
_remote: bool = None):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Examples:
|
||||
@@ -523,6 +587,32 @@ def run_experiments(
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
"""
|
||||
if _remote is None:
|
||||
_remote = ray.util.client.ray.is_connected()
|
||||
|
||||
if _remote is True and trial_executor:
|
||||
raise ValueError("cannot use custom trial executor")
|
||||
|
||||
if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
|
||||
_ray_auto_init()
|
||||
|
||||
if _remote:
|
||||
return ray.get(
|
||||
ray.remote(num_cpus=0)(run_experiments).remote(
|
||||
experiments,
|
||||
scheduler,
|
||||
server_port,
|
||||
verbose,
|
||||
progress_reporter,
|
||||
resume,
|
||||
queue_trials,
|
||||
reuse_actors,
|
||||
trial_executor,
|
||||
raise_on_failed_trial,
|
||||
concurrent,
|
||||
callbacks,
|
||||
_remote=False))
|
||||
|
||||
# This is important to do this here
|
||||
# because it schematize the experiments
|
||||
# and it conducts the implicit registration.
|
||||
@@ -557,3 +647,14 @@ def run_experiments(
|
||||
scheduler=scheduler,
|
||||
callbacks=callbacks).trials
|
||||
return trials
|
||||
|
||||
|
||||
def _ray_auto_init():
|
||||
"""Initialize Ray unless already configured."""
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
|
||||
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
|
||||
elif not ray.is_initialized():
|
||||
logger.info("Initializing Ray automatically."
|
||||
"For cluster usage or custom Ray initialization, "
|
||||
"call `ray.init(...)` before `tune.run`.")
|
||||
ray.init()
|
||||
|
||||
Reference in New Issue
Block a user