diff --git a/doc/source/tune-package-ref.rst b/doc/source/tune-package-ref.rst index c2c24f396..7b7221385 100644 --- a/doc/source/tune-package-ref.rst +++ b/doc/source/tune-package-ref.rst @@ -7,12 +7,14 @@ ray.tune .. automodule:: ray.tune :members: :show-inheritance: - :exclude-members: TuneError, Trainable + :exclude-members: TuneError, Trainable, DurableTrainable .. autoclass:: ray.tune.Trainable :members: :private-members: +.. autoclass:: ray.tune.DurableTrainable + .. autoclass:: ray.tune.function_runner.StatusReporter :members: __call__, logdir diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index c5fb221a0..41d547720 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -148,7 +148,14 @@ py_test( deps = [":tune_lib"], tags = ["exclusive"], ) - + +py_test( + name = "test_trainable_util", + size = "small", + srcs = ["tests/test_trainable_util.py"], + deps = [":tune_lib"], +) + py_test( name = "test_trial_scheduler", size = "medium", diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index b8bfea52e..0488109c5 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -8,12 +8,14 @@ from ray.tune.experiment import Experiment from ray.tune.analysis import ExperimentAnalysis, Analysis from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable +from ray.tune.durable_trainable import DurableTrainable from ray.tune.suggest import grid_search from ray.tune.sample import (function, sample_from, uniform, choice, randint, randn, loguniform) __all__ = [ "Trainable", + "DurableTrainable", "TuneError", "grid_search", "register_env", diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index c386d34d3..36ac41d13 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -5,8 +5,6 @@ from __future__ import print_function import heapq import logging -import os -import shutil try: FileNotFoundError @@ -23,31 +21,18 @@ class Checkpoint: Attributes: storage (str): Storage type. - value (str): If storage==MEMORY, value is a Python object. - If storage==DISK, value is a path points to the checkpoint in disk. + value (str): If storage==MEMORY, it is a Python object. + If storage==PERSISTENT, it is a path to persistent storage. """ MEMORY = "memory" - DISK = "disk" + PERSISTENT = "persistent" def __init__(self, storage, value, result=None): self.storage = storage self.value = value self.result = result or {} - def delete(self): - """Deletes checkpoint data if disk checkpoint.""" - if self.storage == Checkpoint.DISK and self.value: - checkpoint_dir = self.value - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError( - "Attempted to delete checkpoint at {} but " - "path was not found.".format(checkpoint_dir)) - elif os.path.isfile(checkpoint_dir): - shutil.rmtree(os.path.dirname(checkpoint_dir)) - else: - shutil.rmtree(checkpoint_dir) - @staticmethod def from_object(value=None): """Creates a checkpoint from a Python object.""" @@ -72,13 +57,15 @@ class QueueItem: class CheckpointManager: """Manages checkpoints on the driver for a trial.""" - def __init__(self, keep_checkpoints_num, checkpoint_score_attr): + def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn): """Initializes a new CheckpointManager. Args: keep_checkpoints_num (int): Keep at least this many checkpoints. checkpoint_score_attr (str): Attribute to use to determine which checkpoints to keep. + delete_fn (function): Function that deletes checkpoints. Must be + idempotent. """ self.keep_checkpoints_num = keep_checkpoints_num or float("inf") assert self.keep_checkpoints_num > 0, ( @@ -88,7 +75,7 @@ class CheckpointManager: self._checkpoint_score_attr = checkpoint_score_attr[4:] else: self._checkpoint_score_attr = checkpoint_score_attr - + self.delete = delete_fn self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None) self._best_checkpoints = [] self._membership = set() @@ -101,9 +88,6 @@ class CheckpointManager: Args: checkpoint (Checkpoint): Trial state checkpoint. - - Raises: - KeyError if checkpoint_score_attr not in result of checkpoint. """ old_checkpoint = self.newest_checkpoint self.newest_checkpoint = checkpoint @@ -112,7 +96,7 @@ class CheckpointManager: queue_item = QueueItem(self._priority(checkpoint), checkpoint) except KeyError: if old_checkpoint not in self._membership: - old_checkpoint.delete() + self.delete(old_checkpoint) logger.error("Result dict has no key: {}. " "checkpoint_score_attr must be set to a key in the " "result dict.".format(self._checkpoint_score_attr)) @@ -126,11 +110,11 @@ class CheckpointManager: self._membership.add(checkpoint) if worst in self._membership: self._membership.remove(worst) - worst.delete() + self.delete(worst) # Remove the old checkpoint if it isn't one of the best ones. - if old_checkpoint not in self._membership: - old_checkpoint.delete() + if old_checkpoint.value and old_checkpoint not in self._membership: + self.delete(old_checkpoint) def best_checkpoints(self): """Returns best checkpoints, sorted by score.""" @@ -140,3 +124,13 @@ class CheckpointManager: def _priority(self, checkpoint): priority = checkpoint.result[self._checkpoint_score_attr] return -priority if self._checkpoint_score_desc else priority + + def __getstate__(self): + state = self.__dict__.copy() + # Avoid serializing lambda since it may capture cyclical dependencies. + state.pop("delete") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.delete = None diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index b4a797385..9ba4889ff 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -183,6 +183,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): local_dir=os.path.join(spec["local_dir"], output_path), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), + remote_checkpoint_dir=spec.get("remote_checkpoint_dir"), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, sync_on_checkpoint=not args.no_sync_on_checkpoint, diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py new file mode 100644 index 000000000..987551728 --- /dev/null +++ b/python/ray/tune/durable_trainable.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from ray.tune.trainable import Trainable, TrainableUtil +from ray.tune.syncer import get_cloud_sync_client + + +class DurableTrainable(Trainable): + """Abstract class for a remote-storage backed fault-tolerant Trainable. + + Supports checkpointing to and restoring from remote storage. To use this + class, implement the same private methods as ray.tune.Trainable (`_save`, + `_train`, `_restore`, `reset_config`, `_setup`, `_stop`). + + .. warning:: This class is currently **experimental** and may + be subject to change. + + Run this with Tune as follows. Setting `sync_to_driver=False` disables + syncing to the driver to avoid keeping redundant checkpoints around, as + well as preventing the driver from syncing up the same checkpoint. + + See ``tune/trainable.py``. + + Attributes: + remote_checkpoint_dir (str): Upload directory (S3 or GS path). + storage_client: Tune-internal interface for interacting with external + storage. + + >>> tune.run(MyDurableTrainable, sync_to_driver=False) + """ + + def __init__(self, remote_checkpoint_dir, *args, **kwargs): + """Initializes a DurableTrainable. + + Args: + remote_checkpoint_dir (str): Upload directory (S3 or GS path). + """ + super(DurableTrainable, self).__init__(*args, **kwargs) + self.remote_checkpoint_dir = remote_checkpoint_dir + self.storage_client = self._create_storage_client() + + def save(self, checkpoint_dir=None): + """Saves the current model state to a checkpoint, persisted remotely. + + The storage client must provide durability for + restoration to work. That is, once ``storage.client.wait()`` + returns after a checkpoint `sync up`, the checkpoint is considered + committed and can be used to restore the trainable. + + Args: + checkpoint_dir (Optional[str]): Optional dir to place the + checkpoint. Must be ``logdir`` or a sub-directory. + + Returns: + Checkpoint path or prefix that may be passed to restore(). + """ + if checkpoint_dir: + if checkpoint_dir.starts_with(os.path.abspath(self.logdir)): + raise ValueError("`checkpoint_dir` must be `self.logdir`, or " + "a sub-directory.") + + checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir) + self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir) + self.storage_client.wait() + return checkpoint_path + + def restore(self, checkpoint_path): + """Restores training state from a given checkpoint persisted remotely. + + These checkpoints are returned from calls to save(). + + Args: + checkpoint_path (str): Local path to checkpoint. + """ + self.storage_client.sync_down(self.remote_checkpoint_dir, self.logdir) + self.storage_client.wait() + super(DurableTrainable, self).restore(checkpoint_path) + + def delete_checkpoint(self, checkpoint_path): + """Deletes checkpoint from both local and remote storage. + + Args: + checkpoint_path (str): Local path to checkpoint. + """ + super(DurableTrainable, self).delete_checkpoint(checkpoint_path) + local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) + self.storage_client.delete(self._storage_path(local_dirpath)) + + def _create_storage_client(self): + """Returns a storage client.""" + return get_cloud_sync_client(self.remote_checkpoint_dir) + + def _storage_path(self, local_path): + rel_local_path = os.path.relpath(local_path, self.logdir) + return os.path.join(self.remote_checkpoint_dir, rel_local_path) diff --git a/python/ray/tune/examples/durable_trainable_example.py b/python/ray/tune/examples/durable_trainable_example.py new file mode 100644 index 000000000..4e13abf34 --- /dev/null +++ b/python/ray/tune/examples/durable_trainable_example.py @@ -0,0 +1,126 @@ +import argparse +import numpy as np +import time +import logging +import os +import ray +from ray import tune +from ray.tune import DurableTrainable +from ray.tune.sync_client import get_sync_client + +import cloudpickle + +logger = logging.getLogger(__name__) + + +class MockDurableTrainable(DurableTrainable): + """Mocks the storage client on initialization to store data locally.""" + + def __init__(self, remote_checkpoint_dir, *args, **kwargs): + # Mock the path as a local path. + local_dir_suffix = remote_checkpoint_dir.split("://")[1] + remote_checkpoint_dir = os.path.join("/tmp", local_dir_suffix) + # Disallow malformed relative paths for delete safety. + assert os.path.abspath(remote_checkpoint_dir).startswith("/tmp") + logger.info("Using %s as the mocked remote checkpoint directory.", + self.remote_checkpoint_dir) + super(MockDurableTrainable, self).__init__(remote_checkpoint_dir, + *args, **kwargs) + + def _create_storage_client(self): + sync = "mkdir -p {target} && rsync -avz {source} {target}" + delete = "rm -rf {target}" + return get_sync_client(sync, delete) + + +class OptimusFn(object): + def __init__(self, params, max_t=10000): + self.params = params + self.noise = np.random.normal(size=max_t) * 0.005 + + def eval(self, k, add_noise=True): + b0, b1, b2 = self.params + score = (b0 * k / 100 + 0.1 * b1 + 0.5)**(-1) + b2 * 0.01 + if add_noise: + return score + abs(self.noise[k]) + else: + return score + + +def get_optimus_trainable(parent_cls): + class OptimusTrainable(parent_cls): + def _setup(self, config): + self.iter = 0 + if config.get("seed"): + np.random.seed(config["seed"]) + time.sleep(config.get("startup_delay", 0)) + params = [config["param1"], config["param2"], config["param3"]] + self.func = OptimusFn(params=params) + self.initial_samples_per_step = 500 + self.mock_data = open("/dev/urandom", "rb").read(1024) + + def _train(self): + self.iter += 1 + new_loss = self.func.eval(self.iter) + time.sleep(0.5) + return { + "mean_loss": float(new_loss), + "mean_accuracy": (2 - new_loss) / 2, + "samples": self.initial_samples_per_step + } + + def _save(self, checkpoint_dir): + time.sleep(0.5) + return { + "func": cloudpickle.dumps(self.func), + "seed": np.random.get_state(), + "data": self.mock_data, + "iter": self.iter + } + + def _restore(self, checkpoint): + self.func = cloudpickle.loads(checkpoint["func"]) + self.data = checkpoint["data"] + self.iter = checkpoint["iter"] + np.random.set_state(checkpoint["seed"]) + + return OptimusTrainable + + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument("--local", action="store_true", default=False) + parser.add_argument("--mock-storage", action="store_true", default=False) + parser.add_argument("--remote-dir", type=str) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse() + address = None if args.local else "auto" + ray.init(address=address) + + config = { + "seed": None, + "startup_delay": 0.001, + "param1": tune.sample_from(lambda spec: np.random.exponential(0.1)), + "param2": tune.sample_from(lambda _: np.random.rand()), + "param3": tune.sample_from(lambda _: np.random.rand()), + } + + parent = MockDurableTrainable if args.mock_storage else DurableTrainable + analysis = tune.run( + get_optimus_trainable(parent), + name="durableTrainable" + str(time.time()), + config=config, + num_samples=4, + verbose=1, + queue_trials=True, + # fault tolerance parameters + max_failures=-1, + checkpoint_freq=20, + sync_to_driver=False, + sync_on_checkpoint=False, + upload_dir="s3://ray-tune-test/exps/", + checkpoint_score_attr="training_iteration", + ) diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py index 9a4fe461d..6aa12d2db 100644 --- a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -11,6 +11,7 @@ from ray.tune.trial import ExportFormat import argparse import os +from filelock import FileLock import random import torch import torch.nn as nn @@ -248,7 +249,8 @@ class PytorchTrainable(tune.Trainable): self.netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)) - self.dataloader = get_data_loader() + with FileLock(os.path.expanduser("~/.data.lock")): + self.dataloader = get_data_loader() def _train(self): lossG, lossD, is_score = train( diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 422769018..9d65735fa 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -10,7 +10,7 @@ import six import types from ray.tune.error import TuneError -from ray.tune.registry import register_trainable +from ray.tune.registry import register_trainable, get_trainable_cls from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.sample import sample_from @@ -34,6 +34,22 @@ def _raise_deprecation_note(deprecated, replacement, soft=False): raise DeprecationWarning(error_msg) +def _raise_on_durable(trainable_name, sync_to_driver, upload_dir): + trainable_cls = get_trainable_cls(trainable_name) + from ray.tune.durable_trainable import DurableTrainable + if issubclass(trainable_cls, DurableTrainable): + if sync_to_driver is not False: + raise ValueError( + "EXPERIMENTAL: DurableTrainable will automatically sync " + "results to the provided upload_dir. " + "Set `sync_to_driver=False` to avoid data inconsistencies.") + if not upload_dir: + raise ValueError( + "EXPERIMENTAL: DurableTrainable will automatically sync " + "results to the provided upload_dir. " + "`upload_dir` must be provided.") + + class Experiment: """Tracks experiment specifications. @@ -109,6 +125,14 @@ class Experiment: config = config or {} self._run_identifier = Experiment.register_if_needed(run) + self.name = name or self._run_identifier + if upload_dir: + self.remote_checkpoint_dir = os.path.join(upload_dir, self.name) + else: + self.remote_checkpoint_dir = None + + _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir) + spec = { "run": self._run_identifier, "stop": stop, @@ -118,6 +142,7 @@ class Experiment: "local_dir": os.path.abspath( os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)), "upload_dir": upload_dir, + "remote_checkpoint_dir": self.remote_checkpoint_dir, "trial_name_creator": trial_name_creator, "loggers": loggers, "sync_to_driver": sync_to_driver, @@ -131,8 +156,6 @@ class Experiment: "restore": os.path.abspath(os.path.expanduser(restore)) if restore else None } - - self.name = name or self._run_identifier self.spec = spec @classmethod @@ -204,11 +227,6 @@ class Experiment: if self.local_dir: return os.path.join(self.local_dir, self.name) - @property - def remote_checkpoint_dir(self): - if self.spec["upload_dir"]: - return os.path.join(self.spec["upload_dir"], self.name) - @property def run_identifier(self): """Returns a string representing the trainable identifier.""" diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 1fd8dc0e1..4550f5a88 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -430,11 +430,13 @@ class UnifiedLogger(Logger): for _logger in self._loggers: _logger.close() - def flush(self): + def flush(self, sync_down=True): for _logger in self._loggers: _logger.flush() - if not self._log_syncer.sync_down(): - logger.warning("Trial %s: Post-flush sync skipped.", self.trial) + if sync_down: + if not self._log_syncer.sync_down(): + logger.warning("Trial %s: Post-flush sync skipped.", + self.trial) def sync_up(self): return self._log_syncer.sync_up() diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 34a1de5fe..2b3a24d5b 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -238,7 +238,7 @@ def _get_trial_info(trial, parameters, metrics): metrics (List[str]): Names of metrics to include. """ result = flatten_dict(trial.last_result) - trial_info = [str(trial), trial.status, str(trial.address)] + trial_info = [str(trial), trial.status, str(trial.location)] trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters] trial_info += [result.get(metric) for metric in metrics] return trial_info diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 1c763f37b..67af35f91 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -13,10 +13,12 @@ import ray from ray.exceptions import RayTimeoutError from ray import ray_constants from ray.resource_spec import ResourceSpec +from ray.tune.durable_trainable import DurableTrainable from ray.tune.error import AbortTrialExecution from ray.tune.logger import NoopLogger -from ray.tune.trial import Trial, Checkpoint, Location from ray.tune.resources import Resources +from ray.tune.trainable import TrainableUtil +from ray.tune.trial import Trial, Checkpoint, Location from ray.tune.trial_executor import TrialExecutor from ray.tune.util import warn_if_slow from ray.tune.error import TuneError @@ -27,7 +29,6 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms BOTTLENECK_WARN_PERIOD_S = 60 NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3 DEFAULT_GET_TIMEOUT = 30.0 # seconds -TRIAL_START_ATTEMPTS = 3 class _LocalWrapper: @@ -86,7 +87,7 @@ class RayTrialExecutor(TrialExecutor): self._cached_actor) existing_runner = self._cached_actor self._cached_actor = None - trial.runner = existing_runner + trial.set_runner(existing_runner) if not self.reset_trial(trial, trial.config, trial.experiment_tag): raise AbortTrialExecution( "Trainable runner reuse requires reset_config() to be " @@ -122,11 +123,16 @@ class RayTrialExecutor(TrialExecutor): logger.debug("Trial %s: Setting up new remote runner.", trial) # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. - return cls.remote(config=trial.config, logger_creator=logger_creator) + kwargs = { + "config": trial.config, + "logger_creator": logger_creator, + } + if issubclass(trial.get_trainable_cls(), DurableTrainable): + kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir + return cls.remote(**kwargs) def _train(self, trial): """Start one iteration of training and save remote id.""" - if self._find_item(self._paused, trial): raise TuneError( "Should not call `train` on PAUSED trial {}. " @@ -168,9 +174,11 @@ class RayTrialExecutor(TrialExecutor): """ prior_status = trial.status self.set_status(trial, Trial.RUNNING) - trial.runner = runner or self._setup_remote_runner( - trial, - reuse_allowed=checkpoint is not None or trial.has_checkpoint()) + trial.set_runner( + runner or self._setup_remote_runner( + trial, + reuse_allowed=checkpoint is not None + or trial.has_checkpoint())) self.restore(trial, checkpoint) previous_run = self._find_item(self._paused, trial) @@ -178,7 +186,7 @@ class RayTrialExecutor(TrialExecutor): # If Trial was in flight when paused, self._paused stores result. self._paused.pop(previous_run[0]) self._running[previous_run[0]] = trial - else: + elif not trial.is_restoring: self._train(trial) def _stop_trial(self, trial, error=False, error_msg=None, @@ -194,7 +202,6 @@ class RayTrialExecutor(TrialExecutor): error_msg (str): Optional error message. stop_logger (bool): Whether to shut down the trial logger. """ - if stop_logger: trial.close_logger() @@ -206,7 +213,7 @@ class RayTrialExecutor(TrialExecutor): if hasattr(trial, "runner") and trial.runner: if (not error and self._reuse_actors and self._cached_actor is None): - logger.debug("Reusing actor for {}".format(trial.runner)) + logger.debug("Reusing actor for %s", trial.runner) self._cached_actor = trial.runner else: logger.debug("Trial %s: Destroying actor.", trial) @@ -216,7 +223,7 @@ class RayTrialExecutor(TrialExecutor): logger.exception("Trial %s: Error stopping runner.", trial) self.set_status(trial, Trial.ERROR) finally: - trial.runner = None + trial.set_runner(None) def start_trial(self, trial, checkpoint=None): """Starts the trial. @@ -229,49 +236,22 @@ class RayTrialExecutor(TrialExecutor): of trial. """ self._commit_resources(trial.resources) - remote_runner = None - attempts = 0 - while attempts < TRIAL_START_ATTEMPTS: - attempts += 1 - if attempts > 1: - logger.warning("Trial %s: Start attempt #%s...", trial, - attempts) - try: - self._start_trial(trial, checkpoint, remote_runner) - break - except AbortTrialExecution: - logger.exception("Trial %s: Error starting runner, aborting!", - trial) - time.sleep(2) - error_msg = traceback.format_exc() - self._stop_trial(trial, error=True, error_msg=error_msg) - break # don't retry fatal Tune errors - except RayTimeoutError: - # Reuse the existing runner on retries. - remote_runner = trial.runner - warning = ("Runner task timed out. This could be due to " - "slow worker startup.") - if attempts == TRIAL_START_ATTEMPTS: - error_msg = traceback.format_exc() - self._stop_trial(trial, error=True, error_msg=error_msg) - else: - warning += " Reusing the same runner." - logger.warning("Trial %s: %s", trial, warning) - except Exception: - logger.exception("Trial %s: Error starting runner.", trial) - time.sleep(2) - error_msg = traceback.format_exc() - self._stop_trial(trial, error=True, error_msg=error_msg) - remote_runner = None - # This forces the trial to not start from checkpoint. - checkpoint = None - trial.clear_checkpoint() - # Note that we don't return the resources, since they may - # have been lost. TODO(ujvl): is this the right thing to do? - else: - logger.exception( - "Trial %s: Aborting trial after %s start " - "attempts!", trial, TRIAL_START_ATTEMPTS) + try: + self._start_trial(trial, checkpoint) + except AbortTrialExecution: + logger.exception("Trial %s: Error starting runner, aborting!", + trial) + time.sleep(2) + error_msg = traceback.format_exc() + self._stop_trial(trial, error=True, error_msg=error_msg) + except Exception: + logger.exception("Trial %s: Unexpected error starting runner.", + trial) + time.sleep(2) + error_msg = traceback.format_exc() + self._stop_trial(trial, error=True, error_msg=error_msg) + # Note that we don't return the resources, since they may + # have been lost. TODO(ujvl): is this the right thing to do? def _find_item(self, dictionary, item): out = [rid for rid, t in dictionary.items() if t is item] @@ -332,7 +312,6 @@ class RayTrialExecutor(TrialExecutor): def get_running_trials(self): """Returns the running trials.""" - return list(self._running.values()) def get_alive_node_ips(self): @@ -387,7 +366,8 @@ class RayTrialExecutor(TrialExecutor): """Fetches one result of the running trials. Returns: - Result of the most recent trial training run.""" + Result of the most recent trial training run. + """ trial_future = self._find_item(self._running, trial) if not trial_future: raise ValueError("Trial was not running.") @@ -437,6 +417,7 @@ class RayTrialExecutor(TrialExecutor): "Resource invalid: {}".format(resources)) def _update_avail_resources(self, num_retries=5): + resources = None for i in range(num_retries): try: resources = ray.cluster_resources() @@ -520,7 +501,6 @@ class RayTrialExecutor(TrialExecutor): def debug_string(self): """Returns a human readable message for printing to the console.""" - if self._resources_initialized: status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, " "{}/{} GiB heap, {}/{} GiB objects".format( @@ -548,7 +528,6 @@ class RayTrialExecutor(TrialExecutor): def resource_string(self): """Returns a string describing the total resources available.""" - if self._resources_initialized: res_str = ("{} CPUs, {} GPUs, " "{} GiB heap, {} GiB objects".format( @@ -570,18 +549,28 @@ class RayTrialExecutor(TrialExecutor): """Before step() called, update the available resources.""" self._update_avail_resources() - def save(self, trial, storage=Checkpoint.DISK, result=None): - """Saves the trial's state to a checkpoint.""" - result = result or trial.last_result + def save(self, trial, storage=Checkpoint.PERSISTENT, result=None): + """Saves the trial's state to a checkpoint. + Args: + trial (Trial): The state of this trial to be saved. + storage (str): Where to store the checkpoint. Defaults to + PERSISTENT. + result (dict): The state of this trial as a dictionary to be saved. + If result is None, the trial's last result will be used. + + Returns: + Checkpoint future, or None if an Exception occurs. + """ + result = result or trial.last_result if storage == Checkpoint.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = Checkpoint(storage, value, result) else: - with warn_if_slow("save_checkpoint_to_disk"): + with warn_if_slow("save_checkpoint_to_storage"): + # TODO(ujvl): Make this asynchronous. value = ray.get(trial.runner.save.remote()) checkpoint = Checkpoint(storage, value, result) - with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile: try: trial.on_checkpoint(checkpoint) @@ -600,13 +589,10 @@ class RayTrialExecutor(TrialExecutor): def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. - This will also sync the trial results to a new location - if restoring on a different node. - Raises: RuntimeError: This error is raised if no runner is found. - RayTimeoutError: This error is raised if a remote call to the - runner times out. + AbortTrialExecution: This error is raised if the trial is + ineligible for restoration, given the Tune input arguments. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint @@ -617,19 +603,29 @@ class RayTrialExecutor(TrialExecutor): "Trial {}: Unable to restore - no runner found.".format(trial)) value = checkpoint.value if checkpoint.storage == Checkpoint.MEMORY: - assert not isinstance(value, Checkpoint), type(value) + logger.debug("Trial %s: Attempting restore from object", trial) + # Note that we don't store the remote since in-memory checkpoints + # don't guarantee fault tolerance and don't need to be waited on. trial.runner.restore_from_object.remote(value) else: - logger.info("Trial %s: Attempting restore from %s", trial, value) - with warn_if_slow("get_current_ip"): - worker_ip = ray.get(trial.runner.current_ip.remote(), - DEFAULT_GET_TIMEOUT) - with warn_if_slow("sync_to_new_location"): - trial.sync_logger_to_new_location(worker_ip) - with warn_if_slow("restore_from_disk"): - # TODO(ujvl): Take blocking restores out of the control loop. - ray.get(trial.runner.restore.remote(value)) - trial.last_result = checkpoint.result + logger.debug("Trial %s: Attempting restore from %s", trial, value) + if issubclass(trial.get_trainable_cls(), DurableTrainable): + remote = trial.runner.restore.remote(value) + elif trial.sync_on_checkpoint: + # This provides FT backwards compatibility in the + # case where a DurableTrainable is not provided. + logger.warning("Trial %s: Reading checkpoint into memory.", + trial) + data_dict = TrainableUtil.pickle_checkpoint(value) + remote = trial.runner.restore_from_object.remote(data_dict) + else: + raise AbortTrialExecution( + "Pass in `sync_on_checkpoint=True` for driver-based trial" + "restoration. Pass in an `upload_dir` and a Trainable " + "extending `DurableTrainable` for remote storage-based " + "restoration") + self._running[remote] = trial + trial.restoring_from = checkpoint def export_trial_if_needed(self, trial): """Exports model of this trial based on trial.export_formats. diff --git a/python/ray/tune/sync_client.py b/python/ray/tune/sync_client.py index a6a1ef144..ece09c496 100644 --- a/python/ray/tune/sync_client.py +++ b/python/ray/tune/sync_client.py @@ -28,25 +28,31 @@ def noop(*args): return -def get_sync_client(sync_function): +def get_sync_client(sync_function, delete_function=None): """Returns a sync client. Args: sync_function (Optional[str|function]): Sync function. + delete_function (Optional[str|function]): Delete function. Must be + the same type as sync_function if it is provided. Raises: - ValueError if sync_function is malformed. + ValueError if sync_function or delete_function are malformed. """ if sync_function is None: return None + if delete_function and type(sync_function) != type(delete_function): + raise ValueError("Sync and delete functions must be of same type.") if isinstance(sync_function, types.FunctionType): + delete_function = delete_function or noop client_cls = FunctionBasedClient elif isinstance(sync_function, str): + delete_function = delete_function or noop_template client_cls = CommandBasedClient else: raise ValueError("Sync function {} must be string or function".format( sync_function)) - return client_cls(sync_function, sync_function) + return client_cls(sync_function, sync_function, delete_function) def get_cloud_sync_client(remote_path): @@ -63,17 +69,19 @@ def get_cloud_sync_client(remote_path): raise ValueError( "Upload uri starting with '{}' requires awscli tool" " to be installed".format(S3_PREFIX)) - template = "aws s3 sync {source} {target}" + template = "aws s3 sync {source} {target} --only-show-errors" + delete_template = "aws s3 rm {target} --recursive --only-show-errors" elif remote_path.startswith(GS_PREFIX): if not distutils.spawn.find_executable("gsutil"): raise ValueError( "Upload uri starting with '{}' requires gsutil tool" " to be installed".format(GS_PREFIX)) template = "gsutil rsync -r {source} {target}" + delete_template = "gsutil rm -r {target}" else: raise ValueError("Upload uri must start with one of: {}" "".format(ALLOWED_REMOTE_PREFIXES)) - return CommandBasedClient(template, template) + return CommandBasedClient(template, template, delete_template) class SyncClient: @@ -103,6 +111,17 @@ class SyncClient: """ raise NotImplementedError + def delete(self, target): + """Deletes target. + + Args: + target (str): Target path. + + Returns: + True if delete initiation successful, False otherwise. + """ + raise NotImplementedError + def wait(self): """Waits for current sync to complete, if asynchronously started.""" pass @@ -113,9 +132,10 @@ class SyncClient: class FunctionBasedClient(SyncClient): - def __init__(self, sync_up_func, sync_down_func): + def __init__(self, sync_up_func, sync_down_func, delete_func=None): self.sync_up_func = sync_up_func self.sync_down_func = sync_down_func + self.delete_func = delete_func or noop def sync_up(self, source, target): self.sync_up_func(source, target) @@ -125,12 +145,19 @@ class FunctionBasedClient(SyncClient): self.sync_down_func(source, target) return True + def delete(self, target): + self.delete_func(target) + return True + NOOP = FunctionBasedClient(noop, noop) class CommandBasedClient(SyncClient): - def __init__(self, sync_up_template, sync_down_template): + def __init__(self, + sync_up_template, + sync_down_template, + delete_template=noop_template): """Syncs between two directories with the given command. Arguments: @@ -138,11 +165,14 @@ class CommandBasedClient(SyncClient): include replacement fields '{source}' and '{target}'. sync_down_template (str): A runnable string template; needs to include replacement fields '{source}' and '{target}'. + delete_template (Optional[str]): A runnable string template; needs + to include replacement field '{target}'. Noop by default. """ self._validate_sync_string(sync_up_template) self._validate_sync_string(sync_down_template) self.sync_up_template = sync_up_template self.sync_down_template = sync_down_template + self.delete_template = delete_template self.logfile = None self.cmd_process = None @@ -153,7 +183,7 @@ class CommandBasedClient(SyncClient): logdir (str): Log directory. """ self.logfile = tempfile.NamedTemporaryFile( - prefix="log_sync", dir=logdir, suffix=".log", delete=False) + prefix="log_sync_out", dir=logdir, suffix=".log", delete=False) def sync_up(self, source, target): return self._execute(self.sync_up_template, source, target) @@ -161,14 +191,27 @@ class CommandBasedClient(SyncClient): def sync_down(self, source, target): return self._execute(self.sync_down_template, source, target) + def delete(self, target): + if self.is_running: + logger.warning("Last sync client cmd still in progress, skipping.") + return False + final_cmd = self.delete_template.format(target=quote(target)) + logger.debug("Running delete: {}".format(final_cmd)) + self.cmd_process = subprocess.Popen( + final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile) + return True + def wait(self): if self.cmd_process: _, error_msg = self.cmd_process.communicate() error_msg = error_msg.decode("ascii") code = self.cmd_process.returncode + args = self.cmd_process.args self.cmd_process = None if code != 0: - raise TuneError("Sync error ({}): {}".format(code, error_msg)) + raise TuneError("Sync error. Ran command: {}\n" + "Error message ({}): {}".format( + args, code, error_msg)) def reset(self): if self.is_running: @@ -177,7 +220,7 @@ class CommandBasedClient(SyncClient): @property def is_running(self): - """Returns whether a sync process is running.""" + """Returns whether a sync or delete process is running.""" if self.cmd_process: self.cmd_process.poll() return self.cmd_process.returncode is None diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index a4c69fc2b..1f2aecc2e 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -191,6 +191,8 @@ class NodeSyncer(Syncer): def sync_down(self): if not self.has_remote_target(): return True + logger.debug("Syncing from %s to %s", self._remote_path, + self._local_dir) return super(NodeSyncer, self).sync_down() @property @@ -250,23 +252,23 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None): Args: local_dir (str): Source directory for syncing. remote_dir (str): Target directory for syncing. If not provided, a - no-op Syncer is returned. - sync_function (func|str): Function for syncing the local_dir to + noop Syncer is returned. + sync_function (func|str|bool): Function for syncing the local_dir to remote_dir. If string, then it must be a string template for - syncer to run. If not provided, it defaults rsync. + syncer to run. If True or not provided, it defaults rsync. If + False, a noop Syncer is returned. """ key = (local_dir, remote_dir) if key in _syncers: return _syncers[key] - elif not remote_dir: + elif not remote_dir or sync_function is False: sync_client = NOOP - elif sync_function: + elif sync_function and sync_function is not True: sync_client = get_sync_client(sync_function) else: - sync_up = log_sync_template() - sync_down = log_sync_template(options="--remove-source-files") - if sync_up and sync_down: - sync_client = CommandBasedClient(sync_up, sync_down) + sync = log_sync_template() + if sync: + sync_client = CommandBasedClient(sync, sync) sync_client.set_logdir(local_dir) else: sync_client = NOOP diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 8b0ed184e..6b219b8b3 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -2,16 +2,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import shutil + import copy import os import time import unittest +from unittest.mock import patch import ray from ray.rllib import _register_all from ray import tune -from ray.tune import Trainable, TuneError +from ray.tune import DurableTrainable, Trainable, TuneError from ray.tune import register_env, register_trainable, run_experiments from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.trial import Trial @@ -25,6 +28,7 @@ from ray.tune.experiment import Experiment from ray.tune.resources import Resources from ray.tune.suggest import grid_search from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm +from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR class TrainableFunctionApiTest(unittest.TestCase): @@ -703,6 +707,34 @@ class TrainableFunctionApiTest(unittest.TestCase): ] self.assertTrue(all(complete_results1)) + def testDurableTrainable(self): + class TestTrain(DurableTrainable): + def _setup(self, config): + self.state = {"hi": 1, "iter": 0} + + def _train(self): + self.state["iter"] += 1 + return {"timesteps_this_iter": 1, "done": True} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + sync_client = mock_storage_client() + mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client" + with patch(mock_get_client) as mock_get_cloud_sync_client: + mock_get_cloud_sync_client.return_value = sync_client + test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR) + checkpoint_path = test_trainable.save() + test_trainable.train() + test_trainable.state["hi"] = 2 + test_trainable.restore(checkpoint_path) + self.assertEqual(test_trainable.state["hi"], 1) + + self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR) + def testCheckpointDict(self): class TestTrain(Trainable): def _setup(self, config): diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 312aca3dc..93ea38d37 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -16,23 +16,28 @@ class CheckpointManagerTest(unittest.TestCase): def mock_result(i): return {"i": i} + def checkpoint_manager(self, keep_checkpoints_num): + return CheckpointManager( + keep_checkpoints_num, "i", delete_fn=lambda c: None) + def testOnCheckpointOrdered(self): """ Tests increasing priorities. Also tests that that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 - checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i)) + Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i)) for i in range(3) ] - with patch("shutil.rmtree") as rmtree_mock, patch("os.path"): + with patch.object(checkpoint_manager, "delete") as \ + delete_mock: for j in range(3): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 2 else 1 - self.assertEqual(rmtree_mock.call_count, expected_deletes) + self.assertEqual(delete_mock.call_count, expected_deletes, j) self.assertEqual(checkpoint_manager.newest_checkpoint, checkpoints[j]) @@ -47,17 +52,17 @@ class CheckpointManagerTest(unittest.TestCase): that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 - checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i)) + Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i)) for i in range(3, -1, -1) ] - with patch("shutil.rmtree") as rmtree_mock, patch("os.path"): + with patch.object(checkpoint_manager, "delete") as delete_mock: for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 - self.assertEqual(rmtree_mock.call_count, expected_deletes) + self.assertEqual(delete_mock.call_count, expected_deletes) self.assertEqual(checkpoint_manager.newest_checkpoint, checkpoints[j]) @@ -71,7 +76,7 @@ class CheckpointManagerTest(unittest.TestCase): Tests that the best checkpoints are tracked and ordered correctly. """ keep_checkpoints_num = 4 - checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i)) for i in range(16) @@ -92,7 +97,7 @@ class CheckpointManagerTest(unittest.TestCase): checkpoint has no checkpoint score attribute. """ keep_checkpoints_num = 1 - checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {}) with patch.object(logger, "error") as log_error_mock: diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index aa6b9c8e1..cd78d964b 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -9,20 +9,26 @@ import os import pytest import shutil import sys -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import ray from ray import tune from ray.rllib import _register_all from ray.cluster_utils import Cluster from ray.test_utils import run_string_as_driver_nonblocking +from ray.tune import register_trainable +from ray.tune.experiment import Experiment from ray.tune.error import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.experiment import Experiment -from ray.tune.trial import Trial from ray.tune.resources import Resources -from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator +from ray.tune.syncer import Syncer +from ray.tune.trainable import TrainableUtil +from ray.tune.trial import Trial +from ray.tune.trial_runner import TrialRunner +from ray.tune.utils.mock import (MockDurableTrainer, MockRemoteTrainer, + MockNodeSyncer, mock_storage_client, + MOCK_REMOTE_DIR) def _start_new_cluster(): @@ -36,6 +42,8 @@ def _start_new_cluster(): }) }) # Pytest doesn't play nicely with imports + register_trainable("__fake_remote", MockRemoteTrainer) + register_trainable("__fake_durable", MockDurableTrainer) _register_all() return cluster @@ -53,7 +61,6 @@ def start_connected_cluster(): @pytest.fixture def start_connected_emptyhead_cluster(): """Starts head with no resources.""" - cluster = Cluster( initialize_head=True, connect=True, @@ -65,6 +72,8 @@ def start_connected_emptyhead_cluster(): }) # Pytest doesn't play nicely with imports _register_all() + register_trainable("__fake_remote", MockRemoteTrainer) + register_trainable("__fake_durable", MockDurableTrainer) yield cluster # The code after the yield will run as teardown code. ray.shutdown() @@ -208,7 +217,8 @@ def test_queue_trials(start_connected_emptyhead_cluster): assert gpu_trial.status == Trial.TERMINATED -def test_trial_migration(start_connected_emptyhead_cluster): +@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"]) +def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): """Removing a node while cluster has space should migrate trial. The trial state should also be consistent with the checkpoint. @@ -220,14 +230,16 @@ def test_trial_migration(start_connected_emptyhead_cluster): runner = TrialRunner(BasicVariantGenerator()) kwargs = { "stopping_criterion": { - "training_iteration": 3 + "training_iteration": 4 }, "checkpoint_freq": 2, - "max_failures": 2 + "max_failures": 2, + "remote_checkpoint_dir": MOCK_REMOTE_DIR, + "sync_to_driver_fn": trainable_id == "__fake", } # Test recovery of trial that hasn't been checkpointed - t = Trial("__fake", **kwargs) + t = Trial(trainable_id, **kwargs) runner.add_trial(t) runner.step() # start runner.step() # 1 result @@ -241,13 +253,13 @@ def test_trial_migration(start_connected_emptyhead_cluster): # because checkpoint handling is messy and should be refactored # rather than hotfixed. # assert t.last_result is None, "Trial result not restored correctly." - for i in range(3): + for i in range(4): runner.step() assert t.status == Trial.TERMINATED # Test recovery of trial that has been checkpointed - t2 = Trial("__fake", **kwargs) + t2 = Trial(trainable_id, **kwargs) runner.add_trial(t2) runner.step() # start runner.step() # 1 result @@ -256,13 +268,23 @@ def test_trial_migration(start_connected_emptyhead_cluster): node3 = cluster.add_node(num_cpus=1) cluster.remove_node(node2) cluster.wait_for_nodes() + runner.step() # 3 result + start and fail 4 result runner.step() # Recovery step + runner.step() # Process recovery + runner.step() # result if t2.status != Trial.TERMINATED: runner.step() assert t2.status == Trial.TERMINATED, runner.debug_string() # Test recovery of trial that won't be checkpointed - t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}}) + kwargs = { + "stopping_criterion": { + "training_iteration": 3 + }, + "remote_checkpoint_dir": MOCK_REMOTE_DIR, + "sync_to_driver_fn": trainable_id == "__fake", + } + t3 = Trial(trainable_id, **kwargs) runner.add_trial(t3) runner.step() # start runner.step() # 1 result @@ -278,7 +300,8 @@ def test_trial_migration(start_connected_emptyhead_cluster): runner.step() -def test_trial_requeue(start_connected_emptyhead_cluster): +@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"]) +def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id): """Removing a node in full cluster causes Trial to be requeued.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) @@ -290,10 +313,12 @@ def test_trial_requeue(start_connected_emptyhead_cluster): "training_iteration": 5 }, "checkpoint_freq": 1, - "max_failures": 1 + "max_failures": 1, + "remote_checkpoint_dir": MOCK_REMOTE_DIR, + "sync_to_driver_fn": trainable_id == "__fake", } - trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] for t in trials: runner.add_trial(t) @@ -309,7 +334,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster): runner.step() -def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): +@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"]) +def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, + trainable_id): """Test checks that trial restarts if checkpoint is lost w/ node fail.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) @@ -318,32 +345,59 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): runner = TrialRunner(BasicVariantGenerator()) kwargs = { "stopping_criterion": { - "training_iteration": 3 + "training_iteration": 4 }, "checkpoint_freq": 2, - "max_failures": 2 + "max_failures": 2, + "remote_checkpoint_dir": MOCK_REMOTE_DIR, + "sync_to_driver_fn": trainable_id == "__fake_remote", } - # Test recovery of trial that has been checkpointed - t1 = Trial("__fake", **kwargs) - runner.add_trial(t1) - runner.step() # start - runner.step() # 1 result - runner.step() # 2 result and checkpoint - assert t1.has_checkpoint() - cluster.add_node(num_cpus=1) - cluster.remove_node(node) - cluster.wait_for_nodes() - shutil.rmtree(os.path.dirname(t1.checkpoint.value)) + # The following patches only affect __fake_remote. + find_checkpoint_dir = TrainableUtil.find_checkpoint_dir + with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer: + trainable_util = "ray.tune.ray_trial_executor.TrainableUtil" + with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir: - runner.step() # Recovery step - for i in range(3): - if t1.status != Trial.TERMINATED: - runner.step() + def mock_get_syncer_fn(local_dir, remote_dir, sync_function): + client = mock_storage_client() + return MockNodeSyncer(local_dir, remote_dir, client) + + mock_get_node_syncer.side_effect = mock_get_syncer_fn + + def mock_find_dir_fn(checkpoint_path): + """Converts back to local path first.""" + checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):] + checkpoint_path = os.path.join("/", checkpoint_path) + return find_checkpoint_dir(checkpoint_path) + + # __fake_remote trainables save to a separate "remote" directory. + # TrainableUtil will not check this path unless we mock it. + mock_find_dir.side_effect = mock_find_dir_fn + + # Test recovery of trial that has been checkpointed + t1 = Trial(trainable_id, **kwargs) + runner.add_trial(t1) + runner.step() # start + runner.step() # 1 result + runner.step() # 2 result and checkpoint + assert t1.has_checkpoint() + cluster.add_node(num_cpus=1) + cluster.remove_node(node) + cluster.wait_for_nodes() + shutil.rmtree(os.path.dirname(t1.checkpoint.value)) + + runner.step() # collect result 3, kick off + fail result 4 + runner.step() # Recovery step + runner.step() # Process Recovery + step 4 + for i in range(3): + if t1.status != Trial.TERMINATED: + runner.step() assert t1.status == Trial.TERMINATED, runner.debug_string() -def test_cluster_down_simple(start_connected_cluster, tmpdir): +@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"]) +def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id): """Tests that TrialRunner save/restore works on cluster shutdown.""" cluster = start_connected_cluster cluster.add_node(num_cpus=1) @@ -356,9 +410,11 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): "training_iteration": 2 }, "checkpoint_freq": 1, - "max_failures": 1 + "max_failures": 1, + "remote_checkpoint_dir": MOCK_REMOTE_DIR, + "sync_to_driver_fn": trainable_id == "__fake", } - trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] for t in trials: runner.add_trial(t) @@ -374,6 +430,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): cluster = _start_new_cluster() runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath) runner.step() # start + runner.step() # process restore runner.step() # start2 for i in range(3): @@ -387,26 +444,31 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): cluster.shutdown() -def test_cluster_down_full(start_connected_cluster, tmpdir): +@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"]) +def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id): """Tests that run_experiment restoring works on cluster shutdown.""" cluster = start_connected_cluster dirpath = str(tmpdir) - exp1_args = dict( - run="__fake", + use_default_sync = trainable_id == "__fake" + from ray.tune.result import DEFAULT_RESULTS_DIR + local_dir = DEFAULT_RESULTS_DIR + upload_dir = None if use_default_sync else MOCK_REMOTE_DIR + + base_dict = dict( + run=trainable_id, stop=dict(training_iteration=3), - local_dir=dirpath, - checkpoint_freq=1) - exp2_args = dict(run="__fake", stop=dict(training_iteration=3)) - exp3_args = dict( - run="__fake", - stop=dict(training_iteration=3), - config=dict(mock_error=True)) + local_dir=local_dir, + upload_dir=upload_dir, + sync_to_driver=use_default_sync, + ) + + exp1_args = base_dict + exp2_args = dict(base_dict.items(), local_dir=dirpath, checkpoint_freq=1) + exp3_args = dict(base_dict.items(), config=dict(mock_error=True)) exp4_args = dict( - run="__fake", - stop=dict(training_iteration=3), - config=dict(mock_error=True), - checkpoint_freq=1) + base_dict.items(), config=dict(mock_error=True), checkpoint_freq=1) + all_experiments = { "exp1": exp1_args, "exp2": exp2_args, @@ -414,14 +476,20 @@ def test_cluster_down_full(start_connected_cluster, tmpdir): "exp4": exp4_args } - tune.run_experiments(all_experiments, raise_on_failed_trial=False) + mock_get_client = "ray.tune.trial_runner.get_cloud_syncer" + with patch(mock_get_client) as mock_get_cloud_syncer: + mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client()) + mock_get_cloud_syncer.return_value = mock_syncer - ray.shutdown() - cluster.shutdown() - cluster = _start_new_cluster() + tune.run_experiments(all_experiments, raise_on_failed_trial=False) + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + + trials = tune.run_experiments( + all_experiments, resume=True, raise_on_failed_trial=False) - trials = tune.run_experiments( - all_experiments, resume=True, raise_on_failed_trial=False) assert len(trials) == 4 assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials) ray.shutdown() diff --git a/python/ray/tune/tests/test_experiment.py b/python/ray/tune/tests/test_experiment.py index 355d20253..035ce8391 100644 --- a/python/ray/tune/tests/test_experiment.py +++ b/python/ray/tune/tests/test_experiment.py @@ -4,11 +4,25 @@ from __future__ import print_function import unittest +import ray +from ray.rllib import _register_all +from ray.tune import register_trainable from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.error import TuneError class ExperimentTest(unittest.TestCase): + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + def setUp(self): + def train(config, reporter): + for i in range(100): + reporter(timesteps_total=i) + + register_trainable("f1", train) + def testConvertExperimentFromExperiment(self): exp1 = Experiment(**{ "name": "foo", diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index f3c7223e9..abdb49a9a 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -4,12 +4,9 @@ from __future__ import division from __future__ import print_function import json -import sys import unittest -from unittest.mock import patch import ray -from ray.exceptions import RayTimeoutError from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor @@ -41,33 +38,11 @@ class RayTrialExecutorTest(unittest.TestCase): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) - self.trial_executor.save(trial, Checkpoint.DISK) + self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) - def testSaveRestoreTimeout(self): - trial = Trial("__fake") - self.trial_executor.start_trial(trial) - self.assertEqual(Trial.RUNNING, trial.status) - self.trial_executor.save(trial, Checkpoint.DISK) - self.trial_executor.set_status(trial, Trial.PAUSED) - - ray_get = ray.get - start_trial = self.trial_executor._start_trial - - # Timeout on first two attempts, then succeed on subsequent gets. - side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get] - with patch.object(self.trial_executor, "_start_trial") as mock_start: - with patch("ray.get", side_effect=side_effects): - mock_start.side_effect = start_trial - self.trial_executor.start_trial(trial, trial.checkpoint) - - # Trial starts successfully on 3rd attempt. - assert mock_start.call_count == 3 - self.assertEqual(Trial.RUNNING, trial.status) - self.trial_executor.stop_trial(trial) - def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") @@ -216,4 +191,5 @@ class LocalModeExecutorTest(RayTrialExecutorTest): if __name__ == "__main__": import pytest + import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/tests/test_trainable_util.py b/python/ray/tune/tests/test_trainable_util.py new file mode 100644 index 000000000..d7a0f6689 --- /dev/null +++ b/python/ray/tune/tests/test_trainable_util.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pickle +import shutil +import unittest + +from ray.tune.trainable import TrainableUtil + + +class TrainableUtilTest(unittest.TestCase): + def setUp(self): + self.checkpoint_dir = "/tmp/tune/MyTrainable123" + TrainableUtil.make_checkpoint_dir(self.checkpoint_dir) + + def tearDown(self): + self.addCleanup(shutil.rmtree, self.checkpoint_dir) + + def testFindCheckpointDir(self): + checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt") + os.makedirs(checkpoint_path) + found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) + self.assertEquals(self.checkpoint_dir, found_dir) + + with self.assertRaises(FileNotFoundError): + parent = os.path.dirname(found_dir) + TrainableUtil.find_checkpoint_dir(parent) + + def testPickleCheckpoint(self): + for i in range(5): + path = os.path.join(self.checkpoint_dir, str(i)) + with open(path, "w") as f: + f.write(str(i)) + + checkpoint_path = os.path.join(self.checkpoint_dir, "0") + + data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) + loaded = pickle.loads(data_dict) + + checkpoint_name = os.path.basename(checkpoint_path) + self.assertEqual(loaded["checkpoint_name"], checkpoint_name) + + for i in range(5): + path = os.path.join(self.checkpoint_dir, str(i)) + self.assertEquals(loaded["data"][str(i)], open(path, "rb").read()) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 582416bb3..6099a864f 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -19,9 +19,11 @@ from ray.tune.suggest import BasicVariantGenerator class TrialRunnerTest(unittest.TestCase): + def setUp(self): + _register_all() # re-register the evicted objects + def tearDown(self): ray.shutdown() - _register_all() # re-register the evicted objects def testTrialStatus(self): ray.init() diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 60ea44301..3b76d9810 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -182,9 +182,11 @@ class TrialRunnerTest2(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[0].num_failures, 1) + runner.step() # Restore step runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[0].num_failures, 2) + runner.step() # Restore step runner.step() self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 3) @@ -242,6 +244,7 @@ class TrialRunnerTest2(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) + runner.step() # Restore step runner.step() self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10) self.assertEqual(trials[1].last_result["iterations_since_restore"], 1) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 863b8b588..9c73344d9 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -203,7 +203,7 @@ class _MockTrialExecutor(TrialExecutor): def restore(self, trial, checkpoint=None): pass - def save(self, trial, type=Checkpoint.DISK, result=None): + def save(self, trial, type=Checkpoint.PERSISTENT, result=None): return trial.trainable_name def reset_trial(self, trial, new_config, new_experiment_tag): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index ae9c45df6..736a23981 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -29,6 +29,57 @@ logger = logging.getLogger(__name__) SETUP_TIME_THRESHOLD = 10 +class TrainableUtil: + @staticmethod + def pickle_checkpoint(checkpoint_path): + """Pickles checkpoint data.""" + checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) + data = {} + for basedir, _, file_names in os.walk(checkpoint_dir): + for file_name in file_names: + path = os.path.join(basedir, file_name) + with open(path, "rb") as f: + data[os.path.relpath(path, checkpoint_dir)] = f.read() + # Use normpath so that a directory path isn't mapped to empty string. + name = os.path.basename(os.path.normpath(checkpoint_path)) + name += os.path.sep if os.path.isdir(checkpoint_path) else "" + data_dict = pickle.dumps({ + "checkpoint_name": name, + "data": data, + }) + return data_dict + + @staticmethod + def find_checkpoint_dir(checkpoint_path): + """Returns the directory containing the checkpoint path. + + Raises: + FileNotFoundError if the directory is not found. + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError("Path does not exist", checkpoint_path) + if os.path.isdir(checkpoint_path): + checkpoint_dir = checkpoint_path + else: + checkpoint_dir = os.path.dirname(checkpoint_path) + while checkpoint_dir != os.path.dirname(checkpoint_dir): + if os.path.exists(os.path.join(checkpoint_dir, ".is_checkpoint")): + break + checkpoint_dir = os.path.dirname(checkpoint_dir) + else: + raise FileNotFoundError("Checkpoint directory not found for {}" + .format(checkpoint_path)) + return checkpoint_dir + + @staticmethod + def make_checkpoint_dir(checkpoint_dir): + """Creates a checkpoint directory at the provided path.""" + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + # Drop marker in directory to identify it as a checkpoint dir. + open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close() + + class Trainable: """Abstract class for trainable models, functions, etc. @@ -119,17 +170,14 @@ class Trainable: >>> extra_cpu=config["workers"], >>> extra_gpu=int(config["use_gpu"]) * config["workers"]) """ - return None @classmethod def resource_help(cls, config): - """ + """Returns a help string for configuring this trainable's resources. + Args: config (dict): The Trainer's config dict. - - Returns: - str: A help string for configuring this trainable's resources. """ return "" @@ -258,9 +306,7 @@ class Trainable: """ checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, "checkpoint_{}".format(self._iteration)) - - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + TrainableUtil.make_checkpoint_dir(checkpoint_dir) checkpoint = self._save(checkpoint_dir) saved_as_dict = False if isinstance(checkpoint, string_types): @@ -270,6 +316,10 @@ class Trainable: "given checkpoint dir {}: {}".format( checkpoint_dir, checkpoint)) checkpoint_path = checkpoint + if os.path.isdir(checkpoint_path): + # Add trailing slash to prevent tune metadata from + # being written outside the directory. + checkpoint_path = os.path.join(checkpoint_path, "") elif isinstance(checkpoint, dict): saved_as_dict = True checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") @@ -302,19 +352,8 @@ class Trainable: tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) # Save all files in subtree. - data = {} - for basedir, _, file_names in os.walk(tmpdir): - for file_name in file_names: - path = os.path.join(basedir, file_name) - - with open(path, "rb") as f: - data[os.path.relpath(path, tmpdir)] = f.read() - + data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) out = io.BytesIO() - data_dict = pickle.dumps({ - "checkpoint_name": os.path.relpath(checkpoint_path, tmpdir), - "data": data, - }) if len(data_dict) > 10e6: # getting pretty large logger.info("Checkpoint size is {} bytes".format(len(data_dict))) out.write(data_dict) @@ -348,14 +387,15 @@ class Trainable: self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True - logger.info("Restored from checkpoint: %s", checkpoint_path) + logger.info("Restored on %s from checkpoint: %s", self.current_ip(), + checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } - logger.info("Current state after restoring: {}".format(state)) + logger.info("Current state after restoring: %s", state) def restore_from_object(self, obj): """Restores training state from a checkpoint object. @@ -379,6 +419,22 @@ class Trainable: self.restore(checkpoint_path) shutil.rmtree(tmpdir) + def delete_checkpoint(self, checkpoint_path): + """Deletes local copy of checkpoint. + + Args: + checkpoint_path (str): Path to checkpoint. + """ + try: + checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) + except FileNotFoundError: + # The checkpoint won't exist locally if the + # trial was rescheduled to another worker. + logger.debug("Checkpoint not found during garbage collection.") + return + if os.path.exists(checkpoint_dir): + shutil.rmtree(checkpoint_dir) + def export_model(self, export_formats, export_dir=None): """Exports model based on export_formats. @@ -429,7 +485,7 @@ class Trainable: Note that the current working directory will also be changed to this. """ - return self._logdir + return os.path.join(self._logdir, "") @property def iteration(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 055352f0c..8a692f175 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -6,6 +6,7 @@ import ray.cloudpickle as cloudpickle import copy from datetime import datetime import logging +import shutil import uuid import time import tempfile @@ -13,6 +14,7 @@ import os from numbers import Number from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager +from ray.tune.durable_trainable import DurableTrainable from ray.tune.logger import pretty_print, UnifiedLogger from ray.tune.util import flatten_dict # NOTE(rkn): We import ray.tune.registry here instead of importing the names we @@ -73,6 +75,30 @@ class ExportFormat: export_formats[i]) +def checkpoint_deleter(trial_id, runner): + """Returns a checkpoint deleter callback for a runner.""" + if not runner: + return lambda checkpoint: None + + def delete(checkpoint): + """Requests checkpoint deletion asynchronously. + + Args: + checkpoint (Checkpoint): Checkpoint to delete. + """ + if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value: + logger.debug("Trial %s: Deleting checkpoint %s", trial_id, + checkpoint.value) + checkpoint_path = checkpoint.value + # Delete local copy, if any exists. + if os.path.exists(checkpoint_path): + shutil.rmtree(checkpoint_path) + # TODO(ujvl): Batch remote deletes. + runner.delete_checkpoint.remote(checkpoint.value) + + return delete + + class Trial: """A trial object holds the state for one model training run. @@ -98,6 +124,7 @@ class Trial: experiment_tag="", resources=None, stopping_criterion=None, + remote_checkpoint_dir=None, checkpoint_freq=0, checkpoint_at_end=False, sync_on_checkpoint=True, @@ -137,7 +164,7 @@ class Trial: "clear the `resources_per_trial` option.".format( trainable_cls, default_resources)) resources = default_resources - self.address = Location() + self.location = Location() self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.loggers = loggers @@ -148,18 +175,10 @@ class Trial: # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") - self.checkpoint_freq = checkpoint_freq - self.checkpoint_at_end = checkpoint_at_end # stores in memory max/min/last result for each metric by trial self.metric_analysis = {} - self.sync_on_checkpoint = sync_on_checkpoint - newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path) - self.checkpoint_manager = CheckpointManager(keep_checkpoints_num, - checkpoint_score_attr) - self.checkpoint_manager.newest_checkpoint = newest_checkpoint - self.export_formats = export_formats self.status = Trial.PENDING self.start_time = None @@ -169,9 +188,27 @@ class Trial: self.last_debug = 0 self.error_file = None self.error_msg = None - self.num_failures = 0 self.custom_trial_name = None + # Checkpointing fields + if remote_checkpoint_dir: + self.remote_checkpoint_dir_prefix = remote_checkpoint_dir + else: + self.remote_checkpoint_dir_prefix = None + self.checkpoint_freq = checkpoint_freq + self.checkpoint_at_end = checkpoint_at_end + self.sync_on_checkpoint = sync_on_checkpoint + newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) + self.checkpoint_manager = CheckpointManager( + keep_checkpoints_num, checkpoint_score_attr, + checkpoint_deleter(str(self), self.runner)) + self.checkpoint_manager.newest_checkpoint = newest_checkpoint + + # Restoration fields + self.restoring_from = None + self.num_failures = 0 + self.num_consecutive_start_attempts = 0 + # AutoML fields self.results = None self.best_result = None @@ -179,7 +216,6 @@ class Trial: self.extra_arg = None self._nonjson_fields = [ - "checkpoint", "loggers", "sync_to_driver_fn", "results", @@ -192,7 +228,7 @@ class Trial: @property def node_ip(self): - return self.address.hostname + return self.location.hostname @property def checkpoint(self): @@ -202,6 +238,14 @@ class Trial: def generate_id(cls): return str(uuid.uuid1().hex)[:8] + @property + def remote_checkpoint_dir(self): + assert self.logdir, "Trial {}: logdir not initialized.".format(self) + if not self.remote_checkpoint_dir_prefix: + return None + logdir_name = os.path.basename(self.logdir) + return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name) + @classmethod def create_logdir(cls, identifier, local_dir): local_dir = os.path.expanduser(local_dir) @@ -213,7 +257,6 @@ class Trial: def init_logger(self): """Init logger.""" - if not self.result_logger: if not self.logdir: self.logdir = Trial.create_logdir(str(self), self.local_dir) @@ -239,24 +282,20 @@ class Trial: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) - def sync_logger_to_new_location(self, worker_ip): - """Updates the logger location. - - Also pushes logdir to worker_ip, allowing for cross-node recovery. - """ - if self.result_logger: - self.result_logger.sync_results_to_new_location(worker_ip) - self.set_location(Location(worker_ip)) + def set_runner(self, runner): + self.runner = runner + self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner) def set_location(self, location): """Sets the location of the trial.""" - self.address = location + self.location = location def set_status(self, status): """Sets the status of the trial.""" - if status == Trial.RUNNING and self.start_time is None: - self.start_time = time.time() self.status = status + if status == Trial.RUNNING: + if self.start_time is None: + self.start_time = time.time() def close_logger(self): """Closes logger.""" @@ -266,7 +305,7 @@ class Trial: def write_error_log(self, error_msg): if error_msg and self.logdir: - self.num_failures += 1 # may be moved to outer scope? + self.num_failures += 1 self.error_file = os.path.join(self.logdir, "error.txt") with open(self.error_file, "a+") as f: f.write("Failure # {} (occurred at {})\n".format( @@ -276,7 +315,6 @@ class Trial: def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" - if result.get(DONE): return True @@ -309,6 +347,7 @@ class Trial: def clear_checkpoint(self): self.checkpoint.value = None + self.restoring_from = None def on_checkpoint(self, checkpoint): """Hook for handling checkpoints taken by the Trainable. @@ -316,20 +355,52 @@ class Trial: Args: checkpoint (Checkpoint): Checkpoint taken. """ - if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK: - # Wait for any other syncs to finish. We need to sync again after - # this to handle checkpoints taken mid-sync. - self.result_logger.wait() - # Force sync down and wait before tracking the new checkpoint. This - # prevents attempts to restore from partially synced checkpoints. - if self.result_logger.sync_down(): + if checkpoint.storage == Checkpoint.MEMORY: + # TODO(ujvl): Handle this separately to avoid restoration failure. + self.checkpoint_manager.on_checkpoint(checkpoint) + return + if self.sync_on_checkpoint: + try: + # Wait for any other syncs to finish. We need to sync again + # after this to handle checkpoints taken mid-sync. self.result_logger.wait() - else: + except TuneError as e: + # Errors occurring during this wait are not fatal for this + # checkpoint, so it should just be logged. logger.error( - "Trial %s: Checkpoint sync skipped. " - "This should not happen.", self) + "Trial %s: An error occurred during the " + "checkpoint pre-sync wait.", str(e)) + # Force sync down and wait before tracking the new checkpoint. + try: + if self.result_logger.sync_down(): + self.result_logger.wait() + else: + logger.error( + "Trial %s: Checkpoint sync skipped. " + "This should not happen.", self) + except TuneError as e: + if issubclass(self.get_trainable_cls(), DurableTrainable): + # Even though rsync failed the trainable can restore + # from remote durable storage. + logger.error("Trial %s: Sync error - %s", self, str(e)) + else: + # If the trainable didn't have remote storage to upload + # to then this checkpoint may have been lost, so we + # shouldn't track it with the checkpoint_manager. + raise e + if not issubclass(self.get_trainable_cls(), DurableTrainable): + if not os.path.exists(checkpoint.value): + raise TuneError("Trial {}: Checkpoint path {} not " + "found after successful sync down.".format( + self, checkpoint.value)) self.checkpoint_manager.on_checkpoint(checkpoint) + def on_restore(self): + """Handles restoration completion.""" + assert self.is_restoring + self.last_result = self.restoring_from.result + self.restoring_from = None + def should_recover(self): """Returns whether the trial qualifies for retrying. @@ -375,7 +446,11 @@ class Trial: self.verbose = verbose def is_finished(self): - return self.status in [Trial.TERMINATED, Trial.ERROR] + return self.status in [Trial.ERROR, Trial.TERMINATED] + + @property + def is_restoring(self): + return self.restoring_from is not None def __repr__(self): return str(self) @@ -383,7 +458,7 @@ class Trial: def __str__(self): """Combines ``env`` with ``trainable_name`` and ``trial_id``. - Can be overriden with a custom string creator. + Can be overridden with a custom string creator. """ if self.custom_trial_name: return self.custom_trial_name @@ -402,9 +477,9 @@ class Trial: """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. - Note this can only occur if the trial holds a DISK checkpoint. + Note this can only occur if the trial holds a PERSISTENT checkpoint. """ - assert self.checkpoint.storage == Checkpoint.DISK, ( + assert self.checkpoint.storage == Checkpoint.PERSISTENT, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) @@ -415,7 +490,7 @@ class Trial: state["runner"] = None state["result_logger"] = None if self.result_logger: - self.result_logger.flush() + self.result_logger.flush(sync_down=False) state["__logger_started__"] = True else: state["__logger_started__"] = False @@ -424,6 +499,7 @@ class Trial: def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) + if state["status"] == Trial.RUNNING: state["status"] = Trial.PENDING for key in self._nonjson_fields: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 3230d5089..26c023788 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -39,8 +39,11 @@ class TrialExecutor: trial (Trial): Trial to checkpoint. status (Trial.status): Status to set trial to. """ - logger.debug("Trial %s: Changing status from %s to %s.", trial, - trial.status, status) + if trial.status == status: + logger.debug("Trial %s: Status %s unchanged.", trial, trial.status) + else: + logger.debug("Trial %s: Changing status from %s to %s.", trial, + trial.status, status) trial.set_status(status) if status in [Trial.TERMINATED, Trial.ERROR]: self.try_checkpoint_metadata(trial) @@ -226,14 +229,15 @@ class TrialExecutor: raise NotImplementedError("Subclasses of TrialExecutor must provide " "restore() method") - def save(self, trial, storage=Checkpoint.DISK, result=None): + def save(self, trial, storage=Checkpoint.PERSISTENT, result=None): """Saves training state of this trial to a checkpoint. If result is None, this trial's last result will be used. Args: trial (Trial): The state of this trial to be saved. - storage (str): Where to store the checkpoint. Defaults to DISK. + storage (str): Where to store the checkpoint. Defaults to + PERSISTENT. result (dict): The state of this trial as a dictionary to be saved. Return: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index f78bdc8ac..e90ca6197 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -7,7 +7,6 @@ from datetime import datetime import json import logging import os -import re import time import traceback import types @@ -31,12 +30,6 @@ MAX_DEBUG_TRIALS = 20 logger = logging.getLogger(__name__) -def _naturalize(string): - """Provides a natural representation for string for nice sorting.""" - splits = re.split("([0-9]+)", string) - return [int(text) if text.isdigit() else text.lower() for text in splits] - - def _find_newest_ckpt(ckpt_dir): """Returns path to most recently modified checkpoint.""" full_paths = [ @@ -222,6 +215,9 @@ class TrialRunner: # Try syncing down the upload directory. logger.info("Downloading from %s", self._remote_checkpoint_dir) + # TODO(ujvl): Note that this syncs down the entire directory, + # which may also contain trial checkpoints. We should selectively + # sync the necessary files instead. self._syncer.sync_down_if_needed() if not self.checkpoint_exists(self._local_checkpoint_dir): @@ -417,11 +413,18 @@ class TrialRunner: with warn_if_slow("process_failed_trial"): self._process_trial_failure(failed_trial, error_msg=error_msg) else: + # TODO(ujvl): Consider combining get_next_available_trial and + # fetch_result functionality so that we don't timeout on fetch. trial = self.trial_executor.get_next_available_trial() # blocking - with warn_if_slow("process_trial"): - self._process_trial(trial) + if trial.is_restoring: + with warn_if_slow("process_trial_restore"): + self._process_trial_restore(trial) + else: + with warn_if_slow("process_trial"): + self._process_trial(trial) def _process_trial(self, trial): + """Processes a trial result.""" try: result = self.trial_executor.fetch_result(trial) @@ -479,7 +482,24 @@ class TrialRunner: assert False, "Invalid scheduling decision: {}".format( decision) except Exception: - logger.exception("Error processing event.") + logger.exception("Trial %s: Error processing event.", trial) + self._process_trial_failure(trial, traceback.format_exc()) + + def _process_trial_restore(self, trial): + """Processes a trial restore. + + Args: + trial: Trial being restored. + """ + logger.debug("Trial %s: Processing trial restore.", trial) + try: + self.trial_executor.fetch_result(trial) + trial.on_restore() + logger.debug("Trial %s: Restore processed successfully", trial) + self.trial_executor.set_status(trial, Trial.RUNNING) + self.trial_executor.continue_training(trial) + except Exception: + logger.exception("Trial %s: Error processing restore.", trial) self._process_trial_failure(trial, traceback.format_exc()) def _process_trial_failure(self, trial, error_msg): @@ -504,8 +524,8 @@ class TrialRunner: """Checkpoints trial based off trial.last_result.""" if trial.should_checkpoint() or force: # Save trial runtime if possible - if hasattr(trial, "runner") and trial.runner: - self.trial_executor.save(trial, storage=Checkpoint.DISK) + if trial.runner: + self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT) self.trial_executor.try_checkpoint_metadata(trial) def _try_recover(self, trial, error_msg): @@ -517,30 +537,32 @@ class TrialRunner: trial (Trial): Trial to recover. error_msg (str): Error message from prior to invoking this method. """ - try: - self.trial_executor.stop_trial( - trial, - error=error_msg is not None, - error_msg=error_msg, - stop_logger=False) - trial.result_logger.flush() - if self.trial_executor.has_resources(trial.resources): - logger.info( - "Trial %s: Attempting to recover " - "trial state from last checkpoint.", trial) - self.trial_executor.start_trial(trial) - if trial.status == Trial.ERROR: - logger.error("Trial %s: Did not start correctly.", trial) - raise RuntimeError("Trial did not start correctly.") - logger.debug("Trial %s: Started correctly.", trial) + if trial.is_restoring: + # Restore was unsuccessful, try again without checkpoint. + trial.clear_checkpoint() + self.trial_executor.stop_trial( + trial, + error=error_msg is not None, + error_msg=error_msg, + stop_logger=False) + trial.result_logger.flush() + if self.trial_executor.has_resources(trial.resources): + logger.info( + "Trial %s: Attempting to restore " + "trial state from last checkpoint.", trial) + self.trial_executor.start_trial(trial) + if trial.status == Trial.ERROR: + logger.exception( + "Trial %s: Error restoring trial from checkpoint, abort.", + trial) + self._scheduler_alg.on_trial_error(self, trial) + self._search_alg.on_trial_complete(trial.trial_id, error=True) else: - logger.debug("Trial %s: Notifying Scheduler and requeueing.", - trial) - self._requeue_trial(trial) - except Exception: - logger.exception("Error recovering trial from checkpoint, abort.") - self._scheduler_alg.on_trial_error(self, trial) - self._search_alg.on_trial_complete(trial.trial_id, error=True) + logger.debug("Trial %s: Restore dispatched correctly.", trial) + else: + logger.debug("Trial %s: Notifying Scheduler and requeueing.", + trial) + self._requeue_trial(trial) def _requeue_trial(self, trial): """Notification to TrialScheduler and requeue trial. diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 30c7495d4..844a212f8 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -119,8 +119,8 @@ def run(run_or_experiment, `num_samples` of times. local_dir (str): Local dir to save training results to. Defaults to ``~/ray_results``. - upload_dir (str): Optional URI to sync training results to - (e.g. ``s3://bucket``). + upload_dir (str): Optional URI to sync training results and checkpoints + to (e.g. ``s3://bucket`` or ``gs://bucket``). trial_name_creator (func): Optional function for generating the trial string representation. loggers (list): List of logger creators to be used with @@ -130,20 +130,20 @@ def run(run_or_experiment, from upload_dir. If string, then it must be a string template that includes `{source}` and `{target}` for the syncer to run. If not provided, the sync command defaults to standard S3 or gsutil sync - comamnds. - sync_to_driver (func|str): Function for syncing trial logdir from + commands. + sync_to_driver (func|str|bool): Function for syncing trial logdir from remote node to local. If string, then it must be a string template that includes `{source}` and `{target}` for the syncer to run. - If not provided, defaults to using rsync. + If True or not provided, it defaults to using rsync. If False, + syncing to driver is disabled. checkpoint_freq (int): How many training iterations between checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. - sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to - guarantee recoverability. If set to False, checkpoint syncing from - worker to driver is asynchronous. Set this to False only if - synchronous checkpointing is too slow and trial restoration - failures can be tolerated. Defaults to True. + sync_on_checkpoint (bool): Force sync-down of trial checkpoint to + driver. If set to False, checkpoint syncing from worker to driver + is asynchronous and best-effort. This does not affect persistent + storage syncing. Defaults to True. keep_checkpoints_num (int): Number of checkpoints to keep. A value of `None` keeps all checkpoints. Defaults to `None`. If set, need to provide `checkpoint_score_attr`. diff --git a/python/ray/tune/utils/__init__.py b/python/ray/tune/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/tune/utils/mock.py b/python/ray/tune/utils/mock.py new file mode 100644 index 000000000..7f3be4c44 --- /dev/null +++ b/python/ray/tune/utils/mock.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from ray.rllib.agents.mock import _MockTrainer +from ray.tune import DurableTrainable +from ray.tune.sync_client import get_sync_client +from ray.tune.syncer import NodeSyncer + +MOCK_REMOTE_DIR = "/tmp/mock-tune-remote/" +# Sync and delete templates that operate on local directories. +LOCAL_SYNC_TEMPLATE = "mkdir -p {target} && rsync -avz {source}/ {target}/" +LOCAL_DELETE_TEMPLATE = "rm -rf {target}" + + +def mock_storage_client(): + """Mocks storage client that treats a local dir as durable storage.""" + return get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE) + + +class MockNodeSyncer(NodeSyncer): + """Mock NodeSyncer that syncs to and from /tmp""" + + def has_remote_target(self): + return True + + @property + def _remote_path(self): + if self._remote_dir.startswith("/"): + self._remote_dir = self._remote_dir[1:] + return os.path.join(MOCK_REMOTE_DIR, self._remote_dir) + + +class MockRemoteTrainer(_MockTrainer): + """Mock Trainable that saves at tmp for simulated clusters.""" + + def __init__(self, *args, **kwargs): + super(MockRemoteTrainer, self).__init__(*args, **kwargs) + if self._logdir.startswith("/"): + self._logdir = self._logdir[1:] + self._logdir = os.path.join(MOCK_REMOTE_DIR, self._logdir) + if not os.path.exists(self._logdir): + os.makedirs(self._logdir) + + +class MockDurableTrainer(DurableTrainable, _MockTrainer): + """Mock DurableTrainable that saves at tmp for simulated clusters.""" + + # TODO(ujvl): This class uses multiple inheritance; it should be cleaned + # up once the durable training API converges. + + def __init__(self, remote_checkpoint_dir, *args, **kwargs): + _MockTrainer.__init__(self, *args, **kwargs) + DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs) + + def _create_storage_client(self): + return mock_storage_client()