[tune] Cluster Fault Tolerance (#3309)

This PR introduces cluster-level fault tolerance for Tune by checkpointing global state. This occurs with relatively high frequency and allows users to easily resume experiments when the cluster crashes.

Note that this PR may affect automated workflows due to auto-prompting, but this is resolvable.
This commit is contained in:
Richard Liaw
2018-12-29 11:42:25 +08:00
committed by GitHub
parent 382b138fc7
commit aad3c50e2d
16 changed files with 806 additions and 128 deletions
+8 -2
View File
@@ -9,7 +9,8 @@ import yaml
import ray
from ray.test.cluster_utils import Cluster
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.config_parser import make_parser
from ray.tune.trial import resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments
EXAMPLE_USAGE = """
@@ -70,6 +71,10 @@ def create_parser(parser_creator=None):
default="default",
type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume previous Tune experiments.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
@@ -138,7 +143,8 @@ def run(args, parser):
run_experiments(
experiments,
scheduler=_make_scheduler(args),
queue_trials=args.queue_trials)
queue_trials=args.queue_trials,
resume=args.resume)
if __name__ == "__main__":
+3 -1
View File
@@ -51,7 +51,9 @@ class Cluster(object):
assert not self.connected
redis_password = head_node_args.get("redis_password")
output_info = ray.init(
redis_address=self.redis_address, redis_password=redis_password)
ignore_reinit_error=True,
redis_address=self.redis_address,
redis_password=redis_password)
logger.info(output_info)
self.connected = True
+1 -31
View File
@@ -11,40 +11,10 @@ from six import string_types
from ray.tune import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Resources, Trial
from ray.tune.trial import Trial, json_to_resources
from ray.tune.logger import _SafeFallbackEncoder
def json_to_resources(data):
if data is None or data == "null":
return None
if isinstance(data, string_types):
data = json.loads(data)
for k in data:
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
raise TuneError(
"The field `{}` is no longer supported. Use `extra_cpu` "
"or `extra_gpu` instead.".format(k))
if k not in Resources._fields:
raise TuneError(
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))
def resources_to_json(resources):
if resources is None:
return None
return {
"cpu": resources.cpu,
"gpu": resources.gpu,
"extra_cpu": resources.extra_cpu,
"extra_gpu": resources.extra_gpu,
}
def make_parser(parser_creator=None, **kwargs):
"""Returns a base argument parser for the ray.tune tool.
+6 -6
View File
@@ -4,11 +4,11 @@ from __future__ import print_function
import copy
import logging
import os
import six
import types
from ray.tune.error import TuneError
from ray.tune.log_sync import validate_sync_function
from ray.tune.registry import register_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR
@@ -122,7 +122,6 @@ class Experiment(object):
restore=None,
repeat=None,
trial_resources=None):
validate_sync_function(sync_function)
if sync_function:
assert upload_dir, "Need `upload_dir` if sync_function given."
@@ -134,16 +133,16 @@ class Experiment(object):
resources_per_trial = trial_resources
spec = {
"run": self._register_if_needed(run),
"run": Experiment._register_if_needed(run),
"stop": stop or {},
"config": config or {},
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
"upload_dir": upload_dir or "", # argparse converts None to "null"
"trial_name_creator": trial_name_creator,
"custom_loggers": custom_loggers,
"sync_function": sync_function or "", # See `upload_dir`.
"sync_function": sync_function,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"max_failures": max_failures,
@@ -180,7 +179,8 @@ class Experiment(object):
raise TuneError("Improper argument from JSON: {}.".format(spec))
return exp
def _register_if_needed(self, run_object):
@classmethod
def _register_if_needed(cls, run_object):
"""Registers Trainable or Function at runtime.
Assumes already registered if run_object is a string. Does not
+20 -10
View File
@@ -106,19 +106,19 @@ class UnifiedLogger(Logger):
self.logdir, self.uri, sync_function=self._sync_function)
def on_result(self, result):
for logger in self._loggers:
logger.on_result(result)
for _logger in self._loggers:
_logger.on_result(result)
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_if_needed()
def close(self):
for logger in self._loggers:
logger.close()
for _logger in self._loggers:
_logger.close()
self._log_syncer.sync_now(force=True)
def flush(self):
for logger in self._loggers:
logger.flush()
for _logger in self._loggers:
_logger.flush()
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()
@@ -142,7 +142,7 @@ class _JsonLogger(Logger):
with open(config_pkl, "wb") as f:
cloudpickle.dump(self.config, f)
local_file = os.path.join(self.logdir, "result.json")
self.local_out = open(local_file, "w")
self.local_out = open(local_file, "a")
def on_result(self, result):
json.dump(result, self, cls=_SafeFallbackEncoder)
@@ -152,6 +152,9 @@ class _JsonLogger(Logger):
self.local_out.write(b)
self.local_out.flush()
def flush(self):
self.local_out.flush()
def close(self):
self.local_out.close()
@@ -182,7 +185,8 @@ class _TFLogger(Logger):
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
]:
del tmp[k] # not useful to tf log these
if k in tmp:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
@@ -205,15 +209,21 @@ class _VisKitLogger(Logger):
def _init(self):
"""CSV outputted with Headers as first set of results."""
# Note that we assume params.json was already created by JsonLogger
self._file = open(os.path.join(self.logdir, "progress.csv"), "w")
progress_file = os.path.join(self.logdir, "progress.csv")
self._continuing = os.path.exists(progress_file)
self._file = open(progress_file, "a")
self._csv_out = None
def on_result(self, result):
if self._csv_out is None:
self._csv_out = csv.DictWriter(self._file, result.keys())
self._csv_out.writeheader()
if not self._continuing:
self._csv_out.writeheader()
self._csv_out.writerow(result.copy())
def flush(self):
self._file.flush()
def close(self):
self._file.close()
+11 -9
View File
@@ -38,6 +38,8 @@ class RayTrialExecutor(TrialExecutor):
num_gpus=trial.resources.gpu)(trial._get_trainable_cls())
trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
remote_logdir = trial.logdir
def logger_creator(config):
@@ -60,7 +62,7 @@ class RayTrialExecutor(TrialExecutor):
def _start_trial(self, trial, checkpoint=None):
prior_status = trial.status
trial.status = Trial.RUNNING
self.set_status(trial, Trial.RUNNING)
trial.runner = self._setup_runner(trial)
if not self.restore(trial, checkpoint):
return
@@ -87,10 +89,13 @@ class RayTrialExecutor(TrialExecutor):
stop_logger (bool): Whether to shut down the trial logger.
"""
if stop_logger:
trial.close_logger()
if error:
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
else:
trial.status = Trial.TERMINATED
self.set_status(trial, Trial.TERMINATED)
try:
trial.write_error_log(error_msg)
@@ -103,13 +108,10 @@ class RayTrialExecutor(TrialExecutor):
stop_tasks, num_returns=2, timeout=250)
except Exception:
logger.exception("Error stopping runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
finally:
trial.runner = None
if stop_logger:
trial.close_logger()
def start_trial(self, trial, checkpoint=None):
"""Starts the trial.
@@ -302,7 +304,7 @@ class RayTrialExecutor(TrialExecutor):
return True
if trial.runner is None:
logger.error("Unable to restore - no runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
return False
try:
value = checkpoint.value
@@ -316,5 +318,5 @@ class RayTrialExecutor(TrialExecutor):
return True
except Exception:
logger.exception("Error restoring runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
return False
+256 -1
View File
@@ -2,7 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import json
import time
import os
import pytest
try:
import pytest_timeout
@@ -10,14 +13,37 @@ except ImportError:
pytest_timeout = None
import ray
from ray import tune
from ray.rllib import _register_all
from ray.test.cluster_utils import Cluster
from ray.test.test_utils import run_string_as_driver_nonblocking
from ray.tune.error import TuneError
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import BasicVariantGenerator
class _Fail(tune.Trainable):
"""Fails on the 4th iteration."""
def _setup(self, config):
self.state = {"hi": 0}
def _train(self):
self.state["hi"] += 1
time.sleep(0.5)
if self.state["hi"] >= 4:
assert False
return {}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
def _start_new_cluster():
cluster = Cluster(
initialize_head=True,
@@ -36,6 +62,7 @@ def _start_new_cluster():
@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
cluster = _start_new_cluster()
yield cluster
# The code after the yield will run as teardown code.
@@ -47,6 +74,7 @@ def start_connected_cluster():
def start_connected_emptyhead_cluster():
"""Starts head with no resources."""
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
cluster = Cluster(
initialize_head=True,
connect=True,
@@ -66,7 +94,6 @@ def start_connected_emptyhead_cluster():
def test_counting_resources(start_connected_cluster):
"""Tests that Tune accounting is consistent with actual cluster."""
cluster = start_connected_cluster
nodes = []
assert ray.global_state.cluster_resources()["CPU"] == 1
@@ -240,3 +267,231 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
with pytest.raises(TuneError):
runner.step()
def test_cluster_down_simple(start_connected_cluster, tmpdir):
"""Tests that TrialRunner save/restore works on cluster shutdown."""
cluster = start_connected_cluster
cluster.add_node(resources=dict(CPU=1))
assert cluster.wait_for_nodes()
dirpath = str(tmpdir)
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=dirpath)
kwargs = {
"stopping_criterion": {
"training_iteration": 2
},
"checkpoint_freq": 1,
"max_failures": 1
}
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step() # start
runner.step() # start2
runner.step() # step
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
runner.checkpoint()
cluster.shutdown()
ray.shutdown()
cluster = _start_new_cluster()
runner = TrialRunner.restore(dirpath)
runner.step() # start
runner.step() # start2
for i in range(3):
runner.step()
with pytest.raises(TuneError):
runner.step()
assert all(t.status == Trial.TERMINATED for t in runner.get_trials())
cluster.shutdown()
def test_cluster_down_full(start_connected_cluster, tmpdir):
"""Tests that run_experiment restoring works on cluster shutdown."""
cluster = start_connected_cluster
dirpath = str(tmpdir)
exp1_args = dict(
run="__fake",
stop=dict(training_iteration=3),
local_dir=dirpath,
checkpoint_freq=1)
exp2_args = dict(run="__fake", stop=dict(training_iteration=3))
exp3_args = dict(
run="__fake",
stop=dict(training_iteration=3),
config=dict(mock_error=True))
exp4_args = dict(
run="__fake",
stop=dict(training_iteration=3),
config=dict(mock_error=True),
checkpoint_freq=1)
all_experiments = {
"exp1": exp1_args,
"exp2": exp2_args,
"exp3": exp3_args,
"exp4": exp4_args
}
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
trials = tune.run_experiments(
all_experiments, resume=True, raise_on_failed_trial=False)
assert len(trials) == 4
assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
cluster.shutdown()
def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
cluster = start_connected_cluster
dirpath = str(tmpdir)
script = """
import time
import ray
from ray import tune
ray.init(redis_address="{redis_address}")
kwargs = dict(
run="PG",
env="CartPole-v1",
stop=dict(training_iteration=10),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
max_failures=1)
tune.run_experiments(
dict(experiment=kwargs),
raise_on_failed_trial=False)
""".format(
redis_address=cluster.redis_address, checkpoint_dir=dirpath)
run_string_as_driver_nonblocking(script)
# Wait until the right checkpoint is saved.
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res is not None and last_res["training_iteration"]:
break
time.sleep(0.2)
if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
cluster.wait_for_nodes()
# Restore properly from checkpoint
trials2 = tune.run_experiments(
{
"experiment": {
"run": "PG",
"checkpoint_freq": 1,
"local_dir": dirpath
}
},
resume=True)
assert all(t.status == Trial.TERMINATED for t in trials2)
cluster.shutdown()
def test_cluster_interrupt(start_connected_cluster, tmpdir):
"""Tests run_experiment on cluster shutdown even with atypical trial.
The trial fails on the 4th step, and the checkpointing happens on
the 3rd step, so restoring should actually launch the trial again.
"""
cluster = start_connected_cluster
dirpath = str(tmpdir)
script = """
import time
import ray
from ray import tune
ray.init(redis_address="{redis_address}")
{fail_class_code}
kwargs = dict(
run={fail_class},
stop=dict(training_iteration=5),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
max_failures=1)
tune.run_experiments(
dict(experiment=kwargs),
raise_on_failed_trial=False)
""".format(
redis_address=cluster.redis_address,
checkpoint_dir=dirpath,
fail_class_code=inspect.getsource(_Fail),
fail_class=_Fail.__name__)
run_string_as_driver_nonblocking(script)
# Wait until the right checkpoint is saved.
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res is not None and last_res["training_iteration"] == 3:
break
time.sleep(0.2)
if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
Experiment._register_if_needed(_Fail)
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
assert trials[0].last_result["training_iteration"] == 3
assert trials[0].status == Trial.PENDING
# Restore properly from checkpoint
trials2 = tune.run_experiments(
{
"experiment": {
"run": _Fail,
"local_dir": dirpath,
"checkpoint_freq": 1
}
},
resume=True,
raise_on_failed_trial=False)
assert all(t.status == Trial.ERROR for t in trials2)
assert {t.trial_id for t in trials2} == {t.trial_id for t in trials}
cluster.shutdown()
+118 -23
View File
@@ -3,7 +3,9 @@ from __future__ import division
from __future__ import print_function
import os
import shutil
import sys
import tempfile
import time
import unittest
@@ -37,6 +39,7 @@ else:
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
ray.init(num_cpus=4, num_gpus=0)
def tearDown(self):
@@ -542,6 +545,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
class RunExperimentTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
ray.init()
def tearDown(self):
@@ -614,29 +618,6 @@ class RunExperimentTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testSpecifyAlgorithm(self):
"""Tests run_experiments works without specifying experiment."""
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
alg = BasicVariantGenerator()
alg.add_configurations({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
}
})
trials = run_experiments(search_alg=alg)
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testAutoregisterTrainable(self):
def train(config, reporter):
for i in range(100):
@@ -778,6 +759,7 @@ class RunExperimentTest(unittest.TestCase):
class VariantGeneratorTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
ray.init()
def tearDown(self):
@@ -981,6 +963,9 @@ def create_mock_components():
class TrialRunnerTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_RESUME_PROMPT_OFF"] = "True"
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
@@ -1665,6 +1650,116 @@ class TrialRunnerTest(unittest.TestCase):
self.assertTrue(searcher.is_finished())
self.assertRaises(TuneError, runner.step)
def testTrialSaveRestore(self):
"""Creates different trials to test runner.checkpoint/restore."""
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
trials = [
Trial(
"__fake",
trial_id="trial_terminate",
stopping_criterion={"training_iteration": 1},
checkpoint_freq=1)
]
runner.add_trial(trials[0])
runner.step() # start
runner.step()
self.assertEquals(trials[0].status, Trial.TERMINATED)
trials += [
Trial(
"__fake",
trial_id="trial_fail",
stopping_criterion={"training_iteration": 3},
checkpoint_freq=1,
config={"mock_error": True})
]
runner.add_trial(trials[1])
runner.step()
runner.step()
runner.step()
self.assertEquals(trials[1].status, Trial.ERROR)
trials += [
Trial(
"__fake",
trial_id="trial_succ",
stopping_criterion={"training_iteration": 2},
checkpoint_freq=1)
]
runner.add_trial(trials[2])
runner.step()
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
self.assertEquals(trials[2].status, Trial.RUNNING)
runner2 = TrialRunner.restore(tmpdir)
for tid in ["trial_terminate", "trial_fail"]:
original_trial = runner.get_trial(tid)
restored_trial = runner2.get_trial(tid)
self.assertEqual(original_trial.status, restored_trial.status)
restored_trial = runner2.get_trial("trial_succ")
self.assertEqual(Trial.PENDING, restored_trial.status)
runner2.step()
runner2.step()
runner2.step()
self.assertRaises(TuneError, runner2.step)
shutil.rmtree(tmpdir)
def testTrialNoSave(self):
"""Check that non-checkpointing trials are not saved."""
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
runner.add_trial(
Trial(
"__fake",
trial_id="non_checkpoint",
stopping_criterion={"training_iteration": 2}))
while not all(t.status == Trial.TERMINATED
for t in runner.get_trials()):
runner.step()
runner.add_trial(
Trial(
"__fake",
trial_id="checkpoint",
checkpoint_at_end=True,
stopping_criterion={"training_iteration": 2}))
while not all(t.status == Trial.TERMINATED
for t in runner.get_trials()):
runner.step()
runner.add_trial(
Trial(
"__fake",
trial_id="pending",
stopping_criterion={"training_iteration": 2}))
runner.step()
runner.step()
runner2 = TrialRunner.restore(tmpdir)
new_trials = runner2.get_trials()
self.assertEquals(len(new_trials), 3)
self.assertTrue(
runner2.get_trial("non_checkpoint").status == Trial.TERMINATED)
self.assertTrue(
runner2.get_trial("checkpoint").status == Trial.TERMINATED)
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
self.assertTrue(runner2.get_trial("pending").last_result is None)
runner2.step()
shutil.rmtree(tmpdir)
class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):
@@ -578,6 +578,7 @@ class _MockTrial(Trial):
self.logger_running = False
self.restored_checkpoint = None
self.resources = Resources(1, 0)
self.trial_name = None
class PopulationBasedTestingSuite(unittest.TestCase):
+7 -3
View File
@@ -10,6 +10,7 @@ import io
import logging
import os
import pickle
from six import string_types
import shutil
import tempfile
import time
@@ -216,10 +217,11 @@ class Trainable(object):
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
os.makedirs(checkpoint_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint = self._save(checkpoint_dir)
saved_as_dict = False
if isinstance(checkpoint, str):
if isinstance(checkpoint, string_types):
if (not checkpoint.startswith(checkpoint_dir)
or checkpoint == checkpoint_dir):
raise ValueError(
@@ -237,7 +239,9 @@ class Trainable(object):
with open(checkpoint_path, "wb") as f:
pickle.dump(checkpoint, f)
else:
raise ValueError("Return value from `_save` must be dict or str.")
raise ValueError(
"`_save` must return a dict or string type: {}".format(
str(type(checkpoint))))
pickle.dump({
"experiment_id": self._experiment_id,
"iteration": self._iteration,
+109 -13
View File
@@ -3,15 +3,22 @@ from __future__ import division
from __future__ import print_function
from collections import namedtuple
import ray.cloudpickle as cloudpickle
import copy
from datetime import datetime
import logging
import json
import time
import tempfile
import os
# For compatibility under py2 to consider unicode as str
from six import string_types
from numbers import Number
import ray
from ray.tune import TuneError
from ray.tune.log_sync import validate_sync_function
from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not
@@ -19,7 +26,7 @@ from ray.tune.logger import pretty_print, UnifiedLogger
import ray.tune.registry
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
from ray.utils import random_string, binary_to_hex
from ray.utils import random_string, binary_to_hex, hex_to_binary
DEBUG_PRINT_INTERVAL = 5
MAX_LEN_IDENTIFIER = 130
@@ -66,6 +73,36 @@ class Resources(
return self.gpu + self.extra_gpu
def json_to_resources(data):
if data is None or data == "null":
return None
if isinstance(data, string_types):
data = json.loads(data)
for k in data:
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
raise TuneError(
"The field `{}` is no longer supported. Use `extra_cpu` "
"or `extra_gpu` instead.".format(k))
if k not in Resources._fields:
raise TuneError(
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))
def resources_to_json(resources):
if resources is None:
return None
return {
"cpu": resources.cpu,
"gpu": resources.gpu,
"extra_cpu": resources.extra_cpu,
"extra_gpu": resources.extra_gpu,
}
def has_trainable(trainable_name):
return ray.tune.registry._global_registry.contains(
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
@@ -133,12 +170,8 @@ class Trial(object):
The args here take the same meaning as the command line flags defined
in ray.tune.config_parser.
"""
if not has_trainable(trainable_name):
# Make sure rllib agents are registered
from ray import rllib # noqa: F401
if not has_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
Trial._registration_check(trainable_name)
# Trial config
self.trainable_name = trainable_name
self.config = config or {}
@@ -149,14 +182,15 @@ class Trial(object):
or self._get_trainable_cls().default_resource_request(self.config))
self.stopping_criterion = stopping_criterion or {}
self.upload_dir = upload_dir
self.trial_name_creator = trial_name_creator
self.custom_loggers = custom_loggers
self.sync_function = sync_function
validate_sync_function(sync_function)
self.verbose = True
self.max_failures = max_failures
# Local trial state that is updated during the run
self.last_result = None
self.last_update_time = -float("inf")
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
self._checkpoint = Checkpoint(
@@ -170,6 +204,18 @@ class Trial(object):
self.error_file = None
self.num_failures = 0
self.trial_name = None
if trial_name_creator:
self.trial_name = trial_name_creator(self)
@classmethod
def _registration_check(cls, trainable_name):
if not has_trainable(trainable_name):
# Make sure rllib agents are registered
from ray import rllib # noqa: F401
if not has_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
@classmethod
def generate_id(cls):
return binary_to_hex(random_string())[:8]
@@ -180,10 +226,14 @@ class Trial(object):
if not self.result_logger:
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
self.logdir = tempfile.mkdtemp(
prefix="{}_{}".format(
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
dir=self.local_dir)
if not self.logdir:
self.logdir = tempfile.mkdtemp(
prefix="{}_{}".format(
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
dir=self.local_dir)
elif not os.path.exists(self.logdir):
os.makedirs(self.logdir)
self.result_logger = UnifiedLogger(
self.config,
self.logdir,
@@ -307,6 +357,7 @@ class Trial(object):
pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.last_result = result
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
def _get_trainable_cls(self):
@@ -327,8 +378,8 @@ class Trial(object):
Can be overriden with a custom string creator.
"""
if self.trial_name_creator:
return self.trial_name_creator(self)
if self.trial_name:
return self.trial_name
if "env" in self.config:
env = self.config["env"]
@@ -340,3 +391,48 @@ class Trial(object):
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier.replace("/", "_")
def __getstate__(self):
"""Memento generator for Trial.
Sets RUNNING trials to PENDING, and flushes the result logger.
Note this can only occur if the trial holds a DISK checkpoint.
"""
assert self._checkpoint.storage == Checkpoint.DISK, (
"Checkpoint must not be in-memory.")
state = self.__dict__.copy()
state["resources"] = resources_to_json(self.resources)
pickle_data = {
"_checkpoint": self._checkpoint,
"config": self.config,
"custom_loggers": self.custom_loggers,
"sync_function": self.sync_function
}
for key, value in pickle_data.items():
state[key] = binary_to_hex(cloudpickle.dumps(value))
state["runner"] = None
state["result_logger"] = None
if self.status == Trial.RUNNING:
state["status"] = Trial.PENDING
if self.result_logger:
self.result_logger.flush()
state["__logger_started__"] = True
else:
state["__logger_started__"] = False
return copy.deepcopy(state)
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
for key in [
"_checkpoint", "config", "custom_loggers", "sync_function"
]:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
self.__dict__.update(state)
Trial._registration_check(self.trainable_name)
if logger_started:
self.init_logger()
+38 -3
View File
@@ -25,6 +25,41 @@ class TrialExecutor(object):
automatic scale-up.
"""
self._queue_trials = queue_trials
self._cached_trial_state = {}
def set_status(self, trial, status):
"""Sets status and checkpoints metadata if needed.
Only checkpoints metadata if trial status is a terminal condition.
PENDING, PAUSED, and RUNNING switches have checkpoints taken care of
in the TrialRunner.
Args:
trial (Trial): Trial to checkpoint.
status (Trial.status): Status to set trial to.
"""
trial.status = status
if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)
def try_checkpoint_metadata(self, trial):
"""Checkpoints metadata.
Args:
trial (Trial): Trial to checkpoint.
"""
if trial._checkpoint.storage == Checkpoint.MEMORY:
logger.debug("Not saving data for trial w/ memory checkpoint.")
return
try:
logger.debug("Saving trial metadata.")
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
except Exception:
logger.exception("Error checkpointing trial metadata.")
def get_checkpoints(self):
"""Returns a copy of mapping of the trial ID to pickled metadata."""
return self._cached_trial_state.copy()
def has_resources(self, resources):
"""Returns whether this runner has at least the specified resources."""
@@ -71,15 +106,15 @@ class TrialExecutor(object):
try:
self.save(trial, Checkpoint.MEMORY)
self.stop_trial(trial, stop_logger=False)
trial.status = Trial.PAUSED
self.set_status(trial, Trial.PAUSED)
except Exception:
logger.exception("Error pausing runner.")
trial.status = Trial.ERROR
self.set_status(trial, Trial.ERROR)
def unpause_trial(self, trial):
"""Sets PAUSED trial to pending to allow scheduler to start."""
assert trial.status == Trial.PAUSED, trial.status
trial.status = Trial.PENDING
self.set_status(trial, Trial.PENDING)
def resume_trial(self, trial):
"""Resumes PAUSED trials. This is a blocking call."""
+125 -11
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import collections
import json
import logging
import os
import re
@@ -49,10 +50,13 @@ class TrialRunner(object):
misleading benchmark results.
"""
CKPT_FILE_NAME = "experiment_state.json"
def __init__(self,
search_alg,
scheduler=None,
launch_web_server=False,
metadata_checkpoint_dir=None,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
queue_trials=False,
@@ -64,6 +68,8 @@ class TrialRunner(object):
Trial objects.
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
metadata_checkpoint_dir (str): Path where
global checkpoints are stored and restored from.
server_port (int): Port number for launching TuneServer
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
@@ -75,7 +81,6 @@ class TrialRunner(object):
"""
self._search_alg = search_alg
self._scheduler_alg = scheduler or FIFOScheduler()
self._trials = []
self.trial_executor = trial_executor or \
RayTrialExecutor(queue_trials=queue_trials)
@@ -84,13 +89,93 @@ class TrialRunner(object):
self._global_time_limit = float(
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')))
self._total_time = 0
self._server = None
if launch_web_server:
self._server = TuneServer(self, server_port)
self._stop_queue = []
self._iteration = 0
self._verbose = verbose
self._queue_trials = queue_trials
self._server = None
self._server_port = server_port
if launch_web_server:
self._server = TuneServer(self, self._server_port)
self._trials = []
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir
def checkpoint(self):
"""Saves execution state to `self._metadata_checkpoint_dir`."""
if not self._metadata_checkpoint_dir:
return
metadata_checkpoint_dir = self._metadata_checkpoint_dir
if not os.path.exists(metadata_checkpoint_dir):
os.makedirs(metadata_checkpoint_dir)
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
"runner_data": self.__getstate__()
}
tmp_file_name = os.path.join(metadata_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2)
os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME))
return metadata_checkpoint_dir
@classmethod
def restore(cls,
metadata_checkpoint_dir,
search_alg=None,
scheduler=None,
trial_executor=None):
"""Restores all checkpointed trials from previous run.
Requires user to manually re-register their objects. Also stops
all ongoing trials.
Args:
metadata_checkpoint_dir (str): Path to metadata checkpoints.
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
BasicVariantGenerator.
scheduler (TrialScheduler): Scheduler for executing
the experiment.
trial_executor (TrialExecutor): Manage the execution of trials.
Returns:
runner (TrialRunner): A TrialRunner to resume experiments from.
"""
with open(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME), "r") as f:
runner_state = json.load(f)
logger.warning("".join([
"Attempting to resume experiment from {}. ".format(
metadata_checkpoint_dir), "This feature is experimental, "
"and may not work with all search algorithms. ",
"This will ignore any new changes to the specification."
]))
from ray.tune.suggest import BasicVariantGenerator
runner = TrialRunner(
search_alg or BasicVariantGenerator(),
scheduler=scheduler,
trial_executor=trial_executor)
runner.__setstate__(runner_state["runner_data"])
trials = []
for trial_cp in runner_state["checkpoints"]:
new_trial = Trial(trial_cp["trainable_name"])
new_trial.__setstate__(trial_cp)
trials += [new_trial]
for trial in sorted(
trials, key=lambda t: t.last_update_time, reverse=True):
runner.add_trial(trial)
return runner
def is_finished(self):
"""Returns whether all trials have finished running."""
@@ -136,6 +221,12 @@ class TrialRunner(object):
"There are paused trials, but no more pending "
"trials with sufficient resources.")
try:
self.checkpoint()
except Exception:
logger.exception("Trial Runner checkpointing failed.")
self._iteration += 1
if self._server:
self._process_requests()
@@ -165,6 +256,7 @@ class TrialRunner(object):
"""
trial.set_verbose(self._verbose)
self._scheduler_alg.on_trial_add(self, trial)
self.trial_executor.try_checkpoint_metadata(trial)
self._trials.append(trial)
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
@@ -279,14 +371,14 @@ class TrialRunner(object):
result, terminate=(decision == TrialScheduler.STOP))
if decision == TrialScheduler.CONTINUE:
self._checkpoint_if_needed(trial)
self._checkpoint_trial_if_needed(trial)
self.trial_executor.continue_training(trial)
elif decision == TrialScheduler.PAUSE:
self.trial_executor.pause_trial(trial)
elif decision == TrialScheduler.STOP:
# Checkpoint before ending the trial
# if checkpoint_at_end experiment option is set to True
self._checkpoint_if_needed(trial)
self._checkpoint_trial_if_needed(trial)
self.trial_executor.stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
@@ -304,12 +396,13 @@ class TrialRunner(object):
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _checkpoint_if_needed(self, trial):
def _checkpoint_trial_if_needed(self, trial):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint():
# Save trial runtime if possible
if hasattr(trial, "runner") and trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.DISK)
self.trial_executor.try_checkpoint_metadata(trial)
def _try_recover(self, trial, error_msg):
"""Tries to recover trial.
@@ -344,11 +437,11 @@ class TrialRunner(object):
def _requeue_trial(self, trial):
"""Notification to TrialScheduler and requeue trial.
This does not notify the SearchAlgorithm because
the function evaluation is still in progress.
This does not notify the SearchAlgorithm because the function
evaluation is still in progress.
"""
self._scheduler_alg.on_trial_error(self, trial)
trial.status = Trial.PENDING
self.trial_executor.set_status(trial, Trial.PENDING)
self._scheduler_alg.on_trial_add(self, trial)
def _update_trial_queue(self, blocking=False, timeout=600):
@@ -417,3 +510,24 @@ class TrialRunner(object):
error = True
self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg)
def __getstate__(self):
"""Gets state for trial.
Note that this is not used as a pickling override as
does not have all fields.
"""
state = self.__dict__.copy()
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "trial_executor"
]:
del state[k]
state["launch_web_server"] = bool(self._server)
return state
def __setstate__(self, state):
launch_web_server = state.pop("launch_web_server")
self.__dict__.update(state)
if launch_web_server:
self._server = TuneServer(self, self._server_port)
+77 -15
View File
@@ -2,10 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import logging
import os
import time
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.log_sync import wait_for_log_sync
@@ -32,12 +35,30 @@ def _make_scheduler(args):
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments=None,
def _find_checkpoint_dir(exp_list):
assert exp_list, "Experiments must be specified via `run_experiments`"
exp = exp_list[0]
# TODO(rliaw): Make sure this is resolved earlier.
return os.path.join(exp.spec["local_dir"], exp.name)
def try_restore_runner(checkpoint_dir, search_alg, scheduler, trial_executor):
new_runner = None
try:
new_runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler,
trial_executor)
except Exception:
logger.exception("Runner restore failed. Restarting experiment.")
return new_runner
def run_experiments(experiments,
search_alg=None,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
resume=None,
queue_trials=False,
trial_executor=None,
raise_on_failed_trial=True):
@@ -55,6 +76,9 @@ def run_experiments(experiments=None,
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (bool): How much output should be printed for each trial.
resume (bool|None): If checkpoint exists, the experiment will
resume from there. If resume is None, Tune will prompt if
checkpoint detected.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
@@ -83,26 +107,64 @@ def run_experiments(experiments=None,
List of Trial objects, holding data for each executed trial.
"""
# This is important to do this here
# because it schematize the experiments
# and it conducts the implicit registration.
experiments = convert_to_experiment_list(experiments)
checkpoint_dir = _find_checkpoint_dir(experiments)
if scheduler is None:
scheduler = FIFOScheduler()
runner = None
restore = False
if search_alg is None:
search_alg = BasicVariantGenerator()
# TUNE_RESUME_PROMPT_OFF is for testing purposes and defaults
# `resume=False.`
if os.environ.get("TUNE_RESUME_PROMPT_OFF"):
resume = resume or False
search_alg.add_configurations(experiments)
if os.path.exists(
os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if resume:
restore = True
elif resume is None:
msg = ("Found incomplete experiment at {}. "
"Would you like to resume it?".format(checkpoint_dir))
restore = click.confirm(msg, default=True)
if restore:
logger.info("Tip: to always resume, "
"pass resume=True to run_experiments()")
else:
logger.info("Tip: to always start a new experiment, "
"pass resume=False to run_experiments()")
else:
logger.info(
"Did not find checkpoint file in {}.".format(checkpoint_dir))
runner = TrialRunner(
search_alg,
scheduler=scheduler,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose,
queue_trials=queue_trials,
trial_executor=trial_executor)
if restore:
runner = try_restore_runner(checkpoint_dir, search_alg, scheduler,
trial_executor)
else:
logger.info("Starting a new experiment.")
if not runner:
if scheduler is None:
scheduler = FIFOScheduler()
if search_alg is None:
search_alg = BasicVariantGenerator()
search_alg.add_configurations(experiments)
runner = TrialRunner(
search_alg,
scheduler=scheduler,
metadata_checkpoint_dir=checkpoint_dir,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose,
queue_trials=queue_trials,
trial_executor=trial_executor)
logger.info(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()