diff --git a/doc/source/walkthrough.rst b/doc/source/walkthrough.rst index df0e4c2c0..d2da67f4f 100644 --- a/doc/source/walkthrough.rst +++ b/doc/source/walkthrough.rst @@ -204,7 +204,7 @@ You can also set a timeout to return early from a ``get`` that's blocking for to .. code-block:: python - from ray.exceptions import RayTimeoutException + from ray.exceptions import RayTimeoutError @ray.remote def long_running_function() diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index e7b464e78..b8bfea52e 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -13,8 +13,22 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint, randn, loguniform) __all__ = [ - "Trainable", "TuneError", "grid_search", "register_env", - "register_trainable", "run", "run_experiments", "Experiment", "function", - "sample_from", "track", "uniform", "choice", "randint", "randn", - "loguniform", "progress_reporter", "ExperimentAnalysis", "Analysis" + "Trainable", + "TuneError", + "grid_search", + "register_env", + "register_trainable", + "run", + "run_experiments", + "Experiment", + "function", + "sample_from", + "track", + "uniform", + "choice", + "randint", + "randn", + "loguniform", + "ExperimentAnalysis", + "Analysis", ] diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py new file mode 100644 index 000000000..1ef0f6e20 --- /dev/null +++ b/python/ray/tune/checkpoint_manager.py @@ -0,0 +1,142 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import heapq +import logging +import os +import shutil + +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + +logger = logging.getLogger(__name__) + + +class Checkpoint(object): + """Describes a checkpoint of trial state. + + Checkpoint may be saved in different storage. + + Attributes: + storage (str): Storage type. + value (str): If storage==MEMORY, value is a Python object. + If storage==DISK, value is a path points to the checkpoint in disk. + """ + + MEMORY = "memory" + DISK = "disk" + + def __init__(self, storage, value, result=None): + self.storage = storage + self.value = value + self.result = result or {} + + def delete(self): + """Deletes checkpoint data if disk checkpoint.""" + if self.storage == Checkpoint.DISK and self.value: + checkpoint_dir = self.value + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError( + "Attempted to delete checkpoint at {} but " + "path was not found.".format(checkpoint_dir)) + elif os.path.isfile(checkpoint_dir): + shutil.rmtree(os.path.dirname(checkpoint_dir)) + else: + shutil.rmtree(checkpoint_dir) + + @staticmethod + def from_object(value=None): + """Creates a checkpoint from a Python object.""" + return Checkpoint(Checkpoint.MEMORY, value) + + +class QueueItem(object): + def __init__(self, priority, value): + self.priority = priority + self.value = value + + def __cmp__(self, other): + # For python2.7 compatibility. + if self.priority == other.priority: + return 0 + return -1 if self.priority < other.priority else 1 + + def __lt__(self, other): + return self.priority < other.priority + + +class CheckpointManager(object): + """Manages checkpoints on the driver for a trial.""" + + def __init__(self, keep_checkpoints_num, checkpoint_score_attr): + """Initializes a new CheckpointManager. + + Args: + keep_checkpoints_num (int): Keep at least this many checkpoints. + checkpoint_score_attr (str): Attribute to use to determine which + checkpoints to keep. + """ + self.keep_checkpoints_num = keep_checkpoints_num or float("inf") + assert self.keep_checkpoints_num > 0, ( + "keep_checkpoints_num must be greater than 0.") + self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-") + if self._checkpoint_score_desc: + self._checkpoint_score_attr = checkpoint_score_attr[4:] + else: + self._checkpoint_score_attr = checkpoint_score_attr + + self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None) + self._best_checkpoints = [] + self._membership = set() + + def on_checkpoint(self, checkpoint): + """Starts tracking checkpoint metadata on checkpoint. + + Sets newest checkpoint. Deletes previous checkpoint as long as it isn't + one of the best ones. Also deletes the worst checkpoint if at capacity. + + Args: + checkpoint (Checkpoint): Trial state checkpoint. + + Raises: + KeyError if checkpoint_score_attr not in result of checkpoint. + """ + old_checkpoint = self.newest_checkpoint + self.newest_checkpoint = checkpoint + + try: + queue_item = QueueItem(self._priority(checkpoint), checkpoint) + except KeyError: + if old_checkpoint not in self._membership: + old_checkpoint.delete() + logger.error("Result dict has no key: {}. " + "checkpoint_score_attr must be set to a key in the " + "result dict.".format(self._checkpoint_score_attr)) + return + + if len(self._best_checkpoints) < self.keep_checkpoints_num: + heapq.heappush(self._best_checkpoints, queue_item) + self._membership.add(checkpoint) + elif queue_item.priority >= self._best_checkpoints[0].priority: + worst = heapq.heappushpop(self._best_checkpoints, queue_item).value + self._membership.add(checkpoint) + if worst in self._membership: + self._membership.remove(worst) + worst.delete() + + # Remove the old checkpoint if it isn't one of the best ones. + if old_checkpoint not in self._membership: + old_checkpoint.delete() + + def best_checkpoints(self): + """Returns best checkpoints, sorted by score.""" + checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority) + return [queue_item.value for queue_item in checkpoints] + + def _priority(self, checkpoint): + priority = checkpoint.result[self._checkpoint_score_attr] + return -priority if self._checkpoint_score_desc else priority diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 5a0107714..b4a797385 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -76,11 +76,19 @@ def make_parser(parser_creator=None, **kwargs): action="store_true", help="Whether to checkpoint at the end of the experiment. " "Default is False.") + parser.add_argument( + "--no-sync-on-checkpoint", + action="store_true", + help="Disable sync-down of trial checkpoint, which is enabled by " + "default to guarantee recoverability. If set, checkpoint syncing from " + "worker to driver is asynchronous. Set this only if synchronous " + "checkpointing is too slow and trial restoration failures can be " + "tolerated") parser.add_argument( "--keep-checkpoints-num", default=None, type=int, - help="Number of last checkpoints to keep. Others get " + help="Number of best checkpoints to keep. Others get " "deleted. Default (None) keeps all checkpoints.") parser.add_argument( "--checkpoint-score-attr", @@ -177,6 +185,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, + sync_on_checkpoint=not args.no_sync_on_checkpoint, keep_checkpoints_num=args.keep_checkpoints_num, checkpoint_score_attr=args.checkpoint_score_attr, export_formats=spec.get("export_formats", []), diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 480aade4b..843fc092e 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -72,6 +72,7 @@ class Experiment(object): sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, + sync_on_checkpoint=True, keep_checkpoints_num=None, checkpoint_score_attr=None, export_formats=None, @@ -80,6 +81,11 @@ class Experiment(object): repeat=None, trial_resources=None, sync_function=None): + """Initialize a new Experiment. + + The args here take the same meaning as the command line flags defined + in `tune.py:run`. + """ if repeat: _raise_deprecation_note("repeat", "num_samples", soft=False) if trial_resources: @@ -102,7 +108,7 @@ class Experiment(object): "criteria must take exactly 2 parameters.".format(stop)) config = config or {} - self._run_identifier = Experiment._register_if_needed(run) + self._run_identifier = Experiment.register_if_needed(run) spec = { "run": self._run_identifier, "stop": stop, @@ -117,6 +123,7 @@ class Experiment(object): "sync_to_driver": sync_to_driver, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, + "sync_on_checkpoint": sync_on_checkpoint, "keep_checkpoints_num": keep_checkpoints_num, "checkpoint_score_attr": checkpoint_score_attr, "export_formats": export_formats or [], @@ -156,7 +163,7 @@ class Experiment(object): return exp @classmethod - def _register_if_needed(cls, run_object): + def register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index 8d8433dfc..a7d2be140 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -17,12 +17,17 @@ logger = logging.getLogger(__name__) _log_sync_warned = False -def log_sync_template(): +def log_sync_template(options=""): """Syncs the local_dir between driver and worker if possible. Requires ray cluster to be started with the autoscaler. Also requires rsync to be installed. + Args: + options (str): Addtional rsync options. + + Returns: + Sync template with source and target parameters. """ if not distutils.spawn.find_executable("rsync"): logger.error("Log sync requires rsync to be installed.") @@ -36,12 +41,14 @@ def log_sync_template(): _log_sync_warned = True return - return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" {{source}} {{target}}""" - ).format(ssh_key=quote(ssh_key)) + rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no" + rsh = rsh.format(ssh_key=quote(ssh_key)) + template = """rsync {options} -savz -e "{rsh}" {{source}} {{target}}""" + return template.format(options=options, rsh=rsh) class NodeSyncMixin(object): + # TODO(ujvl): Refactor this code. """Mixin for syncing files to/from a remote dir to a local dir.""" def __init__(self): @@ -53,23 +60,43 @@ class NodeSyncMixin(object): """Set the worker ip to sync logs from.""" self.worker_ip = worker_ip - def _check_valid_worker_ip(self): + def has_remote_target(self): + """Returns whether the Syncer has a remote target.""" if not self.worker_ip: - logger.debug("Worker ip unknown, skipping log sync for {}".format( - self._local_dir)) + logger.debug("Worker IP unknown, skipping log sync for %s", + self._local_dir) return False if self.worker_ip == self.local_ip: - logger.debug( - "Worker ip is local ip, skipping log sync for {}".format( - self._local_dir)) + logger.debug("Worker IP is local IP, skipping log sync for %s", + self._local_dir) return False return True + def sync_up_if_needed(self): + if not self.has_remote_target(): + return True + super(NodeSyncMixin, self).sync_up() + + def sync_down_if_needed(self): + if not self.has_remote_target(): + return True + super(NodeSyncMixin, self).sync_down() + + def sync_down(self): + if not self.has_remote_target(): + return True + return super(NodeSyncMixin, self).sync_down() + + def sync_up(self): + if not self.has_remote_target(): + return True + return super(NodeSyncMixin, self).sync_up() + @property def _remote_path(self): ssh_user = get_ssh_user() global _log_sync_warned - if not self._check_valid_worker_ip(): + if not self.has_remote_target(): return if ssh_user is None: if not _log_sync_warned: diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index aa70cf76e..7b61e6f2e 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -409,8 +409,8 @@ class UnifiedLogger(Logger): try: self._loggers.append(cls(self.config, self.logdir, self.trial)) except Exception as exc: - logger.warning("Could not instantiate {}: {}.".format( - cls.__name__, str(exc))) + logger.warning("Could not instantiate %s: %s.", cls.__name__, + str(exc)) self._log_syncer = get_log_syncer( self.logdir, remote_dir=self.logdir, @@ -429,12 +429,21 @@ class UnifiedLogger(Logger): def close(self): for _logger in self._loggers: _logger.close() - self._log_syncer.sync_down() def flush(self): for _logger in self._loggers: _logger.flush() - self._log_syncer.sync_down() + if not self._log_syncer.sync_down(): + logger.warning("Trial %s: Post-flush sync skipped.", self.trial) + + def sync_up(self): + return self._log_syncer.sync_up() + + def sync_down(self): + return self._log_syncer.sync_down() + + def wait(self): + self._log_syncer.wait() def sync_results_to_new_location(self, worker_ip): """Sends the current log directory to the remote node. @@ -443,13 +452,19 @@ class UnifiedLogger(Logger): with the Ray autoscaler. """ if worker_ip != self._log_syncer.worker_ip: - logger.info("Syncing (blocking) results to {}".format(worker_ip)) + logger.info("Trial %s: Syncing (blocking) results to %s", + self.trial, worker_ip) self._log_syncer.reset() self._log_syncer.set_worker_ip(worker_ip) - self._log_syncer.sync_up() - # TODO: change this because this is blocking. But failures - # are rare, so maybe this is OK? + if not self._log_syncer.sync_up(): + logger.error( + "Trial %s: Sync up to new location skipped. " + "This should not occur.", self.trial) self._log_syncer.wait() + else: + logger.error( + "Trial %s: Sync attempted to same IP %s. This " + "should not occur.", self.trial, worker_ip) class _SafeFallbackEncoder(json.JSONEncoder): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 916fa1cbf..4383cc784 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -4,18 +4,18 @@ from __future__ import division from __future__ import print_function import logging -import math import os import random import time import traceback import ray +from ray.exceptions import RayTimeoutError from ray import ray_constants from ray.resource_spec import ResourceSpec from ray.tune.error import AbortTrialExecution from ray.tune.logger import NoopLogger -from ray.tune.trial import Trial, Checkpoint +from ray.tune.trial import Trial, Checkpoint, Location from ray.tune.resources import Resources from ray.tune.trial_executor import TrialExecutor from ray.tune.util import warn_if_slow @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms BOTTLENECK_WARN_PERIOD_S = 60 NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3 +DEFAULT_GET_TIMEOUT = 30.0 # seconds class _LocalWrapper(object): @@ -37,7 +38,7 @@ class _LocalWrapper(object): class RayTrialExecutor(TrialExecutor): - """An implemention of TrialExecutor based on Ray.""" + """An implementation of TrialExecutor based on Ray.""" def __init__(self, queue_trials=False, @@ -71,36 +72,18 @@ class RayTrialExecutor(TrialExecutor): if ray.is_initialized(): self._update_avail_resources() - def _setup_runner(self, trial, reuse_allowed): + def _setup_remote_runner(self, trial, reuse_allowed): + trial.init_logger() + # We checkpoint metadata here to try mitigating logdir duplication + self.try_checkpoint_metadata(trial) + remote_logdir = trial.logdir + if (self._reuse_actors and reuse_allowed and self._cached_actor is not None): logger.debug("Reusing cached runner {} for {}".format( self._cached_actor, trial.trial_id)) existing_runner = self._cached_actor self._cached_actor = None - else: - if self._cached_actor: - logger.debug( - "Cannot reuse cached runner {} for new trial".format( - self._cached_actor)) - self._cached_actor.stop.remote() - self._cached_actor.__ray_terminate__.remote() - self._cached_actor = None - existing_runner = None - cls = ray.remote( - num_cpus=trial.resources.cpu, - num_gpus=trial.resources.gpu, - memory=trial.resources.memory, - object_store_memory=trial.resources.object_store_memory, - resources=trial.resources.custom_resources)( - trial.get_trainable_cls()) - - trial.init_logger() - # We checkpoint metadata here to try mitigating logdir duplication - self.try_checkpoint_metadata(trial) - remote_logdir = trial.logdir - - if existing_runner: trial.runner = existing_runner if not self.reset_trial(trial, trial.config, trial.experiment_tag): raise AbortTrialExecution( @@ -108,6 +91,21 @@ class RayTrialExecutor(TrialExecutor): "implemented and return True.") return existing_runner + if self._cached_actor: + logger.debug("Cannot reuse cached runner {} for new trial".format( + self._cached_actor)) + self._cached_actor.stop.remote() + self._cached_actor.__ray_terminate__.remote() + self._cached_actor = None + + cls = ray.remote( + num_cpus=trial.resources.cpu, + num_gpus=trial.resources.gpu, + memory=trial.resources.memory, + object_store_memory=trial.resources.object_store_memory, + resources=trial.resources.custom_resources)( + trial.get_trainable_cls()) + def logger_creator(config): # Set the working dir in the remote process, for user file writes if not os.path.exists(remote_logdir): @@ -116,6 +114,10 @@ class RayTrialExecutor(TrialExecutor): os.chdir(remote_logdir) return NoopLogger(config, remote_logdir) + # Clear the Trial's location (to be updated later on result) + # since we don't know where the remote runner is placed. + trial.set_location(Location()) + logger.info("Trial %s: Setting up new remote runner.", trial) # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. return cls.remote(config=trial.config, logger_creator=logger_creator) @@ -136,22 +138,20 @@ class RayTrialExecutor(TrialExecutor): """Starts trial and restores last result if trial was paused. Raises: - ValueError if restoring from checkpoint fails. + RuntimeError if restoring from checkpoint fails. """ prior_status = trial.status self.set_status(trial, Trial.RUNNING) - trial.runner = self._setup_runner( + trial.runner = self._setup_remote_runner( trial, - reuse_allowed=checkpoint is not None - or trial._checkpoint.value is not None) + reuse_allowed=checkpoint is not None or trial.has_checkpoint()) if not self.restore(trial, checkpoint): if trial.status == Trial.ERROR: raise RuntimeError( - "Restore from checkpoint failed for Trial {}.".format( - str(trial))) + "Trial {}: Restore from checkpoint failed.".format(trial)) previous_run = self._find_item(self._paused, trial) - if (prior_status == Trial.PAUSED and previous_run): + if prior_status == Trial.PAUSED and previous_run: # If Trial was in flight when paused, self._paused stores result. self._paused.pop(previous_run[0]) self._running[previous_run[0]] = trial @@ -175,10 +175,8 @@ class RayTrialExecutor(TrialExecutor): if stop_logger: trial.close_logger() - if error: - self.set_status(trial, Trial.ERROR) - else: - self.set_status(trial, Trial.TERMINATED) + self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED) + trial.set_location(Location()) try: trial.write_error_log(error_msg) @@ -188,12 +186,11 @@ class RayTrialExecutor(TrialExecutor): logger.debug("Reusing actor for {}".format(trial.runner)) self._cached_actor = trial.runner else: - logger.debug( - "Destroying actor for trial {}.".format(trial)) + logger.debug("Trial %s: Destroying actor.", trial) trial.runner.stop.remote() trial.runner.__ray_terminate__.remote() except Exception: - logger.exception("Error stopping runner for Trial %s", str(trial)) + logger.exception("Trial %s: Error stopping runner.", trial) self.set_status(trial, Trial.ERROR) finally: trial.runner = None @@ -208,32 +205,35 @@ class RayTrialExecutor(TrialExecutor): checkpoint (Checkpoint): A Python object or path storing the state of trial. """ - self._commit_resources(trial.resources) try: self._start_trial(trial, checkpoint) - except Exception as e: - logger.exception("Error starting runner for Trial %s", str(trial)) - error_msg = traceback.format_exc() + except AbortTrialExecution: + logger.exception("Trial %s: Error starting runner, aborting!", + trial) time.sleep(2) + error_msg = traceback.format_exc() + self._stop_trial(trial, error=True, error_msg=error_msg) + return # don't retry fatal Tune errors + except Exception: + logger.exception( + "Trial %s: Error starting runner. Attempting " + "restart without checkpoint.", trial) + time.sleep(2) + error_msg = traceback.format_exc() self._stop_trial(trial, error=True, error_msg=error_msg) - if isinstance(e, AbortTrialExecution): - return # don't retry fatal Tune errors try: # This forces the trial to not start from checkpoint. trial.clear_checkpoint() - logger.info( - "Trying to start runner for Trial %s without checkpoint.", - str(trial)) self._start_trial(trial) except Exception: logger.exception( - "Error starting runner for Trial %s, aborting!", - str(trial)) + "Trial %s: Error starting runner on second " + "attempt, aborting!", trial) error_msg = traceback.format_exc() self._stop_trial(trial, error=True, error_msg=error_msg) - # note that we don't return the resources, since they may - # have been lost + # Note that we don't return the resources, since they may + # have been lost. TODO(ujvl): is this the right thing to do? def _find_item(self, dictionary, item): out = [rid for rid, t in dictionary.items() if t is item] @@ -245,7 +245,7 @@ class RayTrialExecutor(TrialExecutor): self._stop_trial( trial, error=error, error_msg=error_msg, stop_logger=stop_logger) if prior_status == Trial.RUNNING: - logger.debug("Returning resources for Trial %s.", str(trial)) + logger.debug("Trial %s: Returning resources.", trial) self._return_resources(trial.resources) out = self._find_item(self._running, trial) for result_id in out: @@ -253,7 +253,6 @@ class RayTrialExecutor(TrialExecutor): def continue_training(self, trial): """Continues the training of this trial.""" - self._train(trial) def pause_trial(self, trial): @@ -262,7 +261,6 @@ class RayTrialExecutor(TrialExecutor): If trial is in-flight, preserves return value in separate queue before pausing, which is restored when Trial is resumed. """ - trial_future = self._find_item(self._running, trial) if trial_future: self._paused[trial_future[0]] = trial @@ -285,7 +283,13 @@ class RayTrialExecutor(TrialExecutor): trial.config = new_config trainable = trial.runner with warn_if_slow("reset_config"): - reset_val = ray.get(trainable.reset_config.remote(new_config)) + try: + reset_val = ray.get( + trainable.reset_config.remote(new_config), + DEFAULT_GET_TIMEOUT) + except RayTimeoutError: + logger.exception("Trial %s: reset_config timed out.") + return False return reset_val def get_running_trials(self): @@ -351,7 +355,7 @@ class RayTrialExecutor(TrialExecutor): raise ValueError("Trial was not running.") self._running.pop(trial_future[0]) with warn_if_slow("fetch_result"): - result = ray.get(trial_future[0]) + result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) # For local mode if isinstance(result, _LocalWrapper): @@ -530,48 +534,28 @@ class RayTrialExecutor(TrialExecutor): def save(self, trial, storage=Checkpoint.DISK): """Saves the trial's state to a checkpoint.""" - trial._checkpoint.storage = storage - trial._checkpoint.last_result = trial.last_result if storage == Checkpoint.MEMORY: - trial._checkpoint.value = trial.runner.save_to_object.remote() + value = trial.runner.save_to_object.remote() + checkpoint = Checkpoint(storage, value, trial.last_result) else: - # Keeps only highest performing checkpoints if enabled - if trial.keep_checkpoints_num: - try: - last_attr_val = trial.last_result[ - trial.checkpoint_score_attr] - if (trial.compare_checkpoints(last_attr_val) - and not math.isnan(last_attr_val)): - trial.best_checkpoint_attr_value = last_attr_val - self._checkpoint_and_erase(trial) - except KeyError: - logger.warning( - "Result dict has no key: {}. keep" - "_checkpoints_num flag will not work".format( - trial.checkpoint_score_attr)) - else: - with warn_if_slow("save_to_disk"): - trial._checkpoint.value = ray.get( - trial.runner.save.remote()) + with warn_if_slow("save_checkpoint_to_disk"): + value = ray.get(trial.runner.save.remote()) + checkpoint = Checkpoint(storage, value, trial.last_result) - return trial._checkpoint.value - - def _checkpoint_and_erase(self, trial): - """Checkpoints the model and erases old checkpoints - if needed. - Parameters - ---------- - trial : trial to save - """ - - with warn_if_slow("save_to_disk"): - trial._checkpoint.value = ray.get(trial.runner.save.remote()) - - if len(trial.history) >= trial.keep_checkpoints_num: - ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1])) - trial.history.pop() - - trial.history.insert(0, trial._checkpoint.value) + with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile: + try: + trial.on_checkpoint(checkpoint) + except Exception: + logger.exception("Trial %s: Error handling checkpoint %s", + trial, checkpoint.value) + return None + if profile.too_slow and trial.sync_on_checkpoint: + logger.warning( + "Consider turning off forced head-worker trial checkpoint " + "syncs by setting sync_on_checkpoint=False. Note that this " + "might result in faulty trial restoration for some worker " + "failure modes.") + return checkpoint.value def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. @@ -580,11 +564,13 @@ class RayTrialExecutor(TrialExecutor): if restoring on a different node. """ if checkpoint is None or checkpoint.value is None: - checkpoint = trial._checkpoint + checkpoint = trial.checkpoint if checkpoint is None or checkpoint.value is None: return True if trial.runner is None: - logger.error("Unable to restore - no runner.") + logger.error( + "Trial %s: Unable to restore - no runner. " + "Setting status to ERROR.", trial) self.set_status(trial, Trial.ERROR) return False try: @@ -593,23 +579,31 @@ class RayTrialExecutor(TrialExecutor): assert type(value) != Checkpoint, type(value) trial.runner.restore_from_object.remote(value) else: - # TODO: Somehow, the call to get the current IP on the - # remote actor can be very slow - a better fix would - # be to use an actor table to detect the IP of the Trainable - # and rsync the files there. - # See https://github.com/ray-project/ray/issues/5168 + logger.info("Trial %s: Attempting restoration from %s", trial, + checkpoint.value) with warn_if_slow("get_current_ip"): - worker_ip = ray.get(trial.runner.current_ip.remote()) + worker_ip = ray.get(trial.runner.current_ip.remote(), + DEFAULT_GET_TIMEOUT) with warn_if_slow("sync_to_new_location"): trial.sync_logger_to_new_location(worker_ip) with warn_if_slow("restore_from_disk"): - ray.get(trial.runner.restore.remote(value)) - trial.last_result = checkpoint.last_result - return True - except Exception: - logger.exception("Error restoring runner for Trial %s.", trial) + ray.get( + trial.runner.restore.remote(value), + DEFAULT_GET_TIMEOUT) + except RayTimeoutError: + logger.exception( + "Trial %s: Unable to restore - runner task timed " + "out. Setting status to ERROR", trial) self.set_status(trial, Trial.ERROR) return False + except Exception: + logger.exception( + "Trial %s: Unable to restore. Setting status to ERROR", trial) + self.set_status(trial, Trial.ERROR) + return False + + trial.last_result = checkpoint.result + return True def export_trial_if_needed(self, trial): """Exports model of this trial based on trial.export_formats. @@ -619,7 +613,8 @@ class RayTrialExecutor(TrialExecutor): """ if trial.export_formats and len(trial.export_formats) > 0: return ray.get( - trial.runner.export_model.remote(trial.export_formats)) + trial.runner.export_model.remote(trial.export_formats), + DEFAULT_GET_TIMEOUT) return {} def has_gpus(self): diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index ba54163e8..fde57b095 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -28,26 +28,142 @@ SYNC_PERIOD = 300 _syncers = {} -def validate_sync_string(sync_string): - if "{source}" not in sync_string: - raise ValueError("Sync template missing '{source}'.") - if "{target}" not in sync_string: - raise ValueError("Sync template missing '{target}'.") - - def wait_for_sync(): for syncer in _syncers.values(): syncer.wait() -class BaseSyncer(object): - def __init__(self, local_dir, remote_dir, sync_function=None): +class SyncClient(object): + def sync_up(self, source, target): + """Sync up from source to target. + + Args: + source (str): Source path. + target (str): Target path. + + Returns: + True if sync initiation successful, False otherwise. + """ + raise NotImplementedError + + def sync_down(self, source, target): + """Sync down from source to target. + + Args: + source (str): Source path. + target (str): Target path. + + Returns: + True if sync initiation successful, False otherwise. + """ + raise NotImplementedError + + def wait(self): + """Wait for current sync to complete, if asynchronously started.""" + pass + + def reset(self): + """Resets state.""" + pass + + +class FunctionBasedClient(SyncClient): + def __init__(self, sync_up_func, sync_down_func): + self.sync_up_func = sync_up_func + self.sync_down_func = sync_down_func + + def sync_up(self, source, target): + self.sync_up_func(source, target) + return True + + def sync_down(self, source, target): + self.sync_down_func(source, target) + return True + + +class CommandBasedClient(SyncClient): + def __init__(self, sync_up_template, sync_down_template): + """Syncs between two directories with the given command. + + Arguments: + sync_up_template (str): A runnable string template; needs to + include replacement fields '{source}' and '{target}'. + sync_down_template (str): A runnable string template; needs to + include replacement fields '{source}' and '{target}'. + """ + if not isinstance(sync_up_template, str): + raise ValueError("{} is not a string.".format(sync_up_template)) + if not isinstance(sync_down_template, str): + raise ValueError("{} is not a string.".format(sync_down_template)) + self._validate_sync_string(sync_up_template) + self._validate_sync_string(sync_down_template) + self.sync_up_template = sync_up_template + self.sync_down_template = sync_down_template + self.logfile = None + self.sync_process = None + + def set_logdir(self, logdir): + """Sets the directory to log sync execution output in. + + Args: + logdir: Log directory. + """ + self.logfile = tempfile.NamedTemporaryFile( + prefix="log_sync", dir=logdir, suffix=".log", delete=False) + + def sync_up(self, source, target): + return self.execute(self.sync_up_template, source, target) + + def sync_down(self, source, target): + return self.execute(self.sync_down_template, source, target) + + def execute(self, sync_template, source, target): + """Executes sync_template on source and target.""" + if self.sync_process: + self.sync_process.poll() + if self.sync_process.returncode is None: + logger.warning("Last sync is still in progress, skipping.") + return False + final_cmd = sync_template.format( + source=quote(source), target=quote(target)) + logger.debug("Running sync: {}".format(final_cmd)) + self.sync_process = subprocess.Popen( + final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile) + return True + + def wait(self): + if self.sync_process: + _, error_msg = self.sync_process.communicate() + error_msg = error_msg.decode("ascii") + code = self.sync_process.returncode + self.sync_process = None + if code != 0: + raise TuneError("Sync error ({}): {}".format(code, error_msg)) + + def reset(self): + if self.sync_process: + logger.warning("Sync process still running but resetting anyways.") + self.sync_process = None + + @staticmethod + def _validate_sync_string(sync_string): + if "{source}" not in sync_string: + raise ValueError("Sync template missing '{source}'.") + if "{target}" not in sync_string: + raise ValueError("Sync template missing '{target}'.") + + +NOOP = FunctionBasedClient(lambda s, t: None, lambda s, t: None) + + +class Syncer(object): + def __init__(self, local_dir, remote_dir, sync_client=NOOP): """Syncs between two directories with the sync_function. Arguments: local_dir (str): Directory to sync. Uniquely identifies the syncer. remote_dir (str): Remote directory to sync with. - sync_function (func): Function for syncing the local_dir to + sync_client (SyncClient): Client for syncing between local_dir and remote_dir. Defaults to a Noop. """ self._local_dir = (os.path.join(local_dir, "") @@ -55,32 +171,7 @@ class BaseSyncer(object): self._remote_dir = remote_dir self.last_sync_up_time = float("-inf") self.last_sync_down_time = float("-inf") - self._sync_function = sync_function or (lambda source, target: None) - - def sync_function(self, source, target): - """Executes sync between source and target. - - Can be overwritten by subclasses for custom sync procedures. - - Args: - source: Path to source file(s). - target: Path to target file(s). - """ - if self._sync_function: - return self._sync_function(source, target) - - def sync(self, source, target): - if not (source and target): - logger.debug( - "Source or target is empty, skipping log sync for {}".format( - self._local_dir)) - return - - try: - self.sync_function(source, target) - return True - except Exception: - logger.exception("Sync function failed.") + self.sync_client = sync_client def sync_up_if_needed(self): if time.time() - self.last_sync_up_time > SYNC_PERIOD: @@ -90,120 +181,86 @@ class BaseSyncer(object): if time.time() - self.last_sync_down_time > SYNC_PERIOD: self.sync_down() - def sync_down(self, *args, **kwargs): - self.sync(self._remote_path, self._local_dir, *args, **kwargs) - self.last_sync_down_time = time.time() + def sync_up(self): + """Attempts to start the sync-up to the remote path. - def sync_up(self, *args, **kwargs): - self.sync(self._local_dir, self._remote_path, *args, **kwargs) - self.last_sync_up_time = time.time() + Returns: + Whether the sync (if feasible) was successfully started. + """ + result = False + if self.validate_hosts(self._local_dir, self._remote_path): + try: + result = self.sync_client.sync_up(self._local_dir, + self._remote_path) + self.last_sync_up_time = time.time() + except Exception: + logger.exception("Sync execution failed.") + return result + + def sync_down(self): + """Attempts to start the sync-down from the remote path. + + Returns: + Whether the sync (if feasible) was successfully started. + """ + result = False + if self.validate_hosts(self._local_dir, self._remote_path): + try: + result = self.sync_client.sync_down(self._remote_path, + self._local_dir) + self.last_sync_down_time = time.time() + except Exception: + logger.exception("Sync execution failed.") + return result + + def validate_hosts(self, source, target): + if not (source and target): + logger.debug("Source or target is empty, skipping log sync for " + "{}".format(self._local_dir)) + return False + return True + + def wait(self): + """Waits for the sync client to complete the current sync.""" + self.sync_client.wait() def reset(self): self.last_sync_up_time = float("-inf") self.last_sync_down_time = float("-inf") - - def wait(self): - pass + self.sync_client.reset() @property def _remote_path(self): - """Protected method for accessing remote_dir. - - Can be overridden in subclass for custom path. - """ return self._remote_dir -class CommandSyncer(BaseSyncer): - def __init__(self, local_dir, remote_dir, sync_template): - """Syncs between two directories with the given command. - - Arguments: - local_dir (str): Directory to sync. - remote_dir (str): Remote directory to sync with. - sync_template (str): A string template - for syncer to run and needs to include replacement fields - '{source}' and '{target}'. Returned when using - `CommandSyncer.sync_template`, which can be overridden - by subclass. - """ - super(CommandSyncer, self).__init__(local_dir, remote_dir) - if not isinstance(sync_template, str): - raise ValueError("{} is not a string.".format(sync_template)) - validate_sync_string(sync_template) - self._sync_template = sync_template - self.logfile = tempfile.NamedTemporaryFile( - prefix="log_sync", - dir=self._local_dir, - suffix=".log", - delete=False) - - self.sync_process = None - - def sync_function(self, source, target): - self.last_sync_time = time.time() - if self.sync_process: - self.sync_process.poll() - if self.sync_process.returncode is None: - logger.warning("Last sync is still in progress, skipping.") - return - final_cmd = self._sync_template.format( - source=quote(source), target=quote(target)) - logger.debug("Running sync: {}".format(final_cmd)) - self.sync_process = subprocess.Popen( - final_cmd, shell=True, stdout=self.logfile) - return True - - def reset(self): - if self.sync_process: - logger.warning("Sync process still running but resetting anyways.") - self.sync_process = None - super(CommandSyncer, self).reset() - - def wait(self): - if self.sync_process: - self.sync_process.wait() - - -def _get_sync_cls(sync_function): - if not sync_function: - return - if isinstance(sync_function, types.FunctionType): - return BaseSyncer - elif isinstance(sync_function, str): - return CommandSyncer - else: - raise ValueError("Sync function {} must be string or function".format( - sync_function)) - - -def get_syncer(local_dir, remote_dir=None, sync_function=None): - """Returns a Syncer depending on given args. +def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None): + """Returns a Syncer. This syncer is in charge of syncing the local_dir with upload_dir. Args: - local_dir: Source directory for syncing. - remote_dir: Target directory for syncing. If None, - returns BaseSyncer with a noop. + local_dir (str): Source directory for syncing. + remote_dir (str): Target directory for syncing. If not provided, a + no-op Syncer is returned. sync_function (func | str): Function for syncing the local_dir to remote_dir. If string, then it must be a string template for syncer to run. If not provided, it defaults to standard S3 or gsutil sync commands. - """ + """ key = (local_dir, remote_dir) if key in _syncers: return _syncers[key] if not remote_dir: - _syncers[key] = BaseSyncer(local_dir, remote_dir) + _syncers[key] = Syncer(local_dir, remote_dir, NOOP) return _syncers[key] - sync_cls = _get_sync_cls(sync_function) + client = _get_sync_client(sync_function) - if sync_cls: - _syncers[key] = sync_cls(local_dir, remote_dir, sync_function) + if client: + _syncers[key] = Syncer(local_dir, remote_dir, client) return _syncers[key] if remote_dir.startswith(S3_PREFIX): @@ -211,15 +268,17 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None): raise TuneError( "Upload uri starting with '{}' requires awscli tool" " to be installed".format(S3_PREFIX)) - _syncers[key] = CommandSyncer(local_dir, remote_dir, - "aws s3 sync {source} {target}") + template = "aws s3 sync {source} {target}" + s3_client = CommandBasedClient(template, template) + _syncers[key] = Syncer(local_dir, remote_dir, s3_client) elif remote_dir.startswith(GS_PREFIX): if not distutils.spawn.find_executable("gsutil"): raise TuneError( "Upload uri starting with '{}' requires gsutil tool" " to be installed".format(GS_PREFIX)) - _syncers[key] = CommandSyncer(local_dir, remote_dir, - "gsutil rsync -r {source} {target}") + template = "gsutil rsync -r {source} {target}" + gs_client = CommandBasedClient(template, template) + _syncers[key] = Syncer(local_dir, remote_dir, gs_client) else: raise TuneError("Upload uri must start with one of: {}" "".format(ALLOWED_REMOTE_PREFIXES)) @@ -228,37 +287,48 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None): def get_log_syncer(local_dir, remote_dir=None, sync_function=None): - """Returns a Syncer depending on given args. - - This syncer is in charge of syncing the local_dir with remote local_dir. + """Returns a log Syncer. Args: - local_dir: Source directory for syncing. - remote_dir: Target directory for syncing. If None, - returns BaseSyncer with noop. - sync_function (func | str): Function for syncing the local_dir to + local_dir (str): Source directory for syncing. + remote_dir (str): Target directory for syncing. If not provided, a + no-op Syncer is returned. + sync_function (func|str): Function for syncing the local_dir to remote_dir. If string, then it must be a string template for syncer to run. If not provided, it defaults rsync. - """ + """ key = (local_dir, remote_dir) - if key in _syncers: return _syncers[key] - - sync_cls = None - if sync_function: - sync_cls = _get_sync_cls(sync_function) + elif not remote_dir: + sync_client = NOOP + elif sync_function: + sync_client = _get_sync_client(sync_function) else: - sync_cls = CommandSyncer - sync_function = log_sync_template() + sync_up = log_sync_template() + sync_down = log_sync_template(options="--remove-source-files") + if sync_up and sync_down: + sync_client = CommandBasedClient(sync_up, sync_down) + sync_client.set_logdir(local_dir) + else: + sync_client = NOOP - if not remote_dir or sync_function is None: - sync_cls = BaseSyncer - - class MixedSyncer(NodeSyncMixin, sync_cls): + class MixedSyncer(NodeSyncMixin, Syncer): def __init__(self, *args, **kwargs): - sync_cls.__init__(self, *args, **kwargs) + Syncer.__init__(self, *args, **kwargs) NodeSyncMixin.__init__(self) - _syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function) + _syncers[key] = MixedSyncer(local_dir, remote_dir, sync_client) return _syncers[key] + + +def _get_sync_client(sync_function): + if not sync_function: + return None + if isinstance(sync_function, types.FunctionType): + return FunctionBasedClient(sync_function, sync_function) + elif isinstance(sync_function, str): + return CommandBasedClient(sync_function, sync_function) + else: + raise ValueError("Sync function {} must be string or function".format( + sync_function)) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py new file mode 100644 index 000000000..fd1975476 --- /dev/null +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -0,0 +1,106 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import sys +import unittest + +from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager, logger + +if sys.version_info >= (3, 3): + from unittest.mock import patch +else: + from mock import patch + + +class CheckpointManagerTest(unittest.TestCase): + @staticmethod + def mock_result(i): + return {"i": i} + + def testOnCheckpointOrdered(self): + """ + Tests increasing priorities. Also tests that that the worst checkpoints + are deleted when necessary. + """ + keep_checkpoints_num = 2 + checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoints = [ + Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i)) + for i in range(3) + ] + + with patch("shutil.rmtree") as rmtree_mock, patch("os.path"): + for j in range(3): + checkpoint_manager.on_checkpoint(checkpoints[j]) + expected_deletes = 0 if j != 2 else 1 + self.assertEqual(rmtree_mock.call_count, expected_deletes) + self.assertEqual(checkpoint_manager.newest_checkpoint, + checkpoints[j]) + + best_checkpoints = checkpoint_manager.best_checkpoints() + self.assertEqual(len(best_checkpoints), keep_checkpoints_num) + self.assertIn(checkpoints[1], best_checkpoints) + self.assertIn(checkpoints[2], best_checkpoints) + + def testOnCheckpointUnordered(self): + """ + Tests priorities that aren't inserted in ascending order. Also tests + that the worst checkpoints are deleted when necessary. + """ + keep_checkpoints_num = 2 + checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoints = [ + Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i)) + for i in range(3, -1, -1) + ] + + with patch("shutil.rmtree") as rmtree_mock, patch("os.path"): + for j in range(0, len(checkpoints)): + checkpoint_manager.on_checkpoint(checkpoints[j]) + expected_deletes = 0 if j != 3 else 1 + self.assertEqual(rmtree_mock.call_count, expected_deletes) + self.assertEqual(checkpoint_manager.newest_checkpoint, + checkpoints[j]) + + best_checkpoints = checkpoint_manager.best_checkpoints() + self.assertEqual(len(best_checkpoints), keep_checkpoints_num) + self.assertIn(checkpoints[0], best_checkpoints) + self.assertIn(checkpoints[1], best_checkpoints) + + def testBestCheckpoints(self): + """ + Tests that the best checkpoints are tracked and ordered correctly. + """ + keep_checkpoints_num = 4 + checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoints = [ + Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i)) + for i in range(16) + ] + random.shuffle(checkpoints) + + for checkpoint in checkpoints: + checkpoint_manager.on_checkpoint(checkpoint) + + best_checkpoints = checkpoint_manager.best_checkpoints() + self.assertEqual(len(best_checkpoints), keep_checkpoints_num) + for i in range(len(best_checkpoints)): + self.assertEqual(best_checkpoints[i].value, i + 12) + + def testOnCheckpointUnavailableAttribute(self): + """ + Tests that an error is logged when the associated result of the + checkpoint has no checkpoint score attribute. + """ + keep_checkpoints_num = 1 + checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + + no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {}) + with patch.object(logger, "error") as log_error_mock: + checkpoint_manager.on_checkpoint(no_attr_checkpoint) + log_error_mock.assert_called_once() + # The newest checkpoint should still be set despite this error. + assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 3dc2fd2b0..bc88f5743 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -336,7 +336,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() - shutil.rmtree(os.path.dirname(t1._checkpoint.value)) + shutil.rmtree(os.path.dirname(t1.checkpoint.value)) runner.step() # Recovery step for i in range(3): @@ -569,7 +569,7 @@ tune.run( ray.shutdown() cluster.shutdown() cluster = _start_new_cluster() - Experiment._register_if_needed(_Mock) + Experiment.register_if_needed(_Mock) # Inspect the internal trialrunner runner = TrialRunner( diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 9b208985e..6c89ebce1 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -26,6 +26,7 @@ from ray.tune.result import ( EPISODES_TOTAL, TRAINING_ITERATION, TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S, TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG) from ray.tune.logger import Logger +from ray.tune.syncer import CommandBasedClient from ray.tune.util import pin_in_object_store, get_pinned_object, flatten_dict from ray.tune.experiment import Experiment from ray.tune.trial import Trial, ExportFormat @@ -1124,21 +1125,20 @@ class TestSyncFunctionality(unittest.TestCase): "sync_to_driver": "ls {source}" }).trials - with patch("ray.tune.syncer.CommandSyncer.sync_function" - ) as mock_fn, patch( - "ray.services.get_node_ip_address") as mock_sync: - mock_sync.return_value = "0.0.0.0" - [trial] = tune.run( - "__fake", - name="foo", - max_failures=0, - **{ - "stop": { - "training_iteration": 1 - }, - "sync_to_driver": "echo {source} {target}" - }).trials - self.assertGreater(mock_fn.call_count, 0) + with patch.object(CommandBasedClient, "execute") as mock_fn: + with patch("ray.services.get_node_ip_address") as mock_sync: + mock_sync.return_value = "0.0.0.0" + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "sync_to_driver": "echo {source} {target}" + }).trials + self.assertGreater(mock_fn.call_count, 0) def testCloudFunctions(self): tmpdir = tempfile.mkdtemp() @@ -1166,7 +1166,7 @@ class TestSyncFunctionality(unittest.TestCase): def testClusterSyncFunction(self): def sync_func_driver(source, target): - assert ":" in source, "Source not a remote path." + assert ":" in source, "Source {} not a remote path.".format(source) assert ":" not in target, "Target is supposed to be local." with open(os.path.join(target, "test.log2"), "w") as f: print("writing to", f.name) @@ -1203,7 +1203,7 @@ class TestSyncFunctionality(unittest.TestCase): def sync_func(source, target): pass - with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync: + with patch.object(CommandBasedClient, "execute") as mock_sync: [trial] = tune.run( "__fake", name="foo", diff --git a/python/ray/tune/track/session.py b/python/ray/tune/track/session.py index 65301afde..6f02c14cd 100644 --- a/python/ray/tune/track/session.py +++ b/python/ray/tune/track/session.py @@ -103,6 +103,7 @@ class TrackSession(object): self.trial_config["end_time"] = datetime.now().isoformat() # TODO(rliaw): Have Tune support updated configs self._logger.update_config(self.trial_config) + self._logger.flush() self._logger.close() @property diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 8bd24767c..988b4b6c5 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -70,7 +70,6 @@ class Trainable(object): self._experiment_id = uuid.uuid4().hex self.config = config or {} - log_sys_usage = self.config.get("log_sys_usage", False) if logger_creator: self._result_logger = logger_creator(self.config) @@ -92,6 +91,7 @@ class Trainable(object): self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False + start_time = time.time() self._setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time @@ -101,6 +101,7 @@ class Trainable(object): "reuse_actors=True to reduce actor creation " "overheads.".format(setup_time)) self._local_ip = ray.services.get_node_ip_address() + log_sys_usage = self.config.get("log_sys_usage", False) self._monitor = UtilMonitor(start=log_sys_usage) @classmethod @@ -112,11 +113,11 @@ class Trainable(object): Example: >>> def default_resource_request(cls, config): - return Resources( - cpu=0, - gpu=0, - extra_cpu=config["workers"], - extra_gpu=int(config["use_gpu"]) * config["workers"]) + >>> return Resources( + >>> cpu=0, + >>> gpu=0, + >>> extra_cpu=config["workers"], + >>> extra_gpu=int(config["use_gpu"]) * config["workers"]) """ return None @@ -171,7 +172,6 @@ class Trainable(object): Returns: A dict that describes training progress. """ - start = time.time() result = self._train() assert isinstance(result, dict), "_train() needs to return a dict." @@ -239,17 +239,6 @@ class Trainable(object): return result - def delete_checkpoint(self, checkpoint_dir): - """Removes subdirectory within checkpoint_folder - - Args: - checkpoint_dir : path to checkpoint - """ - if os.path.isfile(checkpoint_dir): - shutil.rmtree(os.path.dirname(checkpoint_dir)) - else: - shutil.rmtree(checkpoint_dir) - def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. @@ -262,9 +251,9 @@ class Trainable(object): Returns: Checkpoint path or prefix that may be passed to restore(). """ - checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, "checkpoint_{}".format(self._iteration)) + if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint = self._save(checkpoint_dir) @@ -293,7 +282,7 @@ class Trainable(object): "time_total": self._time_total, "episodes_total": self._episodes_total, "saved_as_dict": saved_as_dict, - "ray_version": ray.__version__ + "ray_version": ray.__version__, }, f) return checkpoint_path @@ -305,10 +294,8 @@ class Trainable(object): Returns: Object holding checkpoint data. """ - tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) - # Save all files in subtree. data = {} for basedir, _, file_names in os.walk(tmpdir): @@ -356,7 +343,7 @@ class Trainable(object): self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True - logger.info("Restored from checkpoint: {}".format(checkpoint_path)) + logger.info("Restored from checkpoint: %s", checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, @@ -398,7 +385,7 @@ class Trainable(object): export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. - Return: + Returns: A dict that maps ExportFormats to successfully exported models. """ export_dir = export_dir or self.logdir @@ -422,7 +409,7 @@ class Trainable(object): def stop(self): """Releases all resources used by this trainable.""" - + self._result_logger.flush() self._result_logger.close() self._stop() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 41900c71b..e5d28e136 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -12,6 +12,7 @@ import tempfile import os from numbers import Number from ray.tune import TuneError +from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager from ray.tune.logger import pretty_print, UnifiedLogger from ray.tune.util import flatten_dict # NOTE(rkn): We import ray.tune.registry here instead of importing the names we @@ -31,29 +32,20 @@ def date_str(): return datetime.today().strftime("%Y-%m-%d_%H-%M-%S") -class Checkpoint(object): - """Describes a checkpoint of trial state. +class Location(object): + """Describes the location at which Trial is placed to run.""" - Checkpoint may be saved in different storage. + def __init__(self, hostname=None, pid=None): + self.hostname = hostname + self.pid = pid - Attributes: - storage (str): Storage type. - value (str): If storage==MEMORY,value is a Python object. - If storage==DISK,value is a path points to the checkpoint in disk. - """ - - MEMORY = "memory" - DISK = "disk" - - def __init__(self, storage, value, last_result=None): - self.storage = storage - self.value = value - self.last_result = last_result or {} - - @staticmethod - def from_object(value=None): - """Creates a checkpoint from a Python object.""" - return Checkpoint(Checkpoint.MEMORY, value) + def __str__(self): + if not self.pid: + return "" + elif self.hostname == os.uname()[1]: + return "pid={}".format(self.pid) + else: + return "{}:{}".format(self.hostname, self.pid) class ExportFormat(object): @@ -108,8 +100,9 @@ class Trial(object): stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, + sync_on_checkpoint=True, keep_checkpoints_num=None, - checkpoint_score_attr="", + checkpoint_score_attr=TRAINING_ITERATION, export_formats=None, restore_path=None, trial_name_creator=None, @@ -121,7 +114,6 @@ class Trial(object): The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ - validate_trainable(trainable_name) # Trial config self.trainable_name = trainable_name @@ -145,6 +137,7 @@ class Trial(object): "clear the `resources_per_trial` option.".format( trainable_cls, default_resources)) resources = default_resources + self.address = Location() self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.loggers = loggers @@ -161,17 +154,12 @@ class Trial(object): # stores in memory max/min/last result for each metric by trial self.metric_analysis = {} - self.history = [] - self.keep_checkpoints_num = keep_checkpoints_num - self._cmp_greater = not checkpoint_score_attr.startswith("min-") - self.best_checkpoint_attr_value = -float("inf") \ - if self._cmp_greater else float("inf") - # Strip off "min-" from checkpoint attribute - self.checkpoint_score_attr = checkpoint_score_attr \ - if self._cmp_greater else checkpoint_score_attr[4:] + self.sync_on_checkpoint = sync_on_checkpoint + newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path) + self.checkpoint_manager = CheckpointManager(keep_checkpoints_num, + checkpoint_score_attr) + self.checkpoint_manager.newest_checkpoint = newest_checkpoint - self._checkpoint = Checkpoint( - storage=Checkpoint.DISK, value=restore_path) self.export_formats = export_formats self.status = Trial.PENDING self.logdir = None @@ -190,7 +178,7 @@ class Trial(object): self.extra_arg = None self._nonjson_fields = [ - "_checkpoint", + "checkpoint", "loggers", "sync_to_driver_fn", "results", @@ -201,6 +189,14 @@ class Trial(object): if trial_name_creator: self.custom_trial_name = trial_name_creator(self) + @property + def node_ip(self): + return self.address.hostname + + @property + def checkpoint(self): + return self.checkpoint_manager.newest_checkpoint + @classmethod def generate_id(cls): return str(uuid.uuid1().hex)[:8] @@ -249,10 +245,14 @@ class Trial(object): """ if self.result_logger: self.result_logger.sync_results_to_new_location(worker_ip) + self.set_location(Location(worker_ip)) + + def set_location(self, location): + """Sets the location of the trial.""" + self.address = location def close_logger(self): - """Close logger.""" - + """Closes logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None @@ -262,8 +262,9 @@ class Trial(object): self.num_failures += 1 # may be moved to outer scope? error_file = os.path.join(self.logdir, "error_{}.txt".format(date_str())) - with open(error_file, "w") as f: - f.write(error_msg) + with open(error_file, "a+") as f: + f.write("Failure # {}".format(self.num_failures) + "\n") + f.write(error_msg + "\n") self.error_file = error_file self.error_msg = error_msg @@ -292,21 +293,36 @@ class Trial(object): def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} - if result.get(DONE) and self.checkpoint_at_end: return True - - if self.checkpoint_freq: - return result.get(TRAINING_ITERATION, - 0) % self.checkpoint_freq == 0 - else: - return False + return (self.checkpoint_freq and + result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0) def has_checkpoint(self): - return self._checkpoint.value is not None + return self.checkpoint.value is not None def clear_checkpoint(self): - self._checkpoint.value = None + self.checkpoint.value = None + + def on_checkpoint(self, checkpoint): + """Hook for handling checkpoints taken by the Trainable. + + Args: + checkpoint (Checkpoint): Checkpoint taken. + """ + if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK: + # Wait for any other syncs to finish. We need to sync again after + # this to handle checkpoints taken mid-sync. + self.result_logger.wait() + # Force sync down and wait before tracking the new checkpoint. This + # prevents attempts to restore from partially synced checkpoints. + if self.result_logger.sync_down(): + self.result_logger.wait() + else: + logger.error( + "Trial %s: Checkpoint sync skipped. " + "This should not happen.", self) + self.checkpoint_manager.on_checkpoint(checkpoint) def should_recover(self): """Returns whether the trial qualifies for retrying. @@ -327,6 +343,7 @@ class Trial(object): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() + self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) @@ -345,28 +362,6 @@ class Trial(object): value, self.metric_analysis[metric]["min"]) self.metric_analysis[metric]["last"] = value - def compare_checkpoints(self, attr_mean): - """Compares two checkpoints based on the attribute attr_mean param. - Greater than is used by default. If command-line parameter - checkpoint_score_attr starts with "min-" less than is used. - - Arguments: - attr_mean: mean of attribute value for the current checkpoint - - Returns: - True: when attr_mean is greater than previous checkpoint attr_mean - and greater than function is selected - when attr_mean is less than previous checkpoint attr_mean and - less than function is selected - False: when attr_mean is not in alignment with selected cmp fn - """ - if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value: - return True - elif (not self._cmp_greater - and attr_mean < self.best_checkpoint_attr_value): - return True - return False - def get_trainable_cls(self): return get_trainable_cls(self.trainable_name) @@ -376,10 +371,6 @@ class Trial(object): def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] - @property - def node_ip(self): - return self.last_result.get("node_ip") - def __repr__(self): return str(self) @@ -407,7 +398,7 @@ class Trial(object): Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ - assert self._checkpoint.storage == Checkpoint.DISK, ( + assert self.checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 27de24f5a..d6bd529a3 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -6,6 +6,7 @@ from __future__ import print_function import logging from ray.tune.trial import Trial, Checkpoint +from ray.tune.error import TuneError logger = logging.getLogger(__name__) @@ -38,6 +39,8 @@ class TrialExecutor(object): trial (Trial): Trial to checkpoint. status (Trial.status): Status to set trial to. """ + logger.debug("Trial %s: Changing status from %s to %s.", trial, + trial.status, status) trial.status = status if status in [Trial.TERMINATED, Trial.ERROR]: self.try_checkpoint_metadata(trial) @@ -48,14 +51,16 @@ class TrialExecutor(object): Args: trial (Trial): Trial to checkpoint. """ - if trial._checkpoint.storage == Checkpoint.MEMORY: - logger.debug("Not saving data for trial w/ memory checkpoint.") + if trial.checkpoint.storage == Checkpoint.MEMORY: + logger.debug("Trial %s: Not saving data for memory checkpoint.", + trial) return try: - logger.debug("Saving trial metadata.") + logger.debug("Trial %s: Saving trial metadata.", trial) self._cached_trial_state[trial.trial_id] = trial.__getstate__() except Exception: - logger.exception("Error checkpointing trial metadata.") + logger.exception("Trial %s: Error checkpointing trial metadata.", + trial) def get_checkpoints(self): """Returns a copy of mapping of the trial ID to pickled metadata.""" @@ -118,7 +123,6 @@ class TrialExecutor(object): def resume_trial(self, trial): """Resumes PAUSED trials. This is a blocking call.""" - assert trial.status == Trial.PAUSED, trial.status self.start_trial(trial) @@ -150,6 +154,27 @@ class TrialExecutor(object): """A hook called after running one step of the trial event loop.""" pass + def on_no_available_trials(self, trial_runner): + if self._queue_trials: + return + for trial in trial_runner.get_trials(): + if trial.status == Trial.PENDING: + if not self.has_resources(trial.resources): + raise TuneError( + ("Insufficient cluster resources to launch trial: " + "trial requested {} but the cluster has only {}. " + "Pass `queue_trials=True` in " + "ray.tune.run() or on the command " + "line to queue trials until the cluster scales " + "up or resources become available. {}").format( + trial.resources.summary_string(), + self.resource_string(), + trial.get_trainable_cls().resource_help( + trial.config))) + elif trial.status == Trial.PAUSED: + raise TuneError("There are paused trials, but no more pending " + "trials with sufficient resources.") + def get_next_available_trial(self): """Blocking call that waits until one result is ready. @@ -188,7 +213,7 @@ class TrialExecutor(object): def restore(self, trial, checkpoint=None): """Restores training state from a checkpoint. - If checkpoint is None, try to restore from trial._checkpoint. + If checkpoint is None, try to restore from trial.checkpoint. If restoring fails, the trial status will be set to ERROR. Args: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 6bfe77cb6..cf749fb2e 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -17,8 +17,8 @@ from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, SHOULD_CHECKPOINT) -from ray.tune.syncer import get_syncer -from ray.tune.trial import Trial, Checkpoint +from ray.tune.syncer import get_cloud_syncer +from ray.tune.trial import Checkpoint, Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest import BasicVariantGenerator from ray.tune.util import warn_if_slow, flatten_dict @@ -126,7 +126,7 @@ class TrialRunner(object): global checkpoints are stored and restored from. Used if `resume` == REMOTE. resume (str|False): see `tune.py:run`. - sync_to_cloud (func|str): see `tune.py:run`. + sync_to_cloud (func|str): See `tune.py:run`. server_port (int): Port number for launching TuneServer. verbose (bool): Flag for verbosity. If False, trial results will not be output. @@ -158,8 +158,8 @@ class TrialRunner(object): os.makedirs(self._local_checkpoint_dir) self._remote_checkpoint_dir = remote_checkpoint_dir - self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir, - sync_to_cloud) + self._syncer = get_cloud_syncer(local_checkpoint_dir, + remote_checkpoint_dir, sync_to_cloud) self._resumed = False @@ -216,8 +216,7 @@ class TrialRunner(object): "Called resume from remote without remote directory.") # Try syncing down the upload directory. - logger.info("Downloading from {}".format( - self._remote_checkpoint_dir)) + logger.info("Downloading from %s", self._remote_checkpoint_dir) self._syncer.sync_down_if_needed() if not self.checkpoint_exists(self._local_checkpoint_dir): @@ -334,24 +333,7 @@ class TrialRunner(object): elif self.trial_executor.get_running_trials(): self._process_events() # blocking else: - for trial in self._trials: - if trial.status == Trial.PENDING: - if not self.has_resources(trial.resources): - raise TuneError( - ("Insufficient cluster resources to launch trial: " - "trial requested {} but the cluster has only {}. " - "Pass `queue_trials=True` in " - "ray.tune.run() or on the command " - "line to queue trials until the cluster scales " - "up. {}").format( - trial.resources.summary_string(), - self.trial_executor.resource_string(), - trial.get_trainable_cls().resource_help( - trial.config))) - elif trial.status == Trial.PAUSED: - raise TuneError( - "There are paused trials, but no more pending " - "trials with sufficient resources.") + self.trial_executor.on_no_available_trials(self) try: with warn_if_slow("experiment_checkpoint"): @@ -422,12 +404,12 @@ class TrialRunner(object): def _process_events(self): failed_trial = self.trial_executor.get_next_failed_trial() if failed_trial: + error_msg = ( + "{} (IP: {}) detected as stale. This is likely because the " + "node was lost").format(failed_trial, failed_trial.node_ip) + logger.info(error_msg) with warn_if_slow("process_failed_trial"): - self._process_trial_failure( - failed_trial, - error_msg="{} (ip: {}) detected as stale. This is likely" - "because the node was lost".format(failed_trial, - failed_trial.node_ip)) + self._process_trial_failure(failed_trial, error_msg=error_msg) else: trial = self.trial_executor.get_next_available_trial() # blocking with warn_if_slow("process_trial"): @@ -537,12 +519,17 @@ class TrialRunner(object): stop_logger=False) trial.result_logger.flush() if self.trial_executor.has_resources(trial.resources): - logger.info("Attempting to recover trial.") + logger.info( + "Trial %s: Attempting to recover " + "trial state from last checkpoint.", trial) self.trial_executor.start_trial(trial) if trial.status == Trial.ERROR: + logger.error("Trial %s: Did not start correctly.", trial) raise RuntimeError("Trial did not start correctly.") + logger.debug("Trial %s: Started correctly.", trial) else: - logger.debug("Notifying Scheduler and requeueing trial.") + logger.debug("Trial %s: Notifying Scheduler and requeueing.", + trial) self._requeue_trial(trial) except Exception: logger.exception("Error recovering trial from checkpoint, abort.") diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 883eef780..08a8f7bb7 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -69,6 +69,7 @@ def run(run_or_experiment, sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, + sync_on_checkpoint=True, keep_checkpoints_num=None, checkpoint_score_attr=None, global_checkpoint_period=10, @@ -118,18 +119,18 @@ def run(run_or_experiment, `num_samples` of times. local_dir (str): Local dir to save training results to. Defaults to ``~/ray_results``. - upload_dir (str): Optional URI to sync training results - to (e.g. ``s3://bucket``). + upload_dir (str): Optional URI to sync training results to + (e.g. ``s3://bucket``). trial_name_creator (func): Optional function for generating the trial string representation. loggers (list): List of logger creators to be used with each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS. See `ray/tune/logger.py`. sync_to_cloud (func|str): Function for syncing the local_dir to and - from upload_dir. If string, then it must be a string template - that includes `{source}` and `{target}` for the syncer to run. - If not provided, the sync command defaults to standard - S3 or gsutil sync comamnds. + from upload_dir. If string, then it must be a string template that + includes `{source}` and `{target}` for the syncer to run. If not + provided, the sync command defaults to standard S3 or gsutil sync + comamnds. sync_to_driver (func|str): Function for syncing trial logdir from remote node to local. If string, then it must be a string template that includes `{source}` and `{target}` for the syncer to run. @@ -138,6 +139,11 @@ def run(run_or_experiment, checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. + sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to + guarantee recoverability. If set to False, checkpoint syncing from + worker to driver is asynchronous. Set this to False only if + synchronous checkpointing is too slow and trial restoration + failures can be tolerated. Defaults to True. keep_checkpoints_num (int): Number of checkpoints to keep. A value of `None` keeps all checkpoints. Defaults to `None`. If set, need to provide `checkpoint_score_attr`. @@ -223,7 +229,7 @@ def run(run_or_experiment, "not work with certain features.") for i, exp in enumerate(experiments): if not isinstance(exp, Experiment): - run_identifier = Experiment._register_if_needed(exp) + run_identifier = Experiment.register_if_needed(exp) experiments[i] = Experiment( name=name, run=run_identifier, @@ -238,6 +244,7 @@ def run(run_or_experiment, loggers=loggers, checkpoint_freq=checkpoint_freq, checkpoint_at_end=checkpoint_at_end, + sync_on_checkpoint=sync_on_checkpoint, keep_checkpoints_num=keep_checkpoints_num, checkpoint_score_attr=checkpoint_score_attr, export_formats=export_formats, diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index e0ade8b12..99c1eace7 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -119,18 +119,25 @@ class warn_if_slow(object): ... ray.get(something) """ - def __init__(self, name): + DEFAULT_THRESHOLD = 0.5 + + def __init__(self, name, threshold=None): self.name = name + self.threshold = threshold or self.DEFAULT_THRESHOLD + self.too_slow = False def __enter__(self): self.start = time.time() + return self def __exit__(self, type, value, traceback): now = time.time() - if now - self.start > 0.5 and now - START_OF_TIME > 60.0: - logger.warning("The `{}` operation took {} seconds to complete, ". - format(self.name, now - self.start) + - "which may be a performance bottleneck.") + if now - self.start > self.threshold and now - START_OF_TIME > 60.0: + self.too_slow = True + logger.warning( + "The `%s` operation took %s seconds to complete, " + "which may be a performance bottleneck.", self.name, + now - self.start) def merge_dicts(d1, d2): diff --git a/python/ray/worker.py b/python/ray/worker.py index 3809fb2d2..0903c9db1 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -311,10 +311,6 @@ class Worker(object): "which is not an ray.ObjectID.".format(object_id)) if self.mode == LOCAL_MODE: - # TODO(ujvl): Remove check when local mode moved to core worker. - if timeout is not None: - raise ValueError( - "`get` must be called with timeout=None in local mode.") return self.local_mode_manager.get_objects(object_ids) timeout_ms = int(timeout * 1000) if timeout else -1 @@ -1407,8 +1403,8 @@ def get(object_ids, timeout=None): Args: object_ids: Object ID of the object to get or a list of object IDs to get. - timeout (float): The maximum amount of time in seconds to wait before - returning. + timeout (Optional[float]): The maximum amount of time in seconds to + wait before returning. Returns: A Python object or a list of Python objects.