mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[tune] Cross-Node Recovery (#3725)
Augments trial restore to also check if the runner is at the same location. If not, the checkpoint files are pushed onto the new location.
This commit is contained in:
@@ -6,6 +6,7 @@ import distutils.spawn
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import types
|
||||
|
||||
@@ -21,6 +22,7 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.suggest.variant_generator import function as tune_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_log_sync_warned = False
|
||||
|
||||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
@@ -97,6 +99,8 @@ class _LogSyncer(object):
|
||||
def __init__(self, local_dir, remote_dir=None, sync_function=None):
|
||||
self.local_dir = local_dir
|
||||
self.remote_dir = remote_dir
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync", dir=self.local_dir, suffix=".log", delete=False)
|
||||
|
||||
# Resolve sync_function into template or function
|
||||
self.sync_func = None
|
||||
@@ -115,13 +119,42 @@ class _LogSyncer(object):
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
"""Set the worker ip to sync logs from."""
|
||||
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
def sync_if_needed(self):
|
||||
if time.time() - self.last_sync_time > 300:
|
||||
self.sync_now()
|
||||
|
||||
def sync_to_worker_if_possible(self):
|
||||
"""Syncs the local logdir on driver to worker if possible.
|
||||
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
rsync to be installed.
|
||||
"""
|
||||
if self.worker_ip == self.local_ip:
|
||||
return
|
||||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if ssh_key is None or ssh_user is None:
|
||||
if not _log_sync_warned:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray up`.")
|
||||
_log_sync_warned = True
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
return
|
||||
source = '{}/'.format(self.local_dir)
|
||||
target = '{}@{}:{}/'.format(ssh_user, self.worker_ip, self.local_dir)
|
||||
final_cmd = (("""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" {} {}""").format(
|
||||
quote(ssh_key), quote(source), quote(target)))
|
||||
logger.info("Syncing results to %s", str(self.worker_ip))
|
||||
sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
sync_process.wait()
|
||||
|
||||
def sync_now(self, force=False):
|
||||
self.last_sync_time = time.time()
|
||||
if not self.worker_ip:
|
||||
@@ -134,9 +167,12 @@ class _LogSyncer(object):
|
||||
else:
|
||||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if ssh_key is None or ssh_user is None:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray create_or_update`.")
|
||||
if not _log_sync_warned:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray up`.")
|
||||
_log_sync_warned = True
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
@@ -179,7 +215,8 @@ class _LogSyncer(object):
|
||||
final_cmd += " && "
|
||||
final_cmd += local_to_remote_sync_cmd
|
||||
logger.debug("Running log sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(final_cmd, shell=True)
|
||||
self.sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
|
||||
@@ -88,6 +88,7 @@ class UnifiedLogger(Logger):
|
||||
sync_function=None):
|
||||
self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger]
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
if custom_loggers:
|
||||
assert isinstance(custom_loggers, list), "Improper custom loggers."
|
||||
self._logger_list += custom_loggers
|
||||
@@ -122,6 +123,16 @@ class UnifiedLogger(Logger):
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.wait()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
|
||||
Syncing will not occur if the cluster is not started
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
self._log_syncer.sync_to_worker_if_possible()
|
||||
|
||||
|
||||
class NoopLogger(Logger):
|
||||
def on_result(self, result):
|
||||
|
||||
@@ -71,7 +71,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial.runner = self._setup_runner(trial)
|
||||
if not self.restore(trial, checkpoint):
|
||||
if trial.status == Trial.ERROR:
|
||||
raise RuntimeError("Restore from checkpoint failed.")
|
||||
raise RuntimeError(
|
||||
"Restore from checkpoint failed for Trial {}.".format(
|
||||
str(trial)))
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if (prior_status == Trial.PAUSED and previous_run):
|
||||
@@ -113,7 +115,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=0.25)
|
||||
except Exception:
|
||||
logger.exception("Error stopping runner.")
|
||||
logger.exception("Error stopping runner for Trial %s", str(trial))
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
finally:
|
||||
trial.runner = None
|
||||
@@ -133,17 +135,21 @@ class RayTrialExecutor(TrialExecutor):
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Error starting runner. "
|
||||
"Trying again without checkpoint.")
|
||||
logger.exception("Error starting runner for Trial %s", str(trial))
|
||||
error_msg = traceback.format_exc()
|
||||
time.sleep(2)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
try:
|
||||
# This forces the trial to not start from checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
logger.info(
|
||||
"Trying to start runner for Trial %s without checkpoint.",
|
||||
str(trial))
|
||||
self._start_trial(trial)
|
||||
except Exception:
|
||||
logger.exception("Error starting runner, aborting!")
|
||||
logger.exception(
|
||||
"Error starting runner for Trial %s, aborting!",
|
||||
str(trial))
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
# note that we don't return the resources, since they may
|
||||
@@ -159,7 +165,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._stop_trial(
|
||||
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
|
||||
if prior_status == Trial.RUNNING:
|
||||
logger.debug("Returning resources for this trial.")
|
||||
logger.debug("Returning resources for Trial %s.", str(trial))
|
||||
self._return_resources(trial.resources)
|
||||
out = self._find_item(self._running, trial)
|
||||
for result_id in out:
|
||||
@@ -249,7 +255,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
self._update_avail_resources()
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
|
||||
@@ -312,7 +318,11 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return trial._checkpoint.value
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint."""
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
This will also sync the trial results to a new location
|
||||
if restoring on a different node.
|
||||
"""
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
checkpoint = trial._checkpoint
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
@@ -327,11 +337,12 @@ class RayTrialExecutor(TrialExecutor):
|
||||
assert type(value) != Checkpoint, type(value)
|
||||
ray.get(trial.runner.restore_from_object.remote(value))
|
||||
else:
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote())
|
||||
trial.sync_logger_to_new_location(worker_ip)
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
trial.last_result = checkpoint.last_result
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error restoring runner.")
|
||||
logger.exception("Error restoring runner for Trial %s.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
|
||||
@@ -104,6 +104,10 @@ class Trainable(object):
|
||||
|
||||
return ""
|
||||
|
||||
def current_ip(self):
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
return self._local_ip
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
|
||||
@@ -196,8 +196,8 @@ class Trial(object):
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.status = Trial.PENDING
|
||||
self.location = None
|
||||
self.logdir = None
|
||||
self.runner = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
||||
@@ -241,6 +241,14 @@ class Trial(object):
|
||||
custom_loggers=self.custom_loggers,
|
||||
sync_function=self.sync_function)
|
||||
|
||||
def sync_logger_to_new_location(self, worker_ip):
|
||||
"""Updates the logger location.
|
||||
|
||||
Also pushes logdir to worker_ip, allowing for cross-node recovery.
|
||||
"""
|
||||
if self.result_logger:
|
||||
self.result_logger.sync_results_to_new_location(worker_ip)
|
||||
|
||||
def close_logger(self):
|
||||
"""Close logger."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user