[tune] Cross-Node Recovery (#3725)

Augments trial restore to also check if the runner is at the same
location. If not, the checkpoint files are pushed onto the new location.
This commit is contained in:
Richard Liaw
2019-01-15 10:37:28 -08:00
committed by GitHub
parent a5df8e3532
commit 3918934dfd
5 changed files with 86 additions and 15 deletions
+41 -4
View File
@@ -6,6 +6,7 @@ import distutils.spawn
import logging
import os
import subprocess
import tempfile
import time
import types
@@ -21,6 +22,7 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.suggest.variant_generator import function as tune_function
logger = logging.getLogger(__name__)
_log_sync_warned = False
# Map from (logdir, remote_dir) -> syncer
_syncers = {}
@@ -97,6 +99,8 @@ class _LogSyncer(object):
def __init__(self, local_dir, remote_dir=None, sync_function=None):
self.local_dir = local_dir
self.remote_dir = remote_dir
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync", dir=self.local_dir, suffix=".log", delete=False)
# Resolve sync_function into template or function
self.sync_func = None
@@ -115,13 +119,42 @@ class _LogSyncer(object):
def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip
def sync_if_needed(self):
if time.time() - self.last_sync_time > 300:
self.sync_now()
def sync_to_worker_if_possible(self):
"""Syncs the local logdir on driver to worker if possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
"""
if self.worker_ip == self.local_ip:
return
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
global _log_sync_warned
if ssh_key is None or ssh_user is None:
if not _log_sync_warned:
logger.error("Log sync requires cluster to be setup with "
"`ray up`.")
_log_sync_warned = True
return
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
return
source = '{}/'.format(self.local_dir)
target = '{}@{}:{}/'.format(ssh_user, self.worker_ip, self.local_dir)
final_cmd = (("""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {} {}""").format(
quote(ssh_key), quote(source), quote(target)))
logger.info("Syncing results to %s", str(self.worker_ip))
sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
sync_process.wait()
def sync_now(self, force=False):
self.last_sync_time = time.time()
if not self.worker_ip:
@@ -134,9 +167,12 @@ class _LogSyncer(object):
else:
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
global _log_sync_warned
if ssh_key is None or ssh_user is None:
logger.error("Log sync requires cluster to be setup with "
"`ray create_or_update`.")
if not _log_sync_warned:
logger.error("Log sync requires cluster to be setup with "
"`ray up`.")
_log_sync_warned = True
return
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
@@ -179,7 +215,8 @@ class _LogSyncer(object):
final_cmd += " && "
final_cmd += local_to_remote_sync_cmd
logger.debug("Running log sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(final_cmd, shell=True)
self.sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
def wait(self):
if self.sync_process:
+11
View File
@@ -88,6 +88,7 @@ class UnifiedLogger(Logger):
sync_function=None):
self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger]
self._sync_function = sync_function
self._log_syncer = None
if custom_loggers:
assert isinstance(custom_loggers, list), "Improper custom loggers."
self._logger_list += custom_loggers
@@ -122,6 +123,16 @@ class UnifiedLogger(Logger):
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()
def sync_results_to_new_location(self, worker_ip):
"""Sends the current log directory to the remote node.
Syncing will not occur if the cluster is not started
with the Ray autoscaler.
"""
if worker_ip != self._log_syncer.worker_ip:
self._log_syncer.set_worker_ip(worker_ip)
self._log_syncer.sync_to_worker_if_possible()
class NoopLogger(Logger):
def on_result(self, result):
+21 -10
View File
@@ -71,7 +71,9 @@ class RayTrialExecutor(TrialExecutor):
trial.runner = self._setup_runner(trial)
if not self.restore(trial, checkpoint):
if trial.status == Trial.ERROR:
raise RuntimeError("Restore from checkpoint failed.")
raise RuntimeError(
"Restore from checkpoint failed for Trial {}.".format(
str(trial)))
previous_run = self._find_item(self._paused, trial)
if (prior_status == Trial.PAUSED and previous_run):
@@ -113,7 +115,7 @@ class RayTrialExecutor(TrialExecutor):
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=0.25)
except Exception:
logger.exception("Error stopping runner.")
logger.exception("Error stopping runner for Trial %s", str(trial))
self.set_status(trial, Trial.ERROR)
finally:
trial.runner = None
@@ -133,17 +135,21 @@ class RayTrialExecutor(TrialExecutor):
try:
self._start_trial(trial, checkpoint)
except Exception:
logger.exception("Error starting runner. "
"Trying again without checkpoint.")
logger.exception("Error starting runner for Trial %s", str(trial))
error_msg = traceback.format_exc()
time.sleep(2)
self._stop_trial(trial, error=True, error_msg=error_msg)
try:
# This forces the trial to not start from checkpoint.
trial.clear_checkpoint()
logger.info(
"Trying to start runner for Trial %s without checkpoint.",
str(trial))
self._start_trial(trial)
except Exception:
logger.exception("Error starting runner, aborting!")
logger.exception(
"Error starting runner for Trial %s, aborting!",
str(trial))
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
# note that we don't return the resources, since they may
@@ -159,7 +165,7 @@ class RayTrialExecutor(TrialExecutor):
self._stop_trial(
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
if prior_status == Trial.RUNNING:
logger.debug("Returning resources for this trial.")
logger.debug("Returning resources for Trial %s.", str(trial))
self._return_resources(trial.resources)
out = self._find_item(self._running, trial)
for result_id in out:
@@ -249,7 +255,7 @@ class RayTrialExecutor(TrialExecutor):
def has_resources(self, resources):
"""Returns whether this runner has at least the specified resources."""
self._update_avail_resources()
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
@@ -312,7 +318,11 @@ class RayTrialExecutor(TrialExecutor):
return trial._checkpoint.value
def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint."""
"""Restores training state from a given model checkpoint.
This will also sync the trial results to a new location
if restoring on a different node.
"""
if checkpoint is None or checkpoint.value is None:
checkpoint = trial._checkpoint
if checkpoint is None or checkpoint.value is None:
@@ -327,11 +337,12 @@ class RayTrialExecutor(TrialExecutor):
assert type(value) != Checkpoint, type(value)
ray.get(trial.runner.restore_from_object.remote(value))
else:
worker_ip = ray.get(trial.runner.current_ip.remote())
trial.sync_logger_to_new_location(worker_ip)
ray.get(trial.runner.restore.remote(value))
trial.last_result = checkpoint.last_result
return True
except Exception:
logger.exception("Error restoring runner.")
logger.exception("Error restoring runner for Trial %s.", trial)
self.set_status(trial, Trial.ERROR)
return False
+4
View File
@@ -104,6 +104,10 @@ class Trainable(object):
return ""
def current_ip(self):
self._local_ip = ray.services.get_node_ip_address()
return self._local_ip
def train(self):
"""Runs one logical iteration of training.
+9 -1
View File
@@ -196,8 +196,8 @@ class Trial(object):
self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.status = Trial.PENDING
self.location = None
self.logdir = None
self.runner = None
self.result_logger = None
self.last_debug = 0
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
@@ -241,6 +241,14 @@ class Trial(object):
custom_loggers=self.custom_loggers,
sync_function=self.sync_function)
def sync_logger_to_new_location(self, worker_ip):
"""Updates the logger location.
Also pushes logdir to worker_ip, allowing for cross-node recovery.
"""
if self.result_logger:
self.result_logger.sync_results_to_new_location(worker_ip)
def close_logger(self):
"""Close logger."""