mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 16:13:54 +08:00
[tune] Cluster Fault Tolerance (#3309)
This PR introduces cluster-level fault tolerance for Tune by checkpointing global state. This occurs with relatively high frequency and allows users to easily resume experiments when the cluster crashes. Note that this PR may affect automated workflows due to auto-prompting, but this is resolvable.
This commit is contained in:
@@ -9,7 +9,8 @@ import yaml
|
||||
|
||||
import ray
|
||||
from ray.test.cluster_utils import Cluster
|
||||
from ray.tune.config_parser import make_parser, resources_to_json
|
||||
from ray.tune.config_parser import make_parser
|
||||
from ray.tune.trial import resources_to_json
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
@@ -70,6 +71,10 @@ def create_parser(parser_creator=None):
|
||||
default="default",
|
||||
type=str,
|
||||
help="Name of the subdirectory under `local_dir` to put results in.")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="Whether to attempt to resume previous Tune experiments.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
@@ -138,7 +143,8 @@ def run(args, parser):
|
||||
run_experiments(
|
||||
experiments,
|
||||
scheduler=_make_scheduler(args),
|
||||
queue_trials=args.queue_trials)
|
||||
queue_trials=args.queue_trials,
|
||||
resume=args.resume)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -51,7 +51,9 @@ class Cluster(object):
|
||||
assert not self.connected
|
||||
redis_password = head_node_args.get("redis_password")
|
||||
output_info = ray.init(
|
||||
redis_address=self.redis_address, redis_password=redis_password)
|
||||
ignore_reinit_error=True,
|
||||
redis_address=self.redis_address,
|
||||
redis_password=redis_password)
|
||||
logger.info(output_info)
|
||||
self.connected = True
|
||||
|
||||
|
||||
@@ -11,40 +11,10 @@ from six import string_types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Resources, Trial
|
||||
from ray.tune.trial import Trial, json_to_resources
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if data is None or data == "null":
|
||||
return None
|
||||
if isinstance(data, string_types):
|
||||
data = json.loads(data)
|
||||
for k in data:
|
||||
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
|
||||
raise TuneError(
|
||||
"The field `{}` is no longer supported. Use `extra_cpu` "
|
||||
"or `extra_gpu` instead.".format(k))
|
||||
if k not in Resources._fields:
|
||||
raise TuneError(
|
||||
"Unknown resource type {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
if resources is None:
|
||||
return None
|
||||
return {
|
||||
"cpu": resources.cpu,
|
||||
"gpu": resources.gpu,
|
||||
"extra_cpu": resources.extra_cpu,
|
||||
"extra_gpu": resources.extra_gpu,
|
||||
}
|
||||
|
||||
|
||||
def make_parser(parser_creator=None, **kwargs):
|
||||
"""Returns a base argument parser for the ray.tune tool.
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.log_sync import validate_sync_function
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
@@ -122,7 +122,6 @@ class Experiment(object):
|
||||
restore=None,
|
||||
repeat=None,
|
||||
trial_resources=None):
|
||||
validate_sync_function(sync_function)
|
||||
if sync_function:
|
||||
assert upload_dir, "Need `upload_dir` if sync_function given."
|
||||
|
||||
@@ -134,16 +133,16 @@ class Experiment(object):
|
||||
resources_per_trial = trial_resources
|
||||
|
||||
spec = {
|
||||
"run": self._register_if_needed(run),
|
||||
"run": Experiment._register_if_needed(run),
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
|
||||
"upload_dir": upload_dir or "", # argparse converts None to "null"
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"custom_loggers": custom_loggers,
|
||||
"sync_function": sync_function or "", # See `upload_dir`.
|
||||
"sync_function": sync_function,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"max_failures": max_failures,
|
||||
@@ -180,7 +179,8 @@ class Experiment(object):
|
||||
raise TuneError("Improper argument from JSON: {}.".format(spec))
|
||||
return exp
|
||||
|
||||
def _register_if_needed(self, run_object):
|
||||
@classmethod
|
||||
def _register_if_needed(cls, run_object):
|
||||
"""Registers Trainable or Function at runtime.
|
||||
|
||||
Assumes already registered if run_object is a string. Does not
|
||||
|
||||
+20
-10
@@ -106,19 +106,19 @@ class UnifiedLogger(Logger):
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
logger.on_result(result)
|
||||
for _logger in self._loggers:
|
||||
_logger.on_result(result)
|
||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
self._log_syncer.sync_if_needed()
|
||||
|
||||
def close(self):
|
||||
for logger in self._loggers:
|
||||
logger.close()
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
|
||||
def flush(self):
|
||||
for logger in self._loggers:
|
||||
logger.flush()
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.wait()
|
||||
|
||||
@@ -142,7 +142,7 @@ class _JsonLogger(Logger):
|
||||
with open(config_pkl, "wb") as f:
|
||||
cloudpickle.dump(self.config, f)
|
||||
local_file = os.path.join(self.logdir, "result.json")
|
||||
self.local_out = open(local_file, "w")
|
||||
self.local_out = open(local_file, "a")
|
||||
|
||||
def on_result(self, result):
|
||||
json.dump(result, self, cls=_SafeFallbackEncoder)
|
||||
@@ -152,6 +152,9 @@ class _JsonLogger(Logger):
|
||||
self.local_out.write(b)
|
||||
self.local_out.flush()
|
||||
|
||||
def flush(self):
|
||||
self.local_out.flush()
|
||||
|
||||
def close(self):
|
||||
self.local_out.close()
|
||||
|
||||
@@ -182,7 +185,8 @@ class _TFLogger(Logger):
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
|
||||
]:
|
||||
del tmp[k] # not useful to tf log these
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to tf log these
|
||||
values = to_tf_values(tmp, ["ray", "tune"])
|
||||
train_stats = tf.Summary(value=values)
|
||||
t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
@@ -205,15 +209,21 @@ class _VisKitLogger(Logger):
|
||||
def _init(self):
|
||||
"""CSV outputted with Headers as first set of results."""
|
||||
# Note that we assume params.json was already created by JsonLogger
|
||||
self._file = open(os.path.join(self.logdir, "progress.csv"), "w")
|
||||
progress_file = os.path.join(self.logdir, "progress.csv")
|
||||
self._continuing = os.path.exists(progress_file)
|
||||
self._file = open(progress_file, "a")
|
||||
self._csv_out = None
|
||||
|
||||
def on_result(self, result):
|
||||
if self._csv_out is None:
|
||||
self._csv_out = csv.DictWriter(self._file, result.keys())
|
||||
self._csv_out.writeheader()
|
||||
if not self._continuing:
|
||||
self._csv_out.writeheader()
|
||||
self._csv_out.writerow(result.copy())
|
||||
|
||||
def flush(self):
|
||||
self._file.flush()
|
||||
|
||||
def close(self):
|
||||
self._file.close()
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
num_gpus=trial.resources.gpu)(trial._get_trainable_cls())
|
||||
|
||||
trial.init_logger()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
remote_logdir = trial.logdir
|
||||
|
||||
def logger_creator(config):
|
||||
@@ -60,7 +62,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None):
|
||||
prior_status = trial.status
|
||||
trial.status = Trial.RUNNING
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.runner = self._setup_runner(trial)
|
||||
if not self.restore(trial, checkpoint):
|
||||
return
|
||||
@@ -87,10 +89,13 @@ class RayTrialExecutor(TrialExecutor):
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
if error:
|
||||
trial.status = Trial.ERROR
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
else:
|
||||
trial.status = Trial.TERMINATED
|
||||
self.set_status(trial, Trial.TERMINATED)
|
||||
|
||||
try:
|
||||
trial.write_error_log(error_msg)
|
||||
@@ -103,13 +108,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
except Exception:
|
||||
logger.exception("Error stopping runner.")
|
||||
trial.status = Trial.ERROR
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
finally:
|
||||
trial.runner = None
|
||||
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
"""Starts the trial.
|
||||
|
||||
@@ -302,7 +304,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return True
|
||||
if trial.runner is None:
|
||||
logger.error("Unable to restore - no runner.")
|
||||
trial.status = Trial.ERROR
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
try:
|
||||
value = checkpoint.value
|
||||
@@ -316,5 +318,5 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error restoring runner.")
|
||||
trial.status = Trial.ERROR
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
|
||||
@@ -2,7 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import pytest
|
||||
try:
|
||||
import pytest_timeout
|
||||
@@ -10,14 +13,37 @@ except ImportError:
|
||||
pytest_timeout = None
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.test.cluster_utils import Cluster
|
||||
from ray.test.test_utils import run_string_as_driver_nonblocking
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
|
||||
class _Fail(tune.Trainable):
|
||||
"""Fails on the 4th iteration."""
|
||||
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 0}
|
||||
|
||||
def _train(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
if self.state["hi"] >= 4:
|
||||
assert False
|
||||
return {}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
@@ -36,6 +62,7 @@ def _start_new_cluster():
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
cluster = _start_new_cluster()
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
@@ -47,6 +74,7 @@ def start_connected_cluster():
|
||||
def start_connected_emptyhead_cluster():
|
||||
"""Starts head with no resources."""
|
||||
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
@@ -66,7 +94,6 @@ def start_connected_emptyhead_cluster():
|
||||
|
||||
def test_counting_resources(start_connected_cluster):
|
||||
"""Tests that Tune accounting is consistent with actual cluster."""
|
||||
|
||||
cluster = start_connected_cluster
|
||||
nodes = []
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
@@ -240,3 +267,231 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
|
||||
|
||||
def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
"""Tests that TrialRunner save/restore works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
cluster.add_node(resources=dict(CPU=1))
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=dirpath)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1
|
||||
}
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # start
|
||||
runner.step() # start2
|
||||
runner.step() # step
|
||||
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
|
||||
runner.checkpoint()
|
||||
|
||||
cluster.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
cluster = _start_new_cluster()
|
||||
runner = TrialRunner.restore(dirpath)
|
||||
runner.step() # start
|
||||
runner.step() # start2
|
||||
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
|
||||
assert all(t.status == Trial.TERMINATED for t in runner.get_trials())
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_cluster_down_full(start_connected_cluster, tmpdir):
|
||||
"""Tests that run_experiment restoring works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
|
||||
exp1_args = dict(
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
local_dir=dirpath,
|
||||
checkpoint_freq=1)
|
||||
exp2_args = dict(run="__fake", stop=dict(training_iteration=3))
|
||||
exp3_args = dict(
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
config=dict(mock_error=True))
|
||||
exp4_args = dict(
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
config=dict(mock_error=True),
|
||||
checkpoint_freq=1)
|
||||
all_experiments = {
|
||||
"exp1": exp1_args,
|
||||
"exp2": exp2_args,
|
||||
"exp3": exp3_args,
|
||||
"exp4": exp4_args
|
||||
}
|
||||
|
||||
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
|
||||
trials = tune.run_experiments(
|
||||
all_experiments, resume=True, raise_on_failed_trial=False)
|
||||
assert len(trials) == 4
|
||||
assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
script = """
|
||||
import time
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
ray.init(redis_address="{redis_address}")
|
||||
|
||||
kwargs = dict(
|
||||
run="PG",
|
||||
env="CartPole-v1",
|
||||
stop=dict(training_iteration=10),
|
||||
local_dir="{checkpoint_dir}",
|
||||
checkpoint_freq=1,
|
||||
max_failures=1)
|
||||
|
||||
tune.run_experiments(
|
||||
dict(experiment=kwargs),
|
||||
raise_on_failed_trial=False)
|
||||
""".format(
|
||||
redis_address=cluster.redis_address, checkpoint_dir=dirpath)
|
||||
run_string_as_driver_nonblocking(script)
|
||||
# Wait until the right checkpoint is saved.
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(50):
|
||||
if os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME)):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res is not None and last_res["training_iteration"]:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
# Restore properly from checkpoint
|
||||
trials2 = tune.run_experiments(
|
||||
{
|
||||
"experiment": {
|
||||
"run": "PG",
|
||||
"checkpoint_freq": 1,
|
||||
"local_dir": dirpath
|
||||
}
|
||||
},
|
||||
resume=True)
|
||||
assert all(t.status == Trial.TERMINATED for t in trials2)
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_cluster_interrupt(start_connected_cluster, tmpdir):
|
||||
"""Tests run_experiment on cluster shutdown even with atypical trial.
|
||||
|
||||
The trial fails on the 4th step, and the checkpointing happens on
|
||||
the 3rd step, so restoring should actually launch the trial again.
|
||||
"""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
script = """
|
||||
import time
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
ray.init(redis_address="{redis_address}")
|
||||
|
||||
{fail_class_code}
|
||||
|
||||
kwargs = dict(
|
||||
run={fail_class},
|
||||
stop=dict(training_iteration=5),
|
||||
local_dir="{checkpoint_dir}",
|
||||
checkpoint_freq=1,
|
||||
max_failures=1)
|
||||
|
||||
tune.run_experiments(
|
||||
dict(experiment=kwargs),
|
||||
raise_on_failed_trial=False)
|
||||
""".format(
|
||||
redis_address=cluster.redis_address,
|
||||
checkpoint_dir=dirpath,
|
||||
fail_class_code=inspect.getsource(_Fail),
|
||||
fail_class=_Fail.__name__)
|
||||
run_string_as_driver_nonblocking(script)
|
||||
|
||||
# Wait until the right checkpoint is saved.
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(50):
|
||||
if os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME)):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res is not None and last_res["training_iteration"] == 3:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
Experiment._register_if_needed(_Fail)
|
||||
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
assert trials[0].last_result["training_iteration"] == 3
|
||||
assert trials[0].status == Trial.PENDING
|
||||
|
||||
# Restore properly from checkpoint
|
||||
trials2 = tune.run_experiments(
|
||||
{
|
||||
"experiment": {
|
||||
"run": _Fail,
|
||||
"local_dir": dirpath,
|
||||
"checkpoint_freq": 1
|
||||
}
|
||||
},
|
||||
resume=True,
|
||||
raise_on_failed_trial=False)
|
||||
assert all(t.status == Trial.ERROR for t in trials2)
|
||||
assert {t.trial_id for t in trials2} == {t.trial_id for t in trials}
|
||||
cluster.shutdown()
|
||||
|
||||
@@ -3,7 +3,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -37,6 +39,7 @@ else:
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
ray.init(num_cpus=4, num_gpus=0)
|
||||
|
||||
def tearDown(self):
|
||||
@@ -542,6 +545,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
@@ -614,29 +618,6 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testSpecifyAlgorithm(self):
|
||||
"""Tests run_experiments works without specifying experiment."""
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
alg = BasicVariantGenerator()
|
||||
alg.add_configurations({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
}
|
||||
})
|
||||
trials = run_experiments(search_alg=alg)
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testAutoregisterTrainable(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
@@ -778,6 +759,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
@@ -981,6 +963,9 @@ def create_mock_components():
|
||||
|
||||
|
||||
class TrialRunnerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
@@ -1665,6 +1650,116 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
def testTrialSaveRestore(self):
|
||||
"""Creates different trials to test runner.checkpoint/restore."""
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
|
||||
trials = [
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="trial_terminate",
|
||||
stopping_criterion={"training_iteration": 1},
|
||||
checkpoint_freq=1)
|
||||
]
|
||||
runner.add_trial(trials[0])
|
||||
runner.step() # start
|
||||
runner.step()
|
||||
self.assertEquals(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
trials += [
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="trial_fail",
|
||||
stopping_criterion={"training_iteration": 3},
|
||||
checkpoint_freq=1,
|
||||
config={"mock_error": True})
|
||||
]
|
||||
runner.add_trial(trials[1])
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step()
|
||||
self.assertEquals(trials[1].status, Trial.ERROR)
|
||||
|
||||
trials += [
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="trial_succ",
|
||||
stopping_criterion={"training_iteration": 2},
|
||||
checkpoint_freq=1)
|
||||
]
|
||||
runner.add_trial(trials[2])
|
||||
runner.step()
|
||||
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEquals(trials[2].status, Trial.RUNNING)
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
for tid in ["trial_terminate", "trial_fail"]:
|
||||
original_trial = runner.get_trial(tid)
|
||||
restored_trial = runner2.get_trial(tid)
|
||||
self.assertEqual(original_trial.status, restored_trial.status)
|
||||
|
||||
restored_trial = runner2.get_trial("trial_succ")
|
||||
self.assertEqual(Trial.PENDING, restored_trial.status)
|
||||
|
||||
runner2.step()
|
||||
runner2.step()
|
||||
runner2.step()
|
||||
self.assertRaises(TuneError, runner2.step)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testTrialNoSave(self):
|
||||
"""Check that non-checkpointing trials are not saved."""
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
|
||||
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="non_checkpoint",
|
||||
stopping_criterion={"training_iteration": 2}))
|
||||
|
||||
while not all(t.status == Trial.TERMINATED
|
||||
for t in runner.get_trials()):
|
||||
runner.step()
|
||||
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="checkpoint",
|
||||
checkpoint_at_end=True,
|
||||
stopping_criterion={"training_iteration": 2}))
|
||||
|
||||
while not all(t.status == Trial.TERMINATED
|
||||
for t in runner.get_trials()):
|
||||
runner.step()
|
||||
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
"__fake",
|
||||
trial_id="pending",
|
||||
stopping_criterion={"training_iteration": 2}))
|
||||
|
||||
runner.step()
|
||||
runner.step()
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
new_trials = runner2.get_trials()
|
||||
self.assertEquals(len(new_trials), 3)
|
||||
self.assertTrue(
|
||||
runner2.get_trial("non_checkpoint").status == Trial.TERMINATED)
|
||||
self.assertTrue(
|
||||
runner2.get_trial("checkpoint").status == Trial.TERMINATED)
|
||||
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
|
||||
self.assertTrue(runner2.get_trial("pending").last_result is None)
|
||||
runner2.step()
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def testNestedSuggestion(self):
|
||||
|
||||
@@ -578,6 +578,7 @@ class _MockTrial(Trial):
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
self.trial_name = None
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
@@ -10,6 +10,7 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from six import string_types
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
@@ -216,10 +217,11 @@ class Trainable(object):
|
||||
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
"checkpoint_{}".format(self._iteration))
|
||||
os.makedirs(checkpoint_dir)
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, str):
|
||||
if isinstance(checkpoint, string_types):
|
||||
if (not checkpoint.startswith(checkpoint_dir)
|
||||
or checkpoint == checkpoint_dir):
|
||||
raise ValueError(
|
||||
@@ -237,7 +239,9 @@ class Trainable(object):
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
else:
|
||||
raise ValueError("Return value from `_save` must be dict or str.")
|
||||
raise ValueError(
|
||||
"`_save` must return a dict or string type: {}".format(
|
||||
str(type(checkpoint))))
|
||||
pickle.dump({
|
||||
"experiment_id": self._experiment_id,
|
||||
"iteration": self._iteration,
|
||||
|
||||
+109
-13
@@ -3,15 +3,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# For compatibility under py2 to consider unicode as str
|
||||
from six import string_types
|
||||
from numbers import Number
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.log_sync import validate_sync_function
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
@@ -19,7 +26,7 @@ from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
import ray.tune.registry
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
from ray.utils import random_string, binary_to_hex, hex_to_binary
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
MAX_LEN_IDENTIFIER = 130
|
||||
@@ -66,6 +73,36 @@ class Resources(
|
||||
return self.gpu + self.extra_gpu
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if data is None or data == "null":
|
||||
return None
|
||||
if isinstance(data, string_types):
|
||||
data = json.loads(data)
|
||||
for k in data:
|
||||
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
|
||||
raise TuneError(
|
||||
"The field `{}` is no longer supported. Use `extra_cpu` "
|
||||
"or `extra_gpu` instead.".format(k))
|
||||
if k not in Resources._fields:
|
||||
raise TuneError(
|
||||
"Unknown resource type {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
if resources is None:
|
||||
return None
|
||||
return {
|
||||
"cpu": resources.cpu,
|
||||
"gpu": resources.gpu,
|
||||
"extra_cpu": resources.extra_cpu,
|
||||
"extra_gpu": resources.extra_gpu,
|
||||
}
|
||||
|
||||
|
||||
def has_trainable(trainable_name):
|
||||
return ray.tune.registry._global_registry.contains(
|
||||
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
|
||||
@@ -133,12 +170,8 @@ class Trial(object):
|
||||
The args here take the same meaning as the command line flags defined
|
||||
in ray.tune.config_parser.
|
||||
"""
|
||||
if not has_trainable(trainable_name):
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa: F401
|
||||
if not has_trainable(trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
Trial._registration_check(trainable_name)
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config or {}
|
||||
@@ -149,14 +182,15 @@ class Trial(object):
|
||||
or self._get_trainable_cls().default_resource_request(self.config))
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.upload_dir = upload_dir
|
||||
self.trial_name_creator = trial_name_creator
|
||||
self.custom_loggers = custom_loggers
|
||||
self.sync_function = sync_function
|
||||
validate_sync_function(sync_function)
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self.last_update_time = -float("inf")
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self._checkpoint = Checkpoint(
|
||||
@@ -170,6 +204,18 @@ class Trial(object):
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
self.trial_name = None
|
||||
if trial_name_creator:
|
||||
self.trial_name = trial_name_creator(self)
|
||||
|
||||
@classmethod
|
||||
def _registration_check(cls, trainable_name):
|
||||
if not has_trainable(trainable_name):
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa: F401
|
||||
if not has_trainable(trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
return binary_to_hex(random_string())[:8]
|
||||
@@ -180,10 +226,14 @@ class Trial(object):
|
||||
if not self.result_logger:
|
||||
if not os.path.exists(self.local_dir):
|
||||
os.makedirs(self.local_dir)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
if not self.logdir:
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
elif not os.path.exists(self.logdir):
|
||||
os.makedirs(self.logdir)
|
||||
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config,
|
||||
self.logdir,
|
||||
@@ -307,6 +357,7 @@ class Trial(object):
|
||||
pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
self.last_result = result
|
||||
self.last_update_time = time.time()
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
def _get_trainable_cls(self):
|
||||
@@ -327,8 +378,8 @@ class Trial(object):
|
||||
|
||||
Can be overriden with a custom string creator.
|
||||
"""
|
||||
if self.trial_name_creator:
|
||||
return self.trial_name_creator(self)
|
||||
if self.trial_name:
|
||||
return self.trial_name
|
||||
|
||||
if "env" in self.config:
|
||||
env = self.config["env"]
|
||||
@@ -340,3 +391,48 @@ class Trial(object):
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier.replace("/", "_")
|
||||
|
||||
def __getstate__(self):
|
||||
"""Memento generator for Trial.
|
||||
|
||||
Sets RUNNING trials to PENDING, and flushes the result logger.
|
||||
Note this can only occur if the trial holds a DISK checkpoint.
|
||||
"""
|
||||
assert self._checkpoint.storage == Checkpoint.DISK, (
|
||||
"Checkpoint must not be in-memory.")
|
||||
state = self.__dict__.copy()
|
||||
state["resources"] = resources_to_json(self.resources)
|
||||
|
||||
pickle_data = {
|
||||
"_checkpoint": self._checkpoint,
|
||||
"config": self.config,
|
||||
"custom_loggers": self.custom_loggers,
|
||||
"sync_function": self.sync_function
|
||||
}
|
||||
|
||||
for key, value in pickle_data.items():
|
||||
state[key] = binary_to_hex(cloudpickle.dumps(value))
|
||||
|
||||
state["runner"] = None
|
||||
state["result_logger"] = None
|
||||
if self.status == Trial.RUNNING:
|
||||
state["status"] = Trial.PENDING
|
||||
if self.result_logger:
|
||||
self.result_logger.flush()
|
||||
state["__logger_started__"] = True
|
||||
else:
|
||||
state["__logger_started__"] = False
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def __setstate__(self, state):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
for key in [
|
||||
"_checkpoint", "config", "custom_loggers", "sync_function"
|
||||
]:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
self.__dict__.update(state)
|
||||
Trial._registration_check(self.trainable_name)
|
||||
if logger_started:
|
||||
self.init_logger()
|
||||
|
||||
@@ -25,6 +25,41 @@ class TrialExecutor(object):
|
||||
automatic scale-up.
|
||||
"""
|
||||
self._queue_trials = queue_trials
|
||||
self._cached_trial_state = {}
|
||||
|
||||
def set_status(self, trial, status):
|
||||
"""Sets status and checkpoints metadata if needed.
|
||||
|
||||
Only checkpoints metadata if trial status is a terminal condition.
|
||||
PENDING, PAUSED, and RUNNING switches have checkpoints taken care of
|
||||
in the TrialRunner.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
status (Trial.status): Status to set trial to.
|
||||
"""
|
||||
trial.status = status
|
||||
if status in [Trial.TERMINATED, Trial.ERROR]:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
|
||||
def try_checkpoint_metadata(self, trial):
|
||||
"""Checkpoints metadata.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
"""
|
||||
if trial._checkpoint.storage == Checkpoint.MEMORY:
|
||||
logger.debug("Not saving data for trial w/ memory checkpoint.")
|
||||
return
|
||||
try:
|
||||
logger.debug("Saving trial metadata.")
|
||||
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
|
||||
except Exception:
|
||||
logger.exception("Error checkpointing trial metadata.")
|
||||
|
||||
def get_checkpoints(self):
|
||||
"""Returns a copy of mapping of the trial ID to pickled metadata."""
|
||||
return self._cached_trial_state.copy()
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
@@ -71,15 +106,15 @@ class TrialExecutor(object):
|
||||
try:
|
||||
self.save(trial, Checkpoint.MEMORY)
|
||||
self.stop_trial(trial, stop_logger=False)
|
||||
trial.status = Trial.PAUSED
|
||||
self.set_status(trial, Trial.PAUSED)
|
||||
except Exception:
|
||||
logger.exception("Error pausing runner.")
|
||||
trial.status = Trial.ERROR
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
|
||||
def unpause_trial(self, trial):
|
||||
"""Sets PAUSED trial to pending to allow scheduler to start."""
|
||||
assert trial.status == Trial.PAUSED, trial.status
|
||||
trial.status = Trial.PENDING
|
||||
self.set_status(trial, Trial.PENDING)
|
||||
|
||||
def resume_trial(self, trial):
|
||||
"""Resumes PAUSED trials. This is a blocking call."""
|
||||
|
||||
+125
-11
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -49,10 +50,13 @@ class TrialRunner(object):
|
||||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
CKPT_FILE_NAME = "experiment_state.json"
|
||||
|
||||
def __init__(self,
|
||||
search_alg,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
metadata_checkpoint_dir=None,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
queue_trials=False,
|
||||
@@ -64,6 +68,8 @@ class TrialRunner(object):
|
||||
Trial objects.
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
metadata_checkpoint_dir (str): Path where
|
||||
global checkpoints are stored and restored from.
|
||||
server_port (int): Port number for launching TuneServer
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
@@ -75,7 +81,6 @@ class TrialRunner(object):
|
||||
"""
|
||||
self._search_alg = search_alg
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
self.trial_executor = trial_executor or \
|
||||
RayTrialExecutor(queue_trials=queue_trials)
|
||||
|
||||
@@ -84,13 +89,93 @@ class TrialRunner(object):
|
||||
self._global_time_limit = float(
|
||||
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')))
|
||||
self._total_time = 0
|
||||
self._server = None
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, server_port)
|
||||
self._stop_queue = []
|
||||
self._iteration = 0
|
||||
self._verbose = verbose
|
||||
self._queue_trials = queue_trials
|
||||
|
||||
self._server = None
|
||||
self._server_port = server_port
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
self._trials = []
|
||||
self._stop_queue = []
|
||||
self._metadata_checkpoint_dir = metadata_checkpoint_dir
|
||||
|
||||
def checkpoint(self):
|
||||
"""Saves execution state to `self._metadata_checkpoint_dir`."""
|
||||
if not self._metadata_checkpoint_dir:
|
||||
return
|
||||
metadata_checkpoint_dir = self._metadata_checkpoint_dir
|
||||
if not os.path.exists(metadata_checkpoint_dir):
|
||||
os.makedirs(metadata_checkpoint_dir)
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
"runner_data": self.__getstate__()
|
||||
}
|
||||
tmp_file_name = os.path.join(metadata_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2)
|
||||
|
||||
os.rename(
|
||||
tmp_file_name,
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME))
|
||||
return metadata_checkpoint_dir
|
||||
|
||||
@classmethod
|
||||
def restore(cls,
|
||||
metadata_checkpoint_dir,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
trial_executor=None):
|
||||
"""Restores all checkpointed trials from previous run.
|
||||
|
||||
Requires user to manually re-register their objects. Also stops
|
||||
all ongoing trials.
|
||||
|
||||
Args:
|
||||
metadata_checkpoint_dir (str): Path to metadata checkpoints.
|
||||
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
|
||||
BasicVariantGenerator.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment.
|
||||
trial_executor (TrialExecutor): Manage the execution of trials.
|
||||
|
||||
Returns:
|
||||
runner (TrialRunner): A TrialRunner to resume experiments from.
|
||||
"""
|
||||
with open(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME), "r") as f:
|
||||
runner_state = json.load(f)
|
||||
|
||||
logger.warning("".join([
|
||||
"Attempting to resume experiment from {}. ".format(
|
||||
metadata_checkpoint_dir), "This feature is experimental, "
|
||||
"and may not work with all search algorithms. ",
|
||||
"This will ignore any new changes to the specification."
|
||||
]))
|
||||
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
runner = TrialRunner(
|
||||
search_alg or BasicVariantGenerator(),
|
||||
scheduler=scheduler,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
runner.__setstate__(runner_state["runner_data"])
|
||||
|
||||
trials = []
|
||||
for trial_cp in runner_state["checkpoints"]:
|
||||
new_trial = Trial(trial_cp["trainable_name"])
|
||||
new_trial.__setstate__(trial_cp)
|
||||
trials += [new_trial]
|
||||
for trial in sorted(
|
||||
trials, key=lambda t: t.last_update_time, reverse=True):
|
||||
runner.add_trial(trial)
|
||||
return runner
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
@@ -136,6 +221,12 @@ class TrialRunner(object):
|
||||
"There are paused trials, but no more pending "
|
||||
"trials with sufficient resources.")
|
||||
|
||||
try:
|
||||
self.checkpoint()
|
||||
except Exception:
|
||||
logger.exception("Trial Runner checkpointing failed.")
|
||||
self._iteration += 1
|
||||
|
||||
if self._server:
|
||||
self._process_requests()
|
||||
|
||||
@@ -165,6 +256,7 @@ class TrialRunner(object):
|
||||
"""
|
||||
trial.set_verbose(self._verbose)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
self._trials.append(trial)
|
||||
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
@@ -279,14 +371,14 @@ class TrialRunner(object):
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
self._checkpoint_if_needed(trial)
|
||||
self._checkpoint_trial_if_needed(trial)
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
# Checkpoint before ending the trial
|
||||
# if checkpoint_at_end experiment option is set to True
|
||||
self._checkpoint_if_needed(trial)
|
||||
self._checkpoint_trial_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
@@ -304,12 +396,13 @@ class TrialRunner(object):
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _checkpoint_if_needed(self, trial):
|
||||
def _checkpoint_trial_if_needed(self, trial):
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
if trial.should_checkpoint():
|
||||
# Save trial runtime if possible
|
||||
if hasattr(trial, "runner") and trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
"""Tries to recover trial.
|
||||
@@ -344,11 +437,11 @@ class TrialRunner(object):
|
||||
def _requeue_trial(self, trial):
|
||||
"""Notification to TrialScheduler and requeue trial.
|
||||
|
||||
This does not notify the SearchAlgorithm because
|
||||
the function evaluation is still in progress.
|
||||
This does not notify the SearchAlgorithm because the function
|
||||
evaluation is still in progress.
|
||||
"""
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
trial.status = Trial.PENDING
|
||||
self.trial_executor.set_status(trial, Trial.PENDING)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
|
||||
def _update_trial_queue(self, blocking=False, timeout=600):
|
||||
@@ -417,3 +510,24 @@ class TrialRunner(object):
|
||||
error = True
|
||||
|
||||
self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Gets state for trial.
|
||||
|
||||
Note that this is not used as a pickling override as
|
||||
does not have all fields.
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
for k in [
|
||||
"_trials", "_stop_queue", "_server", "_search_alg",
|
||||
"_scheduler_alg", "trial_executor"
|
||||
]:
|
||||
del state[k]
|
||||
state["launch_web_server"] = bool(self._server)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
launch_web_server = state.pop("launch_web_server")
|
||||
self.__dict__.update(state)
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
+77
-15
@@ -2,10 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
@@ -32,12 +35,30 @@ def _make_scheduler(args):
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments=None,
|
||||
def _find_checkpoint_dir(exp_list):
|
||||
assert exp_list, "Experiments must be specified via `run_experiments`"
|
||||
exp = exp_list[0]
|
||||
# TODO(rliaw): Make sure this is resolved earlier.
|
||||
return os.path.join(exp.spec["local_dir"], exp.name)
|
||||
|
||||
|
||||
def try_restore_runner(checkpoint_dir, search_alg, scheduler, trial_executor):
|
||||
new_runner = None
|
||||
try:
|
||||
new_runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler,
|
||||
trial_executor)
|
||||
except Exception:
|
||||
logger.exception("Runner restore failed. Restarting experiment.")
|
||||
return new_runner
|
||||
|
||||
|
||||
def run_experiments(experiments,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
resume=None,
|
||||
queue_trials=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True):
|
||||
@@ -55,6 +76,9 @@ def run_experiments(experiments=None,
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): How much output should be printed for each trial.
|
||||
resume (bool|None): If checkpoint exists, the experiment will
|
||||
resume from there. If resume is None, Tune will prompt if
|
||||
checkpoint detected.
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
@@ -83,26 +107,64 @@ def run_experiments(experiments=None,
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
"""
|
||||
# This is important to do this here
|
||||
# because it schematize the experiments
|
||||
# and it conducts the implicit registration.
|
||||
experiments = convert_to_experiment_list(experiments)
|
||||
checkpoint_dir = _find_checkpoint_dir(experiments)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
runner = None
|
||||
restore = False
|
||||
|
||||
if search_alg is None:
|
||||
search_alg = BasicVariantGenerator()
|
||||
# TUNE_RESUME_PROMPT_OFF is for testing purposes and defaults
|
||||
# `resume=False.`
|
||||
if os.environ.get("TUNE_RESUME_PROMPT_OFF"):
|
||||
resume = resume or False
|
||||
|
||||
search_alg.add_configurations(experiments)
|
||||
if os.path.exists(
|
||||
os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
if resume:
|
||||
restore = True
|
||||
elif resume is None:
|
||||
msg = ("Found incomplete experiment at {}. "
|
||||
"Would you like to resume it?".format(checkpoint_dir))
|
||||
restore = click.confirm(msg, default=True)
|
||||
if restore:
|
||||
logger.info("Tip: to always resume, "
|
||||
"pass resume=True to run_experiments()")
|
||||
else:
|
||||
logger.info("Tip: to always start a new experiment, "
|
||||
"pass resume=False to run_experiments()")
|
||||
else:
|
||||
logger.info(
|
||||
"Did not find checkpoint file in {}.".format(checkpoint_dir))
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg,
|
||||
scheduler=scheduler,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
queue_trials=queue_trials,
|
||||
trial_executor=trial_executor)
|
||||
if restore:
|
||||
runner = try_restore_runner(checkpoint_dir, search_alg, scheduler,
|
||||
trial_executor)
|
||||
else:
|
||||
logger.info("Starting a new experiment.")
|
||||
|
||||
if not runner:
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
if search_alg is None:
|
||||
search_alg = BasicVariantGenerator()
|
||||
|
||||
search_alg.add_configurations(experiments)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg,
|
||||
scheduler=scheduler,
|
||||
metadata_checkpoint_dir=checkpoint_dir,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
queue_trials=queue_trials,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
logger.info(runner.debug_string(max_debug=99999))
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
|
||||
Reference in New Issue
Block a user