[tune] Async restores and S3/GCP-capable trial FT (#6376)

* Initial commit for asynchronous save/restore

* Set stage for cloud checkpointable trainable.

* Refactor log_sync and sync_client.

* Add durable trainable impl.

* Support delete in cmd based client

* Fix some tests and such

* Cleanup, comments.

* Use upload_dir instead.

* Revert files belonging to other PR in split.

* Pass upload_dir into trainable init.

* Pickle checkpoint at driver, more robust checkpoint_dir discovery.

* Cleanup trainable helper functions, fix tests.

* Addressed comments.

* Fix bugs from cluster testing, add parameterized cluster tests.

* Add trainable util test

* package_ref

* pbt_address

* Fix bug after running pbt example (_save returning dir).

* get cluster tests running, other bug fixes.

* raise_errors

* Fix deleter bug, add durable trainable example.

* Fix cluster test bugs.

* filelock

* save/restore bug fixes

* .

* Working cluster tests.

* Lint, revert to tracking memory checkpoints.

* Documentation, cleanup

* fixinitialsync

* fix_one_test

* Fix cluster test bug

* nit

* lint

* Revert tune md change

* Fix basename bug for directories.

* lint

* fix_tests

* nit_fix

* Add __init__ file.

* Move to utils package

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Ujval Misra
2020-01-02 23:40:53 -05:00
committed by Richard Liaw
parent 57061a15cf
commit ca651af1d7
30 changed files with 1006 additions and 349 deletions
+8 -1
View File
@@ -148,7 +148,14 @@ py_test(
deps = [":tune_lib"],
tags = ["exclusive"],
)
py_test(
name = "test_trainable_util",
size = "small",
srcs = ["tests/test_trainable_util.py"],
deps = [":tune_lib"],
)
py_test(
name = "test_trial_scheduler",
size = "medium",
+2
View File
@@ -8,12 +8,14 @@ from ray.tune.experiment import Experiment
from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn, loguniform)
__all__ = [
"Trainable",
"DurableTrainable",
"TuneError",
"grid_search",
"register_env",
+21 -27
View File
@@ -5,8 +5,6 @@ from __future__ import print_function
import heapq
import logging
import os
import shutil
try:
FileNotFoundError
@@ -23,31 +21,18 @@ class Checkpoint:
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, value is a Python object.
If storage==DISK, value is a path points to the checkpoint in disk.
value (str): If storage==MEMORY, it is a Python object.
If storage==PERSISTENT, it is a path to persistent storage.
"""
MEMORY = "memory"
DISK = "disk"
PERSISTENT = "persistent"
def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}
def delete(self):
"""Deletes checkpoint data if disk checkpoint."""
if self.storage == Checkpoint.DISK and self.value:
checkpoint_dir = self.value
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(
"Attempted to delete checkpoint at {} but "
"path was not found.".format(checkpoint_dir))
elif os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)
@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
@@ -72,13 +57,15 @@ class QueueItem:
class CheckpointManager:
"""Manages checkpoints on the driver for a trial."""
def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
"""Initializes a new CheckpointManager.
Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
checkpoints to keep.
delete_fn (function): Function that deletes checkpoints. Must be
idempotent.
"""
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
assert self.keep_checkpoints_num > 0, (
@@ -88,7 +75,7 @@ class CheckpointManager:
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr
self.delete = delete_fn
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
@@ -101,9 +88,6 @@ class CheckpointManager:
Args:
checkpoint (Checkpoint): Trial state checkpoint.
Raises:
KeyError if checkpoint_score_attr not in result of checkpoint.
"""
old_checkpoint = self.newest_checkpoint
self.newest_checkpoint = checkpoint
@@ -112,7 +96,7 @@ class CheckpointManager:
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
if old_checkpoint not in self._membership:
old_checkpoint.delete()
self.delete(old_checkpoint)
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
@@ -126,11 +110,11 @@ class CheckpointManager:
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
worst.delete()
self.delete(worst)
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint not in self._membership:
old_checkpoint.delete()
if old_checkpoint.value and old_checkpoint not in self._membership:
self.delete(old_checkpoint)
def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
@@ -140,3 +124,13 @@ class CheckpointManager:
def _priority(self, checkpoint):
priority = checkpoint.result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
def __getstate__(self):
state = self.__dict__.copy()
# Avoid serializing lambda since it may capture cyclical dependencies.
state.pop("delete")
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.delete = None
+1
View File
@@ -183,6 +183,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
local_dir=os.path.join(spec["local_dir"], output_path),
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=not args.no_sync_on_checkpoint,
+98
View File
@@ -0,0 +1,98 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.syncer import get_cloud_sync_client
class DurableTrainable(Trainable):
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
Supports checkpointing to and restoring from remote storage. To use this
class, implement the same private methods as ray.tune.Trainable (`_save`,
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).
.. warning:: This class is currently **experimental** and may
be subject to change.
Run this with Tune as follows. Setting `sync_to_driver=False` disables
syncing to the driver to avoid keeping redundant checkpoints around, as
well as preventing the driver from syncing up the same checkpoint.
See ``tune/trainable.py``.
Attributes:
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
storage_client: Tune-internal interface for interacting with external
storage.
>>> tune.run(MyDurableTrainable, sync_to_driver=False)
"""
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
"""Initializes a DurableTrainable.
Args:
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
"""
super(DurableTrainable, self).__init__(*args, **kwargs)
self.remote_checkpoint_dir = remote_checkpoint_dir
self.storage_client = self._create_storage_client()
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint, persisted remotely.
The storage client must provide durability for
restoration to work. That is, once ``storage.client.wait()``
returns after a checkpoint `sync up`, the checkpoint is considered
committed and can be used to restore the trainable.
Args:
checkpoint_dir (Optional[str]): Optional dir to place the
checkpoint. Must be ``logdir`` or a sub-directory.
Returns:
Checkpoint path or prefix that may be passed to restore().
"""
if checkpoint_dir:
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
"a sub-directory.")
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir)
self.storage_client.wait()
return checkpoint_path
def restore(self, checkpoint_path):
"""Restores training state from a given checkpoint persisted remotely.
These checkpoints are returned from calls to save().
Args:
checkpoint_path (str): Local path to checkpoint.
"""
self.storage_client.sync_down(self.remote_checkpoint_dir, self.logdir)
self.storage_client.wait()
super(DurableTrainable, self).restore(checkpoint_path)
def delete_checkpoint(self, checkpoint_path):
"""Deletes checkpoint from both local and remote storage.
Args:
checkpoint_path (str): Local path to checkpoint.
"""
super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
self.storage_client.delete(self._storage_path(local_dirpath))
def _create_storage_client(self):
"""Returns a storage client."""
return get_cloud_sync_client(self.remote_checkpoint_dir)
def _storage_path(self, local_path):
rel_local_path = os.path.relpath(local_path, self.logdir)
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
@@ -0,0 +1,126 @@
import argparse
import numpy as np
import time
import logging
import os
import ray
from ray import tune
from ray.tune import DurableTrainable
from ray.tune.sync_client import get_sync_client
import cloudpickle
logger = logging.getLogger(__name__)
class MockDurableTrainable(DurableTrainable):
"""Mocks the storage client on initialization to store data locally."""
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
# Mock the path as a local path.
local_dir_suffix = remote_checkpoint_dir.split("://")[1]
remote_checkpoint_dir = os.path.join("/tmp", local_dir_suffix)
# Disallow malformed relative paths for delete safety.
assert os.path.abspath(remote_checkpoint_dir).startswith("/tmp")
logger.info("Using %s as the mocked remote checkpoint directory.",
self.remote_checkpoint_dir)
super(MockDurableTrainable, self).__init__(remote_checkpoint_dir,
*args, **kwargs)
def _create_storage_client(self):
sync = "mkdir -p {target} && rsync -avz {source} {target}"
delete = "rm -rf {target}"
return get_sync_client(sync, delete)
class OptimusFn(object):
def __init__(self, params, max_t=10000):
self.params = params
self.noise = np.random.normal(size=max_t) * 0.005
def eval(self, k, add_noise=True):
b0, b1, b2 = self.params
score = (b0 * k / 100 + 0.1 * b1 + 0.5)**(-1) + b2 * 0.01
if add_noise:
return score + abs(self.noise[k])
else:
return score
def get_optimus_trainable(parent_cls):
class OptimusTrainable(parent_cls):
def _setup(self, config):
self.iter = 0
if config.get("seed"):
np.random.seed(config["seed"])
time.sleep(config.get("startup_delay", 0))
params = [config["param1"], config["param2"], config["param3"]]
self.func = OptimusFn(params=params)
self.initial_samples_per_step = 500
self.mock_data = open("/dev/urandom", "rb").read(1024)
def _train(self):
self.iter += 1
new_loss = self.func.eval(self.iter)
time.sleep(0.5)
return {
"mean_loss": float(new_loss),
"mean_accuracy": (2 - new_loss) / 2,
"samples": self.initial_samples_per_step
}
def _save(self, checkpoint_dir):
time.sleep(0.5)
return {
"func": cloudpickle.dumps(self.func),
"seed": np.random.get_state(),
"data": self.mock_data,
"iter": self.iter
}
def _restore(self, checkpoint):
self.func = cloudpickle.loads(checkpoint["func"])
self.data = checkpoint["data"]
self.iter = checkpoint["iter"]
np.random.set_state(checkpoint["seed"])
return OptimusTrainable
def parse():
parser = argparse.ArgumentParser()
parser.add_argument("--local", action="store_true", default=False)
parser.add_argument("--mock-storage", action="store_true", default=False)
parser.add_argument("--remote-dir", type=str)
return parser.parse_args()
if __name__ == "__main__":
args = parse()
address = None if args.local else "auto"
ray.init(address=address)
config = {
"seed": None,
"startup_delay": 0.001,
"param1": tune.sample_from(lambda spec: np.random.exponential(0.1)),
"param2": tune.sample_from(lambda _: np.random.rand()),
"param3": tune.sample_from(lambda _: np.random.rand()),
}
parent = MockDurableTrainable if args.mock_storage else DurableTrainable
analysis = tune.run(
get_optimus_trainable(parent),
name="durableTrainable" + str(time.time()),
config=config,
num_samples=4,
verbose=1,
queue_trials=True,
# fault tolerance parameters
max_failures=-1,
checkpoint_freq=20,
sync_to_driver=False,
sync_on_checkpoint=False,
upload_dir="s3://ray-tune-test/exps/",
checkpoint_score_attr="training_iteration",
)
@@ -11,6 +11,7 @@ from ray.tune.trial import ExportFormat
import argparse
import os
from filelock import FileLock
import random
import torch
import torch.nn as nn
@@ -248,7 +249,8 @@ class PytorchTrainable(tune.Trainable):
self.netG.parameters(),
lr=config.get("lr", 0.01),
betas=(beta1, 0.999))
self.dataloader = get_data_loader()
with FileLock(os.path.expanduser("~/.data.lock")):
self.dataloader = get_data_loader()
def _train(self):
lossG, lossD, is_score = train(
+26 -8
View File
@@ -10,7 +10,7 @@ import six
import types
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import sample_from
@@ -34,6 +34,22 @@ def _raise_deprecation_note(deprecated, replacement, soft=False):
raise DeprecationWarning(error_msg)
def _raise_on_durable(trainable_name, sync_to_driver, upload_dir):
trainable_cls = get_trainable_cls(trainable_name)
from ray.tune.durable_trainable import DurableTrainable
if issubclass(trainable_cls, DurableTrainable):
if sync_to_driver is not False:
raise ValueError(
"EXPERIMENTAL: DurableTrainable will automatically sync "
"results to the provided upload_dir. "
"Set `sync_to_driver=False` to avoid data inconsistencies.")
if not upload_dir:
raise ValueError(
"EXPERIMENTAL: DurableTrainable will automatically sync "
"results to the provided upload_dir. "
"`upload_dir` must be provided.")
class Experiment:
"""Tracks experiment specifications.
@@ -109,6 +125,14 @@ class Experiment:
config = config or {}
self._run_identifier = Experiment.register_if_needed(run)
self.name = name or self._run_identifier
if upload_dir:
self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
else:
self.remote_checkpoint_dir = None
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
spec = {
"run": self._run_identifier,
"stop": stop,
@@ -118,6 +142,7 @@ class Experiment:
"local_dir": os.path.abspath(
os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
"upload_dir": upload_dir,
"remote_checkpoint_dir": self.remote_checkpoint_dir,
"trial_name_creator": trial_name_creator,
"loggers": loggers,
"sync_to_driver": sync_to_driver,
@@ -131,8 +156,6 @@ class Experiment:
"restore": os.path.abspath(os.path.expanduser(restore))
if restore else None
}
self.name = name or self._run_identifier
self.spec = spec
@classmethod
@@ -204,11 +227,6 @@ class Experiment:
if self.local_dir:
return os.path.join(self.local_dir, self.name)
@property
def remote_checkpoint_dir(self):
if self.spec["upload_dir"]:
return os.path.join(self.spec["upload_dir"], self.name)
@property
def run_identifier(self):
"""Returns a string representing the trainable identifier."""
+5 -3
View File
@@ -430,11 +430,13 @@ class UnifiedLogger(Logger):
for _logger in self._loggers:
_logger.close()
def flush(self):
def flush(self, sync_down=True):
for _logger in self._loggers:
_logger.flush()
if not self._log_syncer.sync_down():
logger.warning("Trial %s: Post-flush sync skipped.", self.trial)
if sync_down:
if not self._log_syncer.sync_down():
logger.warning("Trial %s: Post-flush sync skipped.",
self.trial)
def sync_up(self):
return self._log_syncer.sync_up()
+1 -1
View File
@@ -238,7 +238,7 @@ def _get_trial_info(trial, parameters, metrics):
metrics (List[str]): Names of metrics to include.
"""
result = flatten_dict(trial.last_result)
trial_info = [str(trial), trial.status, str(trial.address)]
trial_info = [str(trial), trial.status, str(trial.location)]
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
trial_info += [result.get(metric) for metric in metrics]
return trial_info
+76 -80
View File
@@ -13,10 +13,12 @@ import ray
from ray.exceptions import RayTimeoutError
from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution
from ray.tune.logger import NoopLogger
from ray.tune.trial import Trial, Checkpoint, Location
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.trial import Trial, Checkpoint, Location
from ray.tune.trial_executor import TrialExecutor
from ray.tune.util import warn_if_slow
from ray.tune.error import TuneError
@@ -27,7 +29,6 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 30.0 # seconds
TRIAL_START_ATTEMPTS = 3
class _LocalWrapper:
@@ -86,7 +87,7 @@ class RayTrialExecutor(TrialExecutor):
self._cached_actor)
existing_runner = self._cached_actor
self._cached_actor = None
trial.runner = existing_runner
trial.set_runner(existing_runner)
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
raise AbortTrialExecution(
"Trainable runner reuse requires reset_config() to be "
@@ -122,11 +123,16 @@ class RayTrialExecutor(TrialExecutor):
logger.debug("Trial %s: Setting up new remote runner.", trial)
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
return cls.remote(config=trial.config, logger_creator=logger_creator)
kwargs = {
"config": trial.config,
"logger_creator": logger_creator,
}
if issubclass(trial.get_trainable_cls(), DurableTrainable):
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
return cls.remote(**kwargs)
def _train(self, trial):
"""Start one iteration of training and save remote id."""
if self._find_item(self._paused, trial):
raise TuneError(
"Should not call `train` on PAUSED trial {}. "
@@ -168,9 +174,11 @@ class RayTrialExecutor(TrialExecutor):
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.runner = runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
trial.set_runner(
runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None
or trial.has_checkpoint()))
self.restore(trial, checkpoint)
previous_run = self._find_item(self._paused, trial)
@@ -178,7 +186,7 @@ class RayTrialExecutor(TrialExecutor):
# If Trial was in flight when paused, self._paused stores result.
self._paused.pop(previous_run[0])
self._running[previous_run[0]] = trial
else:
elif not trial.is_restoring:
self._train(trial)
def _stop_trial(self, trial, error=False, error_msg=None,
@@ -194,7 +202,6 @@ class RayTrialExecutor(TrialExecutor):
error_msg (str): Optional error message.
stop_logger (bool): Whether to shut down the trial logger.
"""
if stop_logger:
trial.close_logger()
@@ -206,7 +213,7 @@ class RayTrialExecutor(TrialExecutor):
if hasattr(trial, "runner") and trial.runner:
if (not error and self._reuse_actors
and self._cached_actor is None):
logger.debug("Reusing actor for {}".format(trial.runner))
logger.debug("Reusing actor for %s", trial.runner)
self._cached_actor = trial.runner
else:
logger.debug("Trial %s: Destroying actor.", trial)
@@ -216,7 +223,7 @@ class RayTrialExecutor(TrialExecutor):
logger.exception("Trial %s: Error stopping runner.", trial)
self.set_status(trial, Trial.ERROR)
finally:
trial.runner = None
trial.set_runner(None)
def start_trial(self, trial, checkpoint=None):
"""Starts the trial.
@@ -229,49 +236,22 @@ class RayTrialExecutor(TrialExecutor):
of trial.
"""
self._commit_resources(trial.resources)
remote_runner = None
attempts = 0
while attempts < TRIAL_START_ATTEMPTS:
attempts += 1
if attempts > 1:
logger.warning("Trial %s: Start attempt #%s...", trial,
attempts)
try:
self._start_trial(trial, checkpoint, remote_runner)
break
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
break # don't retry fatal Tune errors
except RayTimeoutError:
# Reuse the existing runner on retries.
remote_runner = trial.runner
warning = ("Runner task timed out. This could be due to "
"slow worker startup.")
if attempts == TRIAL_START_ATTEMPTS:
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
else:
warning += " Reusing the same runner."
logger.warning("Trial %s: %s", trial, warning)
except Exception:
logger.exception("Trial %s: Error starting runner.", trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
remote_runner = None
# This forces the trial to not start from checkpoint.
checkpoint = None
trial.clear_checkpoint()
# Note that we don't return the resources, since they may
# have been lost. TODO(ujvl): is this the right thing to do?
else:
logger.exception(
"Trial %s: Aborting trial after %s start "
"attempts!", trial, TRIAL_START_ATTEMPTS)
try:
self._start_trial(trial, checkpoint)
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
except Exception:
logger.exception("Trial %s: Unexpected error starting runner.",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
# Note that we don't return the resources, since they may
# have been lost. TODO(ujvl): is this the right thing to do?
def _find_item(self, dictionary, item):
out = [rid for rid, t in dictionary.items() if t is item]
@@ -332,7 +312,6 @@ class RayTrialExecutor(TrialExecutor):
def get_running_trials(self):
"""Returns the running trials."""
return list(self._running.values())
def get_alive_node_ips(self):
@@ -387,7 +366,8 @@ class RayTrialExecutor(TrialExecutor):
"""Fetches one result of the running trials.
Returns:
Result of the most recent trial training run."""
Result of the most recent trial training run.
"""
trial_future = self._find_item(self._running, trial)
if not trial_future:
raise ValueError("Trial was not running.")
@@ -437,6 +417,7 @@ class RayTrialExecutor(TrialExecutor):
"Resource invalid: {}".format(resources))
def _update_avail_resources(self, num_retries=5):
resources = None
for i in range(num_retries):
try:
resources = ray.cluster_resources()
@@ -520,7 +501,6 @@ class RayTrialExecutor(TrialExecutor):
def debug_string(self):
"""Returns a human readable message for printing to the console."""
if self._resources_initialized:
status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, "
"{}/{} GiB heap, {}/{} GiB objects".format(
@@ -548,7 +528,6 @@ class RayTrialExecutor(TrialExecutor):
def resource_string(self):
"""Returns a string describing the total resources available."""
if self._resources_initialized:
res_str = ("{} CPUs, {} GPUs, "
"{} GiB heap, {} GiB objects".format(
@@ -570,18 +549,28 @@ class RayTrialExecutor(TrialExecutor):
"""Before step() called, update the available resources."""
self._update_avail_resources()
def save(self, trial, storage=Checkpoint.DISK, result=None):
"""Saves the trial's state to a checkpoint."""
result = result or trial.last_result
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
"""Saves the trial's state to a checkpoint.
Args:
trial (Trial): The state of this trial to be saved.
storage (str): Where to store the checkpoint. Defaults to
PERSISTENT.
result (dict): The state of this trial as a dictionary to be saved.
If result is None, the trial's last result will be used.
Returns:
Checkpoint future, or None if an Exception occurs.
"""
result = result or trial.last_result
if storage == Checkpoint.MEMORY:
value = trial.runner.save_to_object.remote()
checkpoint = Checkpoint(storage, value, result)
else:
with warn_if_slow("save_checkpoint_to_disk"):
with warn_if_slow("save_checkpoint_to_storage"):
# TODO(ujvl): Make this asynchronous.
value = ray.get(trial.runner.save.remote())
checkpoint = Checkpoint(storage, value, result)
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
try:
trial.on_checkpoint(checkpoint)
@@ -600,13 +589,10 @@ class RayTrialExecutor(TrialExecutor):
def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint.
This will also sync the trial results to a new location
if restoring on a different node.
Raises:
RuntimeError: This error is raised if no runner is found.
RayTimeoutError: This error is raised if a remote call to the
runner times out.
AbortTrialExecution: This error is raised if the trial is
ineligible for restoration, given the Tune input arguments.
"""
if checkpoint is None or checkpoint.value is None:
checkpoint = trial.checkpoint
@@ -617,19 +603,29 @@ class RayTrialExecutor(TrialExecutor):
"Trial {}: Unable to restore - no runner found.".format(trial))
value = checkpoint.value
if checkpoint.storage == Checkpoint.MEMORY:
assert not isinstance(value, Checkpoint), type(value)
logger.debug("Trial %s: Attempting restore from object", trial)
# Note that we don't store the remote since in-memory checkpoints
# don't guarantee fault tolerance and don't need to be waited on.
trial.runner.restore_from_object.remote(value)
else:
logger.info("Trial %s: Attempting restore from %s", trial, value)
with warn_if_slow("get_current_ip"):
worker_ip = ray.get(trial.runner.current_ip.remote(),
DEFAULT_GET_TIMEOUT)
with warn_if_slow("sync_to_new_location"):
trial.sync_logger_to_new_location(worker_ip)
with warn_if_slow("restore_from_disk"):
# TODO(ujvl): Take blocking restores out of the control loop.
ray.get(trial.runner.restore.remote(value))
trial.last_result = checkpoint.result
logger.debug("Trial %s: Attempting restore from %s", trial, value)
if issubclass(trial.get_trainable_cls(), DurableTrainable):
remote = trial.runner.restore.remote(value)
elif trial.sync_on_checkpoint:
# This provides FT backwards compatibility in the
# case where a DurableTrainable is not provided.
logger.warning("Trial %s: Reading checkpoint into memory.",
trial)
data_dict = TrainableUtil.pickle_checkpoint(value)
remote = trial.runner.restore_from_object.remote(data_dict)
else:
raise AbortTrialExecution(
"Pass in `sync_on_checkpoint=True` for driver-based trial"
"restoration. Pass in an `upload_dir` and a Trainable "
"extending `DurableTrainable` for remote storage-based "
"restoration")
self._running[remote] = trial
trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.
+53 -10
View File
@@ -28,25 +28,31 @@ def noop(*args):
return
def get_sync_client(sync_function):
def get_sync_client(sync_function, delete_function=None):
"""Returns a sync client.
Args:
sync_function (Optional[str|function]): Sync function.
delete_function (Optional[str|function]): Delete function. Must be
the same type as sync_function if it is provided.
Raises:
ValueError if sync_function is malformed.
ValueError if sync_function or delete_function are malformed.
"""
if sync_function is None:
return None
if delete_function and type(sync_function) != type(delete_function):
raise ValueError("Sync and delete functions must be of same type.")
if isinstance(sync_function, types.FunctionType):
delete_function = delete_function or noop
client_cls = FunctionBasedClient
elif isinstance(sync_function, str):
delete_function = delete_function or noop_template
client_cls = CommandBasedClient
else:
raise ValueError("Sync function {} must be string or function".format(
sync_function))
return client_cls(sync_function, sync_function)
return client_cls(sync_function, sync_function, delete_function)
def get_cloud_sync_client(remote_path):
@@ -63,17 +69,19 @@ def get_cloud_sync_client(remote_path):
raise ValueError(
"Upload uri starting with '{}' requires awscli tool"
" to be installed".format(S3_PREFIX))
template = "aws s3 sync {source} {target}"
template = "aws s3 sync {source} {target} --only-show-errors"
delete_template = "aws s3 rm {target} --recursive --only-show-errors"
elif remote_path.startswith(GS_PREFIX):
if not distutils.spawn.find_executable("gsutil"):
raise ValueError(
"Upload uri starting with '{}' requires gsutil tool"
" to be installed".format(GS_PREFIX))
template = "gsutil rsync -r {source} {target}"
delete_template = "gsutil rm -r {target}"
else:
raise ValueError("Upload uri must start with one of: {}"
"".format(ALLOWED_REMOTE_PREFIXES))
return CommandBasedClient(template, template)
return CommandBasedClient(template, template, delete_template)
class SyncClient:
@@ -103,6 +111,17 @@ class SyncClient:
"""
raise NotImplementedError
def delete(self, target):
"""Deletes target.
Args:
target (str): Target path.
Returns:
True if delete initiation successful, False otherwise.
"""
raise NotImplementedError
def wait(self):
"""Waits for current sync to complete, if asynchronously started."""
pass
@@ -113,9 +132,10 @@ class SyncClient:
class FunctionBasedClient(SyncClient):
def __init__(self, sync_up_func, sync_down_func):
def __init__(self, sync_up_func, sync_down_func, delete_func=None):
self.sync_up_func = sync_up_func
self.sync_down_func = sync_down_func
self.delete_func = delete_func or noop
def sync_up(self, source, target):
self.sync_up_func(source, target)
@@ -125,12 +145,19 @@ class FunctionBasedClient(SyncClient):
self.sync_down_func(source, target)
return True
def delete(self, target):
self.delete_func(target)
return True
NOOP = FunctionBasedClient(noop, noop)
class CommandBasedClient(SyncClient):
def __init__(self, sync_up_template, sync_down_template):
def __init__(self,
sync_up_template,
sync_down_template,
delete_template=noop_template):
"""Syncs between two directories with the given command.
Arguments:
@@ -138,11 +165,14 @@ class CommandBasedClient(SyncClient):
include replacement fields '{source}' and '{target}'.
sync_down_template (str): A runnable string template; needs to
include replacement fields '{source}' and '{target}'.
delete_template (Optional[str]): A runnable string template; needs
to include replacement field '{target}'. Noop by default.
"""
self._validate_sync_string(sync_up_template)
self._validate_sync_string(sync_down_template)
self.sync_up_template = sync_up_template
self.sync_down_template = sync_down_template
self.delete_template = delete_template
self.logfile = None
self.cmd_process = None
@@ -153,7 +183,7 @@ class CommandBasedClient(SyncClient):
logdir (str): Log directory.
"""
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync", dir=logdir, suffix=".log", delete=False)
prefix="log_sync_out", dir=logdir, suffix=".log", delete=False)
def sync_up(self, source, target):
return self._execute(self.sync_up_template, source, target)
@@ -161,14 +191,27 @@ class CommandBasedClient(SyncClient):
def sync_down(self, source, target):
return self._execute(self.sync_down_template, source, target)
def delete(self, target):
if self.is_running:
logger.warning("Last sync client cmd still in progress, skipping.")
return False
final_cmd = self.delete_template.format(target=quote(target))
logger.debug("Running delete: {}".format(final_cmd))
self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile)
return True
def wait(self):
if self.cmd_process:
_, error_msg = self.cmd_process.communicate()
error_msg = error_msg.decode("ascii")
code = self.cmd_process.returncode
args = self.cmd_process.args
self.cmd_process = None
if code != 0:
raise TuneError("Sync error ({}): {}".format(code, error_msg))
raise TuneError("Sync error. Ran command: {}\n"
"Error message ({}): {}".format(
args, code, error_msg))
def reset(self):
if self.is_running:
@@ -177,7 +220,7 @@ class CommandBasedClient(SyncClient):
@property
def is_running(self):
"""Returns whether a sync process is running."""
"""Returns whether a sync or delete process is running."""
if self.cmd_process:
self.cmd_process.poll()
return self.cmd_process.returncode is None
+11 -9
View File
@@ -191,6 +191,8 @@ class NodeSyncer(Syncer):
def sync_down(self):
if not self.has_remote_target():
return True
logger.debug("Syncing from %s to %s", self._remote_path,
self._local_dir)
return super(NodeSyncer, self).sync_down()
@property
@@ -250,23 +252,23 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
Args:
local_dir (str): Source directory for syncing.
remote_dir (str): Target directory for syncing. If not provided, a
no-op Syncer is returned.
sync_function (func|str): Function for syncing the local_dir to
noop Syncer is returned.
sync_function (func|str|bool): Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults rsync.
syncer to run. If True or not provided, it defaults rsync. If
False, a noop Syncer is returned.
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
elif not remote_dir:
elif not remote_dir or sync_function is False:
sync_client = NOOP
elif sync_function:
elif sync_function and sync_function is not True:
sync_client = get_sync_client(sync_function)
else:
sync_up = log_sync_template()
sync_down = log_sync_template(options="--remove-source-files")
if sync_up and sync_down:
sync_client = CommandBasedClient(sync_up, sync_down)
sync = log_sync_template()
if sync:
sync_client = CommandBasedClient(sync, sync)
sync_client.set_logdir(local_dir)
else:
sync_client = NOOP
+33 -1
View File
@@ -2,16 +2,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
import copy
import os
import time
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import Trainable, TuneError
from ray.tune import DurableTrainable, Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.trial import Trial
@@ -25,6 +28,7 @@ from ray.tune.experiment import Experiment
from ray.tune.resources import Resources
from ray.tune.suggest import grid_search
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
class TrainableFunctionApiTest(unittest.TestCase):
@@ -703,6 +707,34 @@ class TrainableFunctionApiTest(unittest.TestCase):
]
self.assertTrue(all(complete_results1))
def testDurableTrainable(self):
class TestTrain(DurableTrainable):
def _setup(self, config):
self.state = {"hi": 1, "iter": 0}
def _train(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
sync_client = mock_storage_client()
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
with patch(mock_get_client) as mock_get_cloud_sync_client:
mock_get_cloud_sync_client.return_value = sync_client
test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
checkpoint_path = test_trainable.save()
test_trainable.train()
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
self.assertEqual(test_trainable.state["hi"], 1)
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def testCheckpointDict(self):
class TestTrain(Trainable):
def _setup(self, config):
@@ -16,23 +16,28 @@ class CheckpointManagerTest(unittest.TestCase):
def mock_result(i):
return {"i": i}
def checkpoint_manager(self, keep_checkpoints_num):
return CheckpointManager(
keep_checkpoints_num, "i", delete_fn=lambda c: None)
def testOnCheckpointOrdered(self):
"""
Tests increasing priorities. Also tests that that the worst checkpoints
are deleted when necessary.
"""
keep_checkpoints_num = 2
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i))
for i in range(3)
]
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
with patch.object(checkpoint_manager, "delete") as \
delete_mock:
for j in range(3):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 2 else 1
self.assertEqual(rmtree_mock.call_count, expected_deletes)
self.assertEqual(delete_mock.call_count, expected_deletes, j)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
@@ -47,17 +52,17 @@ class CheckpointManagerTest(unittest.TestCase):
that the worst checkpoints are deleted when necessary.
"""
keep_checkpoints_num = 2
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i))
for i in range(3, -1, -1)
]
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
with patch.object(checkpoint_manager, "delete") as delete_mock:
for j in range(0, len(checkpoints)):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 3 else 1
self.assertEqual(rmtree_mock.call_count, expected_deletes)
self.assertEqual(delete_mock.call_count, expected_deletes)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
@@ -71,7 +76,7 @@ class CheckpointManagerTest(unittest.TestCase):
Tests that the best checkpoints are tracked and ordered correctly.
"""
keep_checkpoints_num = 4
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
for i in range(16)
@@ -92,7 +97,7 @@ class CheckpointManagerTest(unittest.TestCase):
checkpoint has no checkpoint score attribute.
"""
keep_checkpoints_num = 1
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
with patch.object(logger, "error") as log_error_mock:
+124 -56
View File
@@ -9,20 +9,26 @@ import os
import pytest
import shutil
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import ray
from ray import tune
from ray.rllib import _register_all
from ray.cluster_utils import Cluster
from ray.test_utils import run_string_as_driver_nonblocking
from ray.tune import register_trainable
from ray.tune.experiment import Experiment
from ray.tune.error import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
from ray.tune.resources import Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.syncer import Syncer
from ray.tune.trainable import TrainableUtil
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils.mock import (MockDurableTrainer, MockRemoteTrainer,
MockNodeSyncer, mock_storage_client,
MOCK_REMOTE_DIR)
def _start_new_cluster():
@@ -36,6 +42,8 @@ def _start_new_cluster():
})
})
# Pytest doesn't play nicely with imports
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
_register_all()
return cluster
@@ -53,7 +61,6 @@ def start_connected_cluster():
@pytest.fixture
def start_connected_emptyhead_cluster():
"""Starts head with no resources."""
cluster = Cluster(
initialize_head=True,
connect=True,
@@ -65,6 +72,8 @@ def start_connected_emptyhead_cluster():
})
# Pytest doesn't play nicely with imports
_register_all()
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
@@ -208,7 +217,8 @@ def test_queue_trials(start_connected_emptyhead_cluster):
assert gpu_trial.status == Trial.TERMINATED
def test_trial_migration(start_connected_emptyhead_cluster):
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"""Removing a node while cluster has space should migrate trial.
The trial state should also be consistent with the checkpoint.
@@ -220,14 +230,16 @@ def test_trial_migration(start_connected_emptyhead_cluster):
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"stopping_criterion": {
"training_iteration": 3
"training_iteration": 4
},
"checkpoint_freq": 2,
"max_failures": 2
"max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
# Test recovery of trial that hasn't been checkpointed
t = Trial("__fake", **kwargs)
t = Trial(trainable_id, **kwargs)
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
@@ -241,13 +253,13 @@ def test_trial_migration(start_connected_emptyhead_cluster):
# because checkpoint handling is messy and should be refactored
# rather than hotfixed.
# assert t.last_result is None, "Trial result not restored correctly."
for i in range(3):
for i in range(4):
runner.step()
assert t.status == Trial.TERMINATED
# Test recovery of trial that has been checkpointed
t2 = Trial("__fake", **kwargs)
t2 = Trial(trainable_id, **kwargs)
runner.add_trial(t2)
runner.step() # start
runner.step() # 1 result
@@ -256,13 +268,23 @@ def test_trial_migration(start_connected_emptyhead_cluster):
node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2)
cluster.wait_for_nodes()
runner.step() # 3 result + start and fail 4 result
runner.step() # Recovery step
runner.step() # Process recovery
runner.step() # result
if t2.status != Trial.TERMINATED:
runner.step()
assert t2.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that won't be checkpointed
t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
kwargs = {
"stopping_criterion": {
"training_iteration": 3
},
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
t3 = Trial(trainable_id, **kwargs)
runner.add_trial(t3)
runner.step() # start
runner.step() # 1 result
@@ -278,7 +300,8 @@ def test_trial_migration(start_connected_emptyhead_cluster):
runner.step()
def test_trial_requeue(start_connected_emptyhead_cluster):
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
"""Removing a node in full cluster causes Trial to be requeued."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
@@ -290,10 +313,12 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
"training_iteration": 5
},
"checkpoint_freq": 1,
"max_failures": 1
"max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -309,7 +334,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
runner.step()
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
trainable_id):
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
@@ -318,32 +345,59 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"stopping_criterion": {
"training_iteration": 3
"training_iteration": 4
},
"checkpoint_freq": 2,
"max_failures": 2
"max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake_remote",
}
# Test recovery of trial that has been checkpointed
t1 = Trial("__fake", **kwargs)
runner.add_trial(t1)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
# The following patches only affect __fake_remote.
find_checkpoint_dir = TrainableUtil.find_checkpoint_dir
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir:
runner.step() # Recovery step
for i in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
client = mock_storage_client()
return MockNodeSyncer(local_dir, remote_dir, client)
mock_get_node_syncer.side_effect = mock_get_syncer_fn
def mock_find_dir_fn(checkpoint_path):
"""Converts back to local path first."""
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
checkpoint_path = os.path.join("/", checkpoint_path)
return find_checkpoint_dir(checkpoint_path)
# __fake_remote trainables save to a separate "remote" directory.
# TrainableUtil will not check this path unless we mock it.
mock_find_dir.side_effect = mock_find_dir_fn
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # collect result 3, kick off + fail result 4
runner.step() # Recovery step
runner.step() # Process Recovery + step 4
for i in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()
def test_cluster_down_simple(start_connected_cluster, tmpdir):
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
"""Tests that TrialRunner save/restore works on cluster shutdown."""
cluster = start_connected_cluster
cluster.add_node(num_cpus=1)
@@ -356,9 +410,11 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
"training_iteration": 2
},
"checkpoint_freq": 1,
"max_failures": 1
"max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -374,6 +430,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
cluster = _start_new_cluster()
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
runner.step() # start
runner.step() # process restore
runner.step() # start2
for i in range(3):
@@ -387,26 +444,31 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
cluster.shutdown()
def test_cluster_down_full(start_connected_cluster, tmpdir):
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
"""Tests that run_experiment restoring works on cluster shutdown."""
cluster = start_connected_cluster
dirpath = str(tmpdir)
exp1_args = dict(
run="__fake",
use_default_sync = trainable_id == "__fake"
from ray.tune.result import DEFAULT_RESULTS_DIR
local_dir = DEFAULT_RESULTS_DIR
upload_dir = None if use_default_sync else MOCK_REMOTE_DIR
base_dict = dict(
run=trainable_id,
stop=dict(training_iteration=3),
local_dir=dirpath,
checkpoint_freq=1)
exp2_args = dict(run="__fake", stop=dict(training_iteration=3))
exp3_args = dict(
run="__fake",
stop=dict(training_iteration=3),
config=dict(mock_error=True))
local_dir=local_dir,
upload_dir=upload_dir,
sync_to_driver=use_default_sync,
)
exp1_args = base_dict
exp2_args = dict(base_dict.items(), local_dir=dirpath, checkpoint_freq=1)
exp3_args = dict(base_dict.items(), config=dict(mock_error=True))
exp4_args = dict(
run="__fake",
stop=dict(training_iteration=3),
config=dict(mock_error=True),
checkpoint_freq=1)
base_dict.items(), config=dict(mock_error=True), checkpoint_freq=1)
all_experiments = {
"exp1": exp1_args,
"exp2": exp2_args,
@@ -414,14 +476,20 @@ def test_cluster_down_full(start_connected_cluster, tmpdir):
"exp4": exp4_args
}
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
with patch(mock_get_client) as mock_get_cloud_syncer:
mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client())
mock_get_cloud_syncer.return_value = mock_syncer
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
trials = tune.run_experiments(
all_experiments, resume=True, raise_on_failed_trial=False)
trials = tune.run_experiments(
all_experiments, resume=True, raise_on_failed_trial=False)
assert len(trials) == 4
assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
ray.shutdown()
+14
View File
@@ -4,11 +4,25 @@ from __future__ import print_function
import unittest
import ray
from ray.rllib import _register_all
from ray.tune import register_trainable
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.error import TuneError
class ExperimentTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def setUp(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
def testConvertExperimentFromExperiment(self):
exp1 = Experiment(**{
"name": "foo",
@@ -4,12 +4,9 @@ from __future__ import division
from __future__ import print_function
import json
import sys
import unittest
from unittest.mock import patch
import ray
from ray.exceptions import RayTimeoutError
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
@@ -41,33 +38,11 @@ class RayTrialExecutorTest(unittest.TestCase):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.DISK)
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSaveRestoreTimeout(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.DISK)
self.trial_executor.set_status(trial, Trial.PAUSED)
ray_get = ray.get
start_trial = self.trial_executor._start_trial
# Timeout on first two attempts, then succeed on subsequent gets.
side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get]
with patch.object(self.trial_executor, "_start_trial") as mock_start:
with patch("ray.get", side_effect=side_effects):
mock_start.side_effect = start_trial
self.trial_executor.start_trial(trial, trial.checkpoint)
# Trial starts successfully on 3rd attempt.
assert mock_start.call_count == 3
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
def testPauseResume(self):
"""Tests that pausing works for trials in flight."""
trial = Trial("__fake")
@@ -216,4 +191,5 @@ class LocalModeExecutorTest(RayTrialExecutorTest):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
@@ -0,0 +1,47 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import shutil
import unittest
from ray.tune.trainable import TrainableUtil
class TrainableUtilTest(unittest.TestCase):
def setUp(self):
self.checkpoint_dir = "/tmp/tune/MyTrainable123"
TrainableUtil.make_checkpoint_dir(self.checkpoint_dir)
def tearDown(self):
self.addCleanup(shutil.rmtree, self.checkpoint_dir)
def testFindCheckpointDir(self):
checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt")
os.makedirs(checkpoint_path)
found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
self.assertEquals(self.checkpoint_dir, found_dir)
with self.assertRaises(FileNotFoundError):
parent = os.path.dirname(found_dir)
TrainableUtil.find_checkpoint_dir(parent)
def testPickleCheckpoint(self):
for i in range(5):
path = os.path.join(self.checkpoint_dir, str(i))
with open(path, "w") as f:
f.write(str(i))
checkpoint_path = os.path.join(self.checkpoint_dir, "0")
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
loaded = pickle.loads(data_dict)
checkpoint_name = os.path.basename(checkpoint_path)
self.assertEqual(loaded["checkpoint_name"], checkpoint_name)
for i in range(5):
path = os.path.join(self.checkpoint_dir, str(i))
self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
+3 -1
View File
@@ -19,9 +19,11 @@ from ray.tune.suggest import BasicVariantGenerator
class TrialRunnerTest(unittest.TestCase):
def setUp(self):
_register_all() # re-register the evicted objects
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def testTrialStatus(self):
ray.init()
@@ -182,9 +182,11 @@ class TrialRunnerTest2(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step() # Restore step
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 2)
runner.step() # Restore step
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3)
@@ -242,6 +244,7 @@ class TrialRunnerTest2(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step() # Restore step
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
@@ -203,7 +203,7 @@ class _MockTrialExecutor(TrialExecutor):
def restore(self, trial, checkpoint=None):
pass
def save(self, trial, type=Checkpoint.DISK, result=None):
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
return trial.trainable_name
def reset_trial(self, trial, new_config, new_experiment_tag):
+79 -23
View File
@@ -29,6 +29,57 @@ logger = logging.getLogger(__name__)
SETUP_TIME_THRESHOLD = 10
class TrainableUtil:
@staticmethod
def pickle_checkpoint(checkpoint_path):
"""Pickles checkpoint data."""
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
data = {}
for basedir, _, file_names in os.walk(checkpoint_dir):
for file_name in file_names:
path = os.path.join(basedir, file_name)
with open(path, "rb") as f:
data[os.path.relpath(path, checkpoint_dir)] = f.read()
# Use normpath so that a directory path isn't mapped to empty string.
name = os.path.basename(os.path.normpath(checkpoint_path))
name += os.path.sep if os.path.isdir(checkpoint_path) else ""
data_dict = pickle.dumps({
"checkpoint_name": name,
"data": data,
})
return data_dict
@staticmethod
def find_checkpoint_dir(checkpoint_path):
"""Returns the directory containing the checkpoint path.
Raises:
FileNotFoundError if the directory is not found.
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError("Path does not exist", checkpoint_path)
if os.path.isdir(checkpoint_path):
checkpoint_dir = checkpoint_path
else:
checkpoint_dir = os.path.dirname(checkpoint_path)
while checkpoint_dir != os.path.dirname(checkpoint_dir):
if os.path.exists(os.path.join(checkpoint_dir, ".is_checkpoint")):
break
checkpoint_dir = os.path.dirname(checkpoint_dir)
else:
raise FileNotFoundError("Checkpoint directory not found for {}"
.format(checkpoint_path))
return checkpoint_dir
@staticmethod
def make_checkpoint_dir(checkpoint_dir):
"""Creates a checkpoint directory at the provided path."""
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Drop marker in directory to identify it as a checkpoint dir.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
class Trainable:
"""Abstract class for trainable models, functions, etc.
@@ -119,17 +170,14 @@ class Trainable:
>>> extra_cpu=config["workers"],
>>> extra_gpu=int(config["use_gpu"]) * config["workers"])
"""
return None
@classmethod
def resource_help(cls, config):
"""
"""Returns a help string for configuring this trainable's resources.
Args:
config (dict): The Trainer's config dict.
Returns:
str: A help string for configuring this trainable's resources.
"""
return ""
@@ -258,9 +306,7 @@ class Trainable:
"""
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
TrainableUtil.make_checkpoint_dir(checkpoint_dir)
checkpoint = self._save(checkpoint_dir)
saved_as_dict = False
if isinstance(checkpoint, string_types):
@@ -270,6 +316,10 @@ class Trainable:
"given checkpoint dir {}: {}".format(
checkpoint_dir, checkpoint))
checkpoint_path = checkpoint
if os.path.isdir(checkpoint_path):
# Add trailing slash to prevent tune metadata from
# being written outside the directory.
checkpoint_path = os.path.join(checkpoint_path, "")
elif isinstance(checkpoint, dict):
saved_as_dict = True
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
@@ -302,19 +352,8 @@ class Trainable:
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_path = self.save(tmpdir)
# Save all files in subtree.
data = {}
for basedir, _, file_names in os.walk(tmpdir):
for file_name in file_names:
path = os.path.join(basedir, file_name)
with open(path, "rb") as f:
data[os.path.relpath(path, tmpdir)] = f.read()
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
out = io.BytesIO()
data_dict = pickle.dumps({
"checkpoint_name": os.path.relpath(checkpoint_path, tmpdir),
"data": data,
})
if len(data_dict) > 10e6: # getting pretty large
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
out.write(data_dict)
@@ -348,14 +387,15 @@ class Trainable:
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = True
logger.info("Restored from checkpoint: %s", checkpoint_path)
logger.info("Restored on %s from checkpoint: %s", self.current_ip(),
checkpoint_path)
state = {
"_iteration": self._iteration,
"_timesteps_total": self._timesteps_total,
"_time_total": self._time_total,
"_episodes_total": self._episodes_total,
}
logger.info("Current state after restoring: {}".format(state))
logger.info("Current state after restoring: %s", state)
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
@@ -379,6 +419,22 @@ class Trainable:
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def delete_checkpoint(self, checkpoint_path):
"""Deletes local copy of checkpoint.
Args:
checkpoint_path (str): Path to checkpoint.
"""
try:
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
except FileNotFoundError:
# The checkpoint won't exist locally if the
# trial was rescheduled to another worker.
logger.debug("Checkpoint not found during garbage collection.")
return
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
def export_model(self, export_formats, export_dir=None):
"""Exports model based on export_formats.
@@ -429,7 +485,7 @@ class Trainable:
Note that the current working directory will also be changed to this.
"""
return self._logdir
return os.path.join(self._logdir, "")
@property
def iteration(self):
+117 -41
View File
@@ -6,6 +6,7 @@ import ray.cloudpickle as cloudpickle
import copy
from datetime import datetime
import logging
import shutil
import uuid
import time
import tempfile
@@ -13,6 +14,7 @@ import os
from numbers import Number
from ray.tune import TuneError
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.logger import pretty_print, UnifiedLogger
from ray.tune.util import flatten_dict
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
@@ -73,6 +75,30 @@ class ExportFormat:
export_formats[i])
def checkpoint_deleter(trial_id, runner):
"""Returns a checkpoint deleter callback for a runner."""
if not runner:
return lambda checkpoint: None
def delete(checkpoint):
"""Requests checkpoint deletion asynchronously.
Args:
checkpoint (Checkpoint): Checkpoint to delete.
"""
if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
logger.debug("Trial %s: Deleting checkpoint %s", trial_id,
checkpoint.value)
checkpoint_path = checkpoint.value
# Delete local copy, if any exists.
if os.path.exists(checkpoint_path):
shutil.rmtree(checkpoint_path)
# TODO(ujvl): Batch remote deletes.
runner.delete_checkpoint.remote(checkpoint.value)
return delete
class Trial:
"""A trial object holds the state for one model training run.
@@ -98,6 +124,7 @@ class Trial:
experiment_tag="",
resources=None,
stopping_criterion=None,
remote_checkpoint_dir=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
@@ -137,7 +164,7 @@ class Trial:
"clear the `resources_per_trial` option.".format(
trainable_cls, default_resources))
resources = default_resources
self.address = Location()
self.location = Location()
self.resources = resources or Resources(cpu=1, gpu=0)
self.stopping_criterion = stopping_criterion or {}
self.loggers = loggers
@@ -148,18 +175,10 @@ class Trial:
# Local trial state that is updated during the run
self.last_result = {}
self.last_update_time = -float("inf")
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
# stores in memory max/min/last result for each metric by trial
self.metric_analysis = {}
self.sync_on_checkpoint = sync_on_checkpoint
newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path)
self.checkpoint_manager = CheckpointManager(keep_checkpoints_num,
checkpoint_score_attr)
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
self.export_formats = export_formats
self.status = Trial.PENDING
self.start_time = None
@@ -169,9 +188,27 @@ class Trial:
self.last_debug = 0
self.error_file = None
self.error_msg = None
self.num_failures = 0
self.custom_trial_name = None
# Checkpointing fields
if remote_checkpoint_dir:
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
else:
self.remote_checkpoint_dir_prefix = None
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
self.sync_on_checkpoint = sync_on_checkpoint
newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
self.checkpoint_manager = CheckpointManager(
keep_checkpoints_num, checkpoint_score_attr,
checkpoint_deleter(str(self), self.runner))
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
# Restoration fields
self.restoring_from = None
self.num_failures = 0
self.num_consecutive_start_attempts = 0
# AutoML fields
self.results = None
self.best_result = None
@@ -179,7 +216,6 @@ class Trial:
self.extra_arg = None
self._nonjson_fields = [
"checkpoint",
"loggers",
"sync_to_driver_fn",
"results",
@@ -192,7 +228,7 @@ class Trial:
@property
def node_ip(self):
return self.address.hostname
return self.location.hostname
@property
def checkpoint(self):
@@ -202,6 +238,14 @@ class Trial:
def generate_id(cls):
return str(uuid.uuid1().hex)[:8]
@property
def remote_checkpoint_dir(self):
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
if not self.remote_checkpoint_dir_prefix:
return None
logdir_name = os.path.basename(self.logdir)
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
@classmethod
def create_logdir(cls, identifier, local_dir):
local_dir = os.path.expanduser(local_dir)
@@ -213,7 +257,6 @@ class Trial:
def init_logger(self):
"""Init logger."""
if not self.result_logger:
if not self.logdir:
self.logdir = Trial.create_logdir(str(self), self.local_dir)
@@ -239,24 +282,20 @@ class Trial:
raise ValueError("Cannot update resources while Trial is running.")
self.resources = Resources(cpu, gpu, **kwargs)
def sync_logger_to_new_location(self, worker_ip):
"""Updates the logger location.
Also pushes logdir to worker_ip, allowing for cross-node recovery.
"""
if self.result_logger:
self.result_logger.sync_results_to_new_location(worker_ip)
self.set_location(Location(worker_ip))
def set_runner(self, runner):
self.runner = runner
self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner)
def set_location(self, location):
"""Sets the location of the trial."""
self.address = location
self.location = location
def set_status(self, status):
"""Sets the status of the trial."""
if status == Trial.RUNNING and self.start_time is None:
self.start_time = time.time()
self.status = status
if status == Trial.RUNNING:
if self.start_time is None:
self.start_time = time.time()
def close_logger(self):
"""Closes logger."""
@@ -266,7 +305,7 @@ class Trial:
def write_error_log(self, error_msg):
if error_msg and self.logdir:
self.num_failures += 1 # may be moved to outer scope?
self.num_failures += 1
self.error_file = os.path.join(self.logdir, "error.txt")
with open(self.error_file, "a+") as f:
f.write("Failure # {} (occurred at {})\n".format(
@@ -276,7 +315,6 @@ class Trial:
def should_stop(self, result):
"""Whether the given result meets this trial's stopping criteria."""
if result.get(DONE):
return True
@@ -309,6 +347,7 @@ class Trial:
def clear_checkpoint(self):
self.checkpoint.value = None
self.restoring_from = None
def on_checkpoint(self, checkpoint):
"""Hook for handling checkpoints taken by the Trainable.
@@ -316,20 +355,52 @@ class Trial:
Args:
checkpoint (Checkpoint): Checkpoint taken.
"""
if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK:
# Wait for any other syncs to finish. We need to sync again after
# this to handle checkpoints taken mid-sync.
self.result_logger.wait()
# Force sync down and wait before tracking the new checkpoint. This
# prevents attempts to restore from partially synced checkpoints.
if self.result_logger.sync_down():
if checkpoint.storage == Checkpoint.MEMORY:
# TODO(ujvl): Handle this separately to avoid restoration failure.
self.checkpoint_manager.on_checkpoint(checkpoint)
return
if self.sync_on_checkpoint:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
self.result_logger.wait()
else:
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", self)
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait.", str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if self.result_logger.sync_down():
self.result_logger.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", self)
except TuneError as e:
if issubclass(self.get_trainable_cls(), DurableTrainable):
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error("Trial %s: Sync error - %s", self, str(e))
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not issubclass(self.get_trainable_cls(), DurableTrainable):
if not os.path.exists(checkpoint.value):
raise TuneError("Trial {}: Checkpoint path {} not "
"found after successful sync down.".format(
self, checkpoint.value))
self.checkpoint_manager.on_checkpoint(checkpoint)
def on_restore(self):
"""Handles restoration completion."""
assert self.is_restoring
self.last_result = self.restoring_from.result
self.restoring_from = None
def should_recover(self):
"""Returns whether the trial qualifies for retrying.
@@ -375,7 +446,11 @@ class Trial:
self.verbose = verbose
def is_finished(self):
return self.status in [Trial.TERMINATED, Trial.ERROR]
return self.status in [Trial.ERROR, Trial.TERMINATED]
@property
def is_restoring(self):
return self.restoring_from is not None
def __repr__(self):
return str(self)
@@ -383,7 +458,7 @@ class Trial:
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
Can be overriden with a custom string creator.
Can be overridden with a custom string creator.
"""
if self.custom_trial_name:
return self.custom_trial_name
@@ -402,9 +477,9 @@ class Trial:
"""Memento generator for Trial.
Sets RUNNING trials to PENDING, and flushes the result logger.
Note this can only occur if the trial holds a DISK checkpoint.
Note this can only occur if the trial holds a PERSISTENT checkpoint.
"""
assert self.checkpoint.storage == Checkpoint.DISK, (
assert self.checkpoint.storage == Checkpoint.PERSISTENT, (
"Checkpoint must not be in-memory.")
state = self.__dict__.copy()
state["resources"] = resources_to_json(self.resources)
@@ -415,7 +490,7 @@ class Trial:
state["runner"] = None
state["result_logger"] = None
if self.result_logger:
self.result_logger.flush()
self.result_logger.flush(sync_down=False)
state["__logger_started__"] = True
else:
state["__logger_started__"] = False
@@ -424,6 +499,7 @@ class Trial:
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
if state["status"] == Trial.RUNNING:
state["status"] = Trial.PENDING
for key in self._nonjson_fields:
+8 -4
View File
@@ -39,8 +39,11 @@ class TrialExecutor:
trial (Trial): Trial to checkpoint.
status (Trial.status): Status to set trial to.
"""
logger.debug("Trial %s: Changing status from %s to %s.", trial,
trial.status, status)
if trial.status == status:
logger.debug("Trial %s: Status %s unchanged.", trial, trial.status)
else:
logger.debug("Trial %s: Changing status from %s to %s.", trial,
trial.status, status)
trial.set_status(status)
if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)
@@ -226,14 +229,15 @@ class TrialExecutor:
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"restore() method")
def save(self, trial, storage=Checkpoint.DISK, result=None):
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
"""Saves training state of this trial to a checkpoint.
If result is None, this trial's last result will be used.
Args:
trial (Trial): The state of this trial to be saved.
storage (str): Where to store the checkpoint. Defaults to DISK.
storage (str): Where to store the checkpoint. Defaults to
PERSISTENT.
result (dict): The state of this trial as a dictionary to be saved.
Return:
+57 -35
View File
@@ -7,7 +7,6 @@ from datetime import datetime
import json
import logging
import os
import re
import time
import traceback
import types
@@ -31,12 +30,6 @@ MAX_DEBUG_TRIALS = 20
logger = logging.getLogger(__name__)
def _naturalize(string):
"""Provides a natural representation for string for nice sorting."""
splits = re.split("([0-9]+)", string)
return [int(text) if text.isdigit() else text.lower() for text in splits]
def _find_newest_ckpt(ckpt_dir):
"""Returns path to most recently modified checkpoint."""
full_paths = [
@@ -222,6 +215,9 @@ class TrialRunner:
# Try syncing down the upload directory.
logger.info("Downloading from %s", self._remote_checkpoint_dir)
# TODO(ujvl): Note that this syncs down the entire directory,
# which may also contain trial checkpoints. We should selectively
# sync the necessary files instead.
self._syncer.sync_down_if_needed()
if not self.checkpoint_exists(self._local_checkpoint_dir):
@@ -417,11 +413,18 @@ class TrialRunner:
with warn_if_slow("process_failed_trial"):
self._process_trial_failure(failed_trial, error_msg=error_msg)
else:
# TODO(ujvl): Consider combining get_next_available_trial and
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial() # blocking
with warn_if_slow("process_trial"):
self._process_trial(trial)
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
else:
with warn_if_slow("process_trial"):
self._process_trial(trial)
def _process_trial(self, trial):
"""Processes a trial result."""
try:
result = self.trial_executor.fetch_result(trial)
@@ -479,7 +482,24 @@ class TrialRunner:
assert False, "Invalid scheduling decision: {}".format(
decision)
except Exception:
logger.exception("Error processing event.")
logger.exception("Trial %s: Error processing event.", trial)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_restore(self, trial):
"""Processes a trial restore.
Args:
trial: Trial being restored.
"""
logger.debug("Trial %s: Processing trial restore.", trial)
try:
self.trial_executor.fetch_result(trial)
trial.on_restore()
logger.debug("Trial %s: Restore processed successfully", trial)
self.trial_executor.set_status(trial, Trial.RUNNING)
self.trial_executor.continue_training(trial)
except Exception:
logger.exception("Trial %s: Error processing restore.", trial)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_failure(self, trial, error_msg):
@@ -504,8 +524,8 @@ class TrialRunner:
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint() or force:
# Save trial runtime if possible
if hasattr(trial, "runner") and trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.DISK)
if trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT)
self.trial_executor.try_checkpoint_metadata(trial)
def _try_recover(self, trial, error_msg):
@@ -517,30 +537,32 @@ class TrialRunner:
trial (Trial): Trial to recover.
error_msg (str): Error message from prior to invoking this method.
"""
try:
self.trial_executor.stop_trial(
trial,
error=error_msg is not None,
error_msg=error_msg,
stop_logger=False)
trial.result_logger.flush()
if self.trial_executor.has_resources(trial.resources):
logger.info(
"Trial %s: Attempting to recover "
"trial state from last checkpoint.", trial)
self.trial_executor.start_trial(trial)
if trial.status == Trial.ERROR:
logger.error("Trial %s: Did not start correctly.", trial)
raise RuntimeError("Trial did not start correctly.")
logger.debug("Trial %s: Started correctly.", trial)
if trial.is_restoring:
# Restore was unsuccessful, try again without checkpoint.
trial.clear_checkpoint()
self.trial_executor.stop_trial(
trial,
error=error_msg is not None,
error_msg=error_msg,
stop_logger=False)
trial.result_logger.flush()
if self.trial_executor.has_resources(trial.resources):
logger.info(
"Trial %s: Attempting to restore "
"trial state from last checkpoint.", trial)
self.trial_executor.start_trial(trial)
if trial.status == Trial.ERROR:
logger.exception(
"Trial %s: Error restoring trial from checkpoint, abort.",
trial)
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
else:
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
trial)
self._requeue_trial(trial)
except Exception:
logger.exception("Error recovering trial from checkpoint, abort.")
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
logger.debug("Trial %s: Restore dispatched correctly.", trial)
else:
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
trial)
self._requeue_trial(trial)
def _requeue_trial(self, trial):
"""Notification to TrialScheduler and requeue trial.
+10 -10
View File
@@ -119,8 +119,8 @@ def run(run_or_experiment,
`num_samples` of times.
local_dir (str): Local dir to save training results to.
Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results to
(e.g. ``s3://bucket``).
upload_dir (str): Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket`` or ``gs://bucket``).
trial_name_creator (func): Optional function for generating
the trial string representation.
loggers (list): List of logger creators to be used with
@@ -130,20 +130,20 @@ def run(run_or_experiment,
from upload_dir. If string, then it must be a string template that
includes `{source}` and `{target}` for the syncer to run. If not
provided, the sync command defaults to standard S3 or gsutil sync
comamnds.
sync_to_driver (func|str): Function for syncing trial logdir from
commands.
sync_to_driver (func|str|bool): Function for syncing trial logdir from
remote node to local. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If not provided, defaults to using rsync.
If True or not provided, it defaults to using rsync. If False,
syncing to driver is disabled.
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to
guarantee recoverability. If set to False, checkpoint syncing from
worker to driver is asynchronous. Set this to False only if
synchronous checkpointing is too slow and trial restoration
failures can be tolerated. Defaults to True.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
driver. If set to False, checkpoint syncing from worker to driver
is asynchronous and best-effort. This does not affect persistent
storage syncing. Defaults to True.
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
`None` keeps all checkpoints. Defaults to `None`. If set, need
to provide `checkpoint_score_attr`.
View File
+59
View File
@@ -0,0 +1,59 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from ray.rllib.agents.mock import _MockTrainer
from ray.tune import DurableTrainable
from ray.tune.sync_client import get_sync_client
from ray.tune.syncer import NodeSyncer
MOCK_REMOTE_DIR = "/tmp/mock-tune-remote/"
# Sync and delete templates that operate on local directories.
LOCAL_SYNC_TEMPLATE = "mkdir -p {target} && rsync -avz {source}/ {target}/"
LOCAL_DELETE_TEMPLATE = "rm -rf {target}"
def mock_storage_client():
"""Mocks storage client that treats a local dir as durable storage."""
return get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE)
class MockNodeSyncer(NodeSyncer):
"""Mock NodeSyncer that syncs to and from /tmp"""
def has_remote_target(self):
return True
@property
def _remote_path(self):
if self._remote_dir.startswith("/"):
self._remote_dir = self._remote_dir[1:]
return os.path.join(MOCK_REMOTE_DIR, self._remote_dir)
class MockRemoteTrainer(_MockTrainer):
"""Mock Trainable that saves at tmp for simulated clusters."""
def __init__(self, *args, **kwargs):
super(MockRemoteTrainer, self).__init__(*args, **kwargs)
if self._logdir.startswith("/"):
self._logdir = self._logdir[1:]
self._logdir = os.path.join(MOCK_REMOTE_DIR, self._logdir)
if not os.path.exists(self._logdir):
os.makedirs(self._logdir)
class MockDurableTrainer(DurableTrainable, _MockTrainer):
"""Mock DurableTrainable that saves at tmp for simulated clusters."""
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned
# up once the durable training API converges.
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
_MockTrainer.__init__(self, *args, **kwargs)
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)
def _create_storage_client(self):
return mock_storage_client()