[tune] Fault tolerance improvements (#5877)

* Precede ray.get with ray.wait.

* Trigger checkpoint deletes locally in Trainable

* Clean-up code.

* Minor changes.

* Track best checkpoint so far again

* Pulled checkpoint GC out of Trainable.

* Added comments, error logging.

* Immediate pull after checkpoint taken; rsync source delete on pull

* Minor doc fixes

* Fix checkpoint manager bug

* Fix bugs, tests, formatting

* Fix bugs, feature flag for force sync.

* Fix test.

* Fix minor bugs: clear proc and less verbose sync_on_checkpoint warnings.

* Fix bug: update IP of last_result.

* Fixed message.

* Added a lot of logging.

* Changes to ray trial executor.

* More bug fixes (logging after failure), better logging.

* Fix richards bug and logging

* Add comments.

* try-except

* Fix heapq bug.

* .

* Move handling of no available trials to ray_trial_executor (#1)

* Fix formatting bug, lint.

* Addressed Richard's comments

* Revert tests.

* fix rebase

* Fix trial location reporting.

* Fix test

* Fix lint

* Rebase, use ray.get w/ timeout, lint.

* lint

* fix rebase

* Address richard's comments
This commit is contained in:
Ujval Misra
2019-11-18 01:14:41 -08:00
committed by Richard Liaw
parent 66edebce3a
commit 2965dc1b72
20 changed files with 846 additions and 460 deletions
+18 -4
View File
@@ -13,8 +13,22 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn, loguniform)
__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run", "run_experiments", "Experiment", "function",
"sample_from", "track", "uniform", "choice", "randint", "randn",
"loguniform", "progress_reporter", "ExperimentAnalysis", "Analysis"
"Trainable",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run",
"run_experiments",
"Experiment",
"function",
"sample_from",
"track",
"uniform",
"choice",
"randint",
"randn",
"loguniform",
"ExperimentAnalysis",
"Analysis",
]
+142
View File
@@ -0,0 +1,142 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import heapq
import logging
import os
import shutil
try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError
logger = logging.getLogger(__name__)
class Checkpoint(object):
"""Describes a checkpoint of trial state.
Checkpoint may be saved in different storage.
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, value is a Python object.
If storage==DISK, value is a path points to the checkpoint in disk.
"""
MEMORY = "memory"
DISK = "disk"
def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}
def delete(self):
"""Deletes checkpoint data if disk checkpoint."""
if self.storage == Checkpoint.DISK and self.value:
checkpoint_dir = self.value
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(
"Attempted to delete checkpoint at {} but "
"path was not found.".format(checkpoint_dir))
elif os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)
@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)
class QueueItem(object):
def __init__(self, priority, value):
self.priority = priority
self.value = value
def __cmp__(self, other):
# For python2.7 compatibility.
if self.priority == other.priority:
return 0
return -1 if self.priority < other.priority else 1
def __lt__(self, other):
return self.priority < other.priority
class CheckpointManager(object):
"""Manages checkpoints on the driver for a trial."""
def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
"""Initializes a new CheckpointManager.
Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
checkpoints to keep.
"""
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
assert self.keep_checkpoints_num > 0, (
"keep_checkpoints_num must be greater than 0.")
self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-")
if self._checkpoint_score_desc:
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
Sets newest checkpoint. Deletes previous checkpoint as long as it isn't
one of the best ones. Also deletes the worst checkpoint if at capacity.
Args:
checkpoint (Checkpoint): Trial state checkpoint.
Raises:
KeyError if checkpoint_score_attr not in result of checkpoint.
"""
old_checkpoint = self.newest_checkpoint
self.newest_checkpoint = checkpoint
try:
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
if old_checkpoint not in self._membership:
old_checkpoint.delete()
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
return
if len(self._best_checkpoints) < self.keep_checkpoints_num:
heapq.heappush(self._best_checkpoints, queue_item)
self._membership.add(checkpoint)
elif queue_item.priority >= self._best_checkpoints[0].priority:
worst = heapq.heappushpop(self._best_checkpoints, queue_item).value
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
worst.delete()
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint not in self._membership:
old_checkpoint.delete()
def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority)
return [queue_item.value for queue_item in checkpoints]
def _priority(self, checkpoint):
priority = checkpoint.result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
+10 -1
View File
@@ -76,11 +76,19 @@ def make_parser(parser_creator=None, **kwargs):
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--no-sync-on-checkpoint",
action="store_true",
help="Disable sync-down of trial checkpoint, which is enabled by "
"default to guarantee recoverability. If set, checkpoint syncing from "
"worker to driver is asynchronous. Set this only if synchronous "
"checkpointing is too slow and trial restoration failures can be "
"tolerated")
parser.add_argument(
"--keep-checkpoints-num",
default=None,
type=int,
help="Number of last checkpoints to keep. Others get "
help="Number of best checkpoints to keep. Others get "
"deleted. Default (None) keeps all checkpoints.")
parser.add_argument(
"--checkpoint-score-attr",
@@ -177,6 +185,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=not args.no_sync_on_checkpoint,
keep_checkpoints_num=args.keep_checkpoints_num,
checkpoint_score_attr=args.checkpoint_score_attr,
export_formats=spec.get("export_formats", []),
+9 -2
View File
@@ -72,6 +72,7 @@ class Experiment(object):
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
export_formats=None,
@@ -80,6 +81,11 @@ class Experiment(object):
repeat=None,
trial_resources=None,
sync_function=None):
"""Initialize a new Experiment.
The args here take the same meaning as the command line flags defined
in `tune.py:run`.
"""
if repeat:
_raise_deprecation_note("repeat", "num_samples", soft=False)
if trial_resources:
@@ -102,7 +108,7 @@ class Experiment(object):
"criteria must take exactly 2 parameters.".format(stop))
config = config or {}
self._run_identifier = Experiment._register_if_needed(run)
self._run_identifier = Experiment.register_if_needed(run)
spec = {
"run": self._run_identifier,
"stop": stop,
@@ -117,6 +123,7 @@ class Experiment(object):
"sync_to_driver": sync_to_driver,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"sync_on_checkpoint": sync_on_checkpoint,
"keep_checkpoints_num": keep_checkpoints_num,
"checkpoint_score_attr": checkpoint_score_attr,
"export_formats": export_formats or [],
@@ -156,7 +163,7 @@ class Experiment(object):
return exp
@classmethod
def _register_if_needed(cls, run_object):
def register_if_needed(cls, run_object):
"""Registers Trainable or Function at runtime.
Assumes already registered if run_object is a string.
+38 -11
View File
@@ -17,12 +17,17 @@ logger = logging.getLogger(__name__)
_log_sync_warned = False
def log_sync_template():
def log_sync_template(options=""):
"""Syncs the local_dir between driver and worker if possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
Args:
options (str): Addtional rsync options.
Returns:
Sync template with source and target parameters.
"""
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
@@ -36,12 +41,14 @@ def log_sync_template():
_log_sync_warned = True
return
return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {{source}} {{target}}"""
).format(ssh_key=quote(ssh_key))
rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no"
rsh = rsh.format(ssh_key=quote(ssh_key))
template = """rsync {options} -savz -e "{rsh}" {{source}} {{target}}"""
return template.format(options=options, rsh=rsh)
class NodeSyncMixin(object):
# TODO(ujvl): Refactor this code.
"""Mixin for syncing files to/from a remote dir to a local dir."""
def __init__(self):
@@ -53,23 +60,43 @@ class NodeSyncMixin(object):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip
def _check_valid_worker_ip(self):
def has_remote_target(self):
"""Returns whether the Syncer has a remote target."""
if not self.worker_ip:
logger.debug("Worker ip unknown, skipping log sync for {}".format(
self._local_dir))
logger.debug("Worker IP unknown, skipping log sync for %s",
self._local_dir)
return False
if self.worker_ip == self.local_ip:
logger.debug(
"Worker ip is local ip, skipping log sync for {}".format(
self._local_dir))
logger.debug("Worker IP is local IP, skipping log sync for %s",
self._local_dir)
return False
return True
def sync_up_if_needed(self):
if not self.has_remote_target():
return True
super(NodeSyncMixin, self).sync_up()
def sync_down_if_needed(self):
if not self.has_remote_target():
return True
super(NodeSyncMixin, self).sync_down()
def sync_down(self):
if not self.has_remote_target():
return True
return super(NodeSyncMixin, self).sync_down()
def sync_up(self):
if not self.has_remote_target():
return True
return super(NodeSyncMixin, self).sync_up()
@property
def _remote_path(self):
ssh_user = get_ssh_user()
global _log_sync_warned
if not self._check_valid_worker_ip():
if not self.has_remote_target():
return
if ssh_user is None:
if not _log_sync_warned:
+23 -8
View File
@@ -409,8 +409,8 @@ class UnifiedLogger(Logger):
try:
self._loggers.append(cls(self.config, self.logdir, self.trial))
except Exception as exc:
logger.warning("Could not instantiate {}: {}.".format(
cls.__name__, str(exc)))
logger.warning("Could not instantiate %s: %s.", cls.__name__,
str(exc))
self._log_syncer = get_log_syncer(
self.logdir,
remote_dir=self.logdir,
@@ -429,12 +429,21 @@ class UnifiedLogger(Logger):
def close(self):
for _logger in self._loggers:
_logger.close()
self._log_syncer.sync_down()
def flush(self):
for _logger in self._loggers:
_logger.flush()
self._log_syncer.sync_down()
if not self._log_syncer.sync_down():
logger.warning("Trial %s: Post-flush sync skipped.", self.trial)
def sync_up(self):
return self._log_syncer.sync_up()
def sync_down(self):
return self._log_syncer.sync_down()
def wait(self):
self._log_syncer.wait()
def sync_results_to_new_location(self, worker_ip):
"""Sends the current log directory to the remote node.
@@ -443,13 +452,19 @@ class UnifiedLogger(Logger):
with the Ray autoscaler.
"""
if worker_ip != self._log_syncer.worker_ip:
logger.info("Syncing (blocking) results to {}".format(worker_ip))
logger.info("Trial %s: Syncing (blocking) results to %s",
self.trial, worker_ip)
self._log_syncer.reset()
self._log_syncer.set_worker_ip(worker_ip)
self._log_syncer.sync_up()
# TODO: change this because this is blocking. But failures
# are rare, so maybe this is OK?
if not self._log_syncer.sync_up():
logger.error(
"Trial %s: Sync up to new location skipped. "
"This should not occur.", self.trial)
self._log_syncer.wait()
else:
logger.error(
"Trial %s: Sync attempted to same IP %s. This "
"should not occur.", self.trial, worker_ip)
class _SafeFallbackEncoder(json.JSONEncoder):
+107 -112
View File
@@ -4,18 +4,18 @@ from __future__ import division
from __future__ import print_function
import logging
import math
import os
import random
import time
import traceback
import ray
from ray.exceptions import RayTimeoutError
from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.error import AbortTrialExecution
from ray.tune.logger import NoopLogger
from ray.tune.trial import Trial, Checkpoint
from ray.tune.trial import Trial, Checkpoint, Location
from ray.tune.resources import Resources
from ray.tune.trial_executor import TrialExecutor
from ray.tune.util import warn_if_slow
@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 30.0 # seconds
class _LocalWrapper(object):
@@ -37,7 +38,7 @@ class _LocalWrapper(object):
class RayTrialExecutor(TrialExecutor):
"""An implemention of TrialExecutor based on Ray."""
"""An implementation of TrialExecutor based on Ray."""
def __init__(self,
queue_trials=False,
@@ -71,36 +72,18 @@ class RayTrialExecutor(TrialExecutor):
if ray.is_initialized():
self._update_avail_resources()
def _setup_runner(self, trial, reuse_allowed):
def _setup_remote_runner(self, trial, reuse_allowed):
trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
remote_logdir = trial.logdir
if (self._reuse_actors and reuse_allowed
and self._cached_actor is not None):
logger.debug("Reusing cached runner {} for {}".format(
self._cached_actor, trial.trial_id))
existing_runner = self._cached_actor
self._cached_actor = None
else:
if self._cached_actor:
logger.debug(
"Cannot reuse cached runner {} for new trial".format(
self._cached_actor))
self._cached_actor.stop.remote()
self._cached_actor.__ray_terminate__.remote()
self._cached_actor = None
existing_runner = None
cls = ray.remote(
num_cpus=trial.resources.cpu,
num_gpus=trial.resources.gpu,
memory=trial.resources.memory,
object_store_memory=trial.resources.object_store_memory,
resources=trial.resources.custom_resources)(
trial.get_trainable_cls())
trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
remote_logdir = trial.logdir
if existing_runner:
trial.runner = existing_runner
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
raise AbortTrialExecution(
@@ -108,6 +91,21 @@ class RayTrialExecutor(TrialExecutor):
"implemented and return True.")
return existing_runner
if self._cached_actor:
logger.debug("Cannot reuse cached runner {} for new trial".format(
self._cached_actor))
self._cached_actor.stop.remote()
self._cached_actor.__ray_terminate__.remote()
self._cached_actor = None
cls = ray.remote(
num_cpus=trial.resources.cpu,
num_gpus=trial.resources.gpu,
memory=trial.resources.memory,
object_store_memory=trial.resources.object_store_memory,
resources=trial.resources.custom_resources)(
trial.get_trainable_cls())
def logger_creator(config):
# Set the working dir in the remote process, for user file writes
if not os.path.exists(remote_logdir):
@@ -116,6 +114,10 @@ class RayTrialExecutor(TrialExecutor):
os.chdir(remote_logdir)
return NoopLogger(config, remote_logdir)
# Clear the Trial's location (to be updated later on result)
# since we don't know where the remote runner is placed.
trial.set_location(Location())
logger.info("Trial %s: Setting up new remote runner.", trial)
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
return cls.remote(config=trial.config, logger_creator=logger_creator)
@@ -136,22 +138,20 @@ class RayTrialExecutor(TrialExecutor):
"""Starts trial and restores last result if trial was paused.
Raises:
ValueError if restoring from checkpoint fails.
RuntimeError if restoring from checkpoint fails.
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.runner = self._setup_runner(
trial.runner = self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None
or trial._checkpoint.value is not None)
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
if not self.restore(trial, checkpoint):
if trial.status == Trial.ERROR:
raise RuntimeError(
"Restore from checkpoint failed for Trial {}.".format(
str(trial)))
"Trial {}: Restore from checkpoint failed.".format(trial))
previous_run = self._find_item(self._paused, trial)
if (prior_status == Trial.PAUSED and previous_run):
if prior_status == Trial.PAUSED and previous_run:
# If Trial was in flight when paused, self._paused stores result.
self._paused.pop(previous_run[0])
self._running[previous_run[0]] = trial
@@ -175,10 +175,8 @@ class RayTrialExecutor(TrialExecutor):
if stop_logger:
trial.close_logger()
if error:
self.set_status(trial, Trial.ERROR)
else:
self.set_status(trial, Trial.TERMINATED)
self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
trial.set_location(Location())
try:
trial.write_error_log(error_msg)
@@ -188,12 +186,11 @@ class RayTrialExecutor(TrialExecutor):
logger.debug("Reusing actor for {}".format(trial.runner))
self._cached_actor = trial.runner
else:
logger.debug(
"Destroying actor for trial {}.".format(trial))
logger.debug("Trial %s: Destroying actor.", trial)
trial.runner.stop.remote()
trial.runner.__ray_terminate__.remote()
except Exception:
logger.exception("Error stopping runner for Trial %s", str(trial))
logger.exception("Trial %s: Error stopping runner.", trial)
self.set_status(trial, Trial.ERROR)
finally:
trial.runner = None
@@ -208,32 +205,35 @@ class RayTrialExecutor(TrialExecutor):
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
"""
self._commit_resources(trial.resources)
try:
self._start_trial(trial, checkpoint)
except Exception as e:
logger.exception("Error starting runner for Trial %s", str(trial))
error_msg = traceback.format_exc()
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
return # don't retry fatal Tune errors
except Exception:
logger.exception(
"Trial %s: Error starting runner. Attempting "
"restart without checkpoint.", trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
if isinstance(e, AbortTrialExecution):
return # don't retry fatal Tune errors
try:
# This forces the trial to not start from checkpoint.
trial.clear_checkpoint()
logger.info(
"Trying to start runner for Trial %s without checkpoint.",
str(trial))
self._start_trial(trial)
except Exception:
logger.exception(
"Error starting runner for Trial %s, aborting!",
str(trial))
"Trial %s: Error starting runner on second "
"attempt, aborting!", trial)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
# note that we don't return the resources, since they may
# have been lost
# Note that we don't return the resources, since they may
# have been lost. TODO(ujvl): is this the right thing to do?
def _find_item(self, dictionary, item):
out = [rid for rid, t in dictionary.items() if t is item]
@@ -245,7 +245,7 @@ class RayTrialExecutor(TrialExecutor):
self._stop_trial(
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
if prior_status == Trial.RUNNING:
logger.debug("Returning resources for Trial %s.", str(trial))
logger.debug("Trial %s: Returning resources.", trial)
self._return_resources(trial.resources)
out = self._find_item(self._running, trial)
for result_id in out:
@@ -253,7 +253,6 @@ class RayTrialExecutor(TrialExecutor):
def continue_training(self, trial):
"""Continues the training of this trial."""
self._train(trial)
def pause_trial(self, trial):
@@ -262,7 +261,6 @@ class RayTrialExecutor(TrialExecutor):
If trial is in-flight, preserves return value in separate queue
before pausing, which is restored when Trial is resumed.
"""
trial_future = self._find_item(self._running, trial)
if trial_future:
self._paused[trial_future[0]] = trial
@@ -285,7 +283,13 @@ class RayTrialExecutor(TrialExecutor):
trial.config = new_config
trainable = trial.runner
with warn_if_slow("reset_config"):
reset_val = ray.get(trainable.reset_config.remote(new_config))
try:
reset_val = ray.get(
trainable.reset_config.remote(new_config),
DEFAULT_GET_TIMEOUT)
except RayTimeoutError:
logger.exception("Trial %s: reset_config timed out.")
return False
return reset_val
def get_running_trials(self):
@@ -351,7 +355,7 @@ class RayTrialExecutor(TrialExecutor):
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
with warn_if_slow("fetch_result"):
result = ray.get(trial_future[0])
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
# For local mode
if isinstance(result, _LocalWrapper):
@@ -530,48 +534,28 @@ class RayTrialExecutor(TrialExecutor):
def save(self, trial, storage=Checkpoint.DISK):
"""Saves the trial's state to a checkpoint."""
trial._checkpoint.storage = storage
trial._checkpoint.last_result = trial.last_result
if storage == Checkpoint.MEMORY:
trial._checkpoint.value = trial.runner.save_to_object.remote()
value = trial.runner.save_to_object.remote()
checkpoint = Checkpoint(storage, value, trial.last_result)
else:
# Keeps only highest performing checkpoints if enabled
if trial.keep_checkpoints_num:
try:
last_attr_val = trial.last_result[
trial.checkpoint_score_attr]
if (trial.compare_checkpoints(last_attr_val)
and not math.isnan(last_attr_val)):
trial.best_checkpoint_attr_value = last_attr_val
self._checkpoint_and_erase(trial)
except KeyError:
logger.warning(
"Result dict has no key: {}. keep"
"_checkpoints_num flag will not work".format(
trial.checkpoint_score_attr))
else:
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(
trial.runner.save.remote())
with warn_if_slow("save_checkpoint_to_disk"):
value = ray.get(trial.runner.save.remote())
checkpoint = Checkpoint(storage, value, trial.last_result)
return trial._checkpoint.value
def _checkpoint_and_erase(self, trial):
"""Checkpoints the model and erases old checkpoints
if needed.
Parameters
----------
trial : trial to save
"""
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(trial.runner.save.remote())
if len(trial.history) >= trial.keep_checkpoints_num:
ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
trial.history.pop()
trial.history.insert(0, trial._checkpoint.value)
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
try:
trial.on_checkpoint(checkpoint)
except Exception:
logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint.value)
return None
if profile.too_slow and trial.sync_on_checkpoint:
logger.warning(
"Consider turning off forced head-worker trial checkpoint "
"syncs by setting sync_on_checkpoint=False. Note that this "
"might result in faulty trial restoration for some worker "
"failure modes.")
return checkpoint.value
def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint.
@@ -580,11 +564,13 @@ class RayTrialExecutor(TrialExecutor):
if restoring on a different node.
"""
if checkpoint is None or checkpoint.value is None:
checkpoint = trial._checkpoint
checkpoint = trial.checkpoint
if checkpoint is None or checkpoint.value is None:
return True
if trial.runner is None:
logger.error("Unable to restore - no runner.")
logger.error(
"Trial %s: Unable to restore - no runner. "
"Setting status to ERROR.", trial)
self.set_status(trial, Trial.ERROR)
return False
try:
@@ -593,23 +579,31 @@ class RayTrialExecutor(TrialExecutor):
assert type(value) != Checkpoint, type(value)
trial.runner.restore_from_object.remote(value)
else:
# TODO: Somehow, the call to get the current IP on the
# remote actor can be very slow - a better fix would
# be to use an actor table to detect the IP of the Trainable
# and rsync the files there.
# See https://github.com/ray-project/ray/issues/5168
logger.info("Trial %s: Attempting restoration from %s", trial,
checkpoint.value)
with warn_if_slow("get_current_ip"):
worker_ip = ray.get(trial.runner.current_ip.remote())
worker_ip = ray.get(trial.runner.current_ip.remote(),
DEFAULT_GET_TIMEOUT)
with warn_if_slow("sync_to_new_location"):
trial.sync_logger_to_new_location(worker_ip)
with warn_if_slow("restore_from_disk"):
ray.get(trial.runner.restore.remote(value))
trial.last_result = checkpoint.last_result
return True
except Exception:
logger.exception("Error restoring runner for Trial %s.", trial)
ray.get(
trial.runner.restore.remote(value),
DEFAULT_GET_TIMEOUT)
except RayTimeoutError:
logger.exception(
"Trial %s: Unable to restore - runner task timed "
"out. Setting status to ERROR", trial)
self.set_status(trial, Trial.ERROR)
return False
except Exception:
logger.exception(
"Trial %s: Unable to restore. Setting status to ERROR", trial)
self.set_status(trial, Trial.ERROR)
return False
trial.last_result = checkpoint.result
return True
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.
@@ -619,7 +613,8 @@ class RayTrialExecutor(TrialExecutor):
"""
if trial.export_formats and len(trial.export_formats) > 0:
return ray.get(
trial.runner.export_model.remote(trial.export_formats))
trial.runner.export_model.remote(trial.export_formats),
DEFAULT_GET_TIMEOUT)
return {}
def has_gpus(self):
+217 -147
View File
@@ -28,26 +28,142 @@ SYNC_PERIOD = 300
_syncers = {}
def validate_sync_string(sync_string):
if "{source}" not in sync_string:
raise ValueError("Sync template missing '{source}'.")
if "{target}" not in sync_string:
raise ValueError("Sync template missing '{target}'.")
def wait_for_sync():
for syncer in _syncers.values():
syncer.wait()
class BaseSyncer(object):
def __init__(self, local_dir, remote_dir, sync_function=None):
class SyncClient(object):
def sync_up(self, source, target):
"""Sync up from source to target.
Args:
source (str): Source path.
target (str): Target path.
Returns:
True if sync initiation successful, False otherwise.
"""
raise NotImplementedError
def sync_down(self, source, target):
"""Sync down from source to target.
Args:
source (str): Source path.
target (str): Target path.
Returns:
True if sync initiation successful, False otherwise.
"""
raise NotImplementedError
def wait(self):
"""Wait for current sync to complete, if asynchronously started."""
pass
def reset(self):
"""Resets state."""
pass
class FunctionBasedClient(SyncClient):
def __init__(self, sync_up_func, sync_down_func):
self.sync_up_func = sync_up_func
self.sync_down_func = sync_down_func
def sync_up(self, source, target):
self.sync_up_func(source, target)
return True
def sync_down(self, source, target):
self.sync_down_func(source, target)
return True
class CommandBasedClient(SyncClient):
def __init__(self, sync_up_template, sync_down_template):
"""Syncs between two directories with the given command.
Arguments:
sync_up_template (str): A runnable string template; needs to
include replacement fields '{source}' and '{target}'.
sync_down_template (str): A runnable string template; needs to
include replacement fields '{source}' and '{target}'.
"""
if not isinstance(sync_up_template, str):
raise ValueError("{} is not a string.".format(sync_up_template))
if not isinstance(sync_down_template, str):
raise ValueError("{} is not a string.".format(sync_down_template))
self._validate_sync_string(sync_up_template)
self._validate_sync_string(sync_down_template)
self.sync_up_template = sync_up_template
self.sync_down_template = sync_down_template
self.logfile = None
self.sync_process = None
def set_logdir(self, logdir):
"""Sets the directory to log sync execution output in.
Args:
logdir: Log directory.
"""
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync", dir=logdir, suffix=".log", delete=False)
def sync_up(self, source, target):
return self.execute(self.sync_up_template, source, target)
def sync_down(self, source, target):
return self.execute(self.sync_down_template, source, target)
def execute(self, sync_template, source, target):
"""Executes sync_template on source and target."""
if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
logger.warning("Last sync is still in progress, skipping.")
return False
final_cmd = sync_template.format(
source=quote(source), target=quote(target))
logger.debug("Running sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile)
return True
def wait(self):
if self.sync_process:
_, error_msg = self.sync_process.communicate()
error_msg = error_msg.decode("ascii")
code = self.sync_process.returncode
self.sync_process = None
if code != 0:
raise TuneError("Sync error ({}): {}".format(code, error_msg))
def reset(self):
if self.sync_process:
logger.warning("Sync process still running but resetting anyways.")
self.sync_process = None
@staticmethod
def _validate_sync_string(sync_string):
if "{source}" not in sync_string:
raise ValueError("Sync template missing '{source}'.")
if "{target}" not in sync_string:
raise ValueError("Sync template missing '{target}'.")
NOOP = FunctionBasedClient(lambda s, t: None, lambda s, t: None)
class Syncer(object):
def __init__(self, local_dir, remote_dir, sync_client=NOOP):
"""Syncs between two directories with the sync_function.
Arguments:
local_dir (str): Directory to sync. Uniquely identifies the syncer.
remote_dir (str): Remote directory to sync with.
sync_function (func): Function for syncing the local_dir to
sync_client (SyncClient): Client for syncing between local_dir and
remote_dir. Defaults to a Noop.
"""
self._local_dir = (os.path.join(local_dir, "")
@@ -55,32 +171,7 @@ class BaseSyncer(object):
self._remote_dir = remote_dir
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
self._sync_function = sync_function or (lambda source, target: None)
def sync_function(self, source, target):
"""Executes sync between source and target.
Can be overwritten by subclasses for custom sync procedures.
Args:
source: Path to source file(s).
target: Path to target file(s).
"""
if self._sync_function:
return self._sync_function(source, target)
def sync(self, source, target):
if not (source and target):
logger.debug(
"Source or target is empty, skipping log sync for {}".format(
self._local_dir))
return
try:
self.sync_function(source, target)
return True
except Exception:
logger.exception("Sync function failed.")
self.sync_client = sync_client
def sync_up_if_needed(self):
if time.time() - self.last_sync_up_time > SYNC_PERIOD:
@@ -90,120 +181,86 @@ class BaseSyncer(object):
if time.time() - self.last_sync_down_time > SYNC_PERIOD:
self.sync_down()
def sync_down(self, *args, **kwargs):
self.sync(self._remote_path, self._local_dir, *args, **kwargs)
self.last_sync_down_time = time.time()
def sync_up(self):
"""Attempts to start the sync-up to the remote path.
def sync_up(self, *args, **kwargs):
self.sync(self._local_dir, self._remote_path, *args, **kwargs)
self.last_sync_up_time = time.time()
Returns:
Whether the sync (if feasible) was successfully started.
"""
result = False
if self.validate_hosts(self._local_dir, self._remote_path):
try:
result = self.sync_client.sync_up(self._local_dir,
self._remote_path)
self.last_sync_up_time = time.time()
except Exception:
logger.exception("Sync execution failed.")
return result
def sync_down(self):
"""Attempts to start the sync-down from the remote path.
Returns:
Whether the sync (if feasible) was successfully started.
"""
result = False
if self.validate_hosts(self._local_dir, self._remote_path):
try:
result = self.sync_client.sync_down(self._remote_path,
self._local_dir)
self.last_sync_down_time = time.time()
except Exception:
logger.exception("Sync execution failed.")
return result
def validate_hosts(self, source, target):
if not (source and target):
logger.debug("Source or target is empty, skipping log sync for "
"{}".format(self._local_dir))
return False
return True
def wait(self):
"""Waits for the sync client to complete the current sync."""
self.sync_client.wait()
def reset(self):
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
def wait(self):
pass
self.sync_client.reset()
@property
def _remote_path(self):
"""Protected method for accessing remote_dir.
Can be overridden in subclass for custom path.
"""
return self._remote_dir
class CommandSyncer(BaseSyncer):
def __init__(self, local_dir, remote_dir, sync_template):
"""Syncs between two directories with the given command.
Arguments:
local_dir (str): Directory to sync.
remote_dir (str): Remote directory to sync with.
sync_template (str): A string template
for syncer to run and needs to include replacement fields
'{source}' and '{target}'. Returned when using
`CommandSyncer.sync_template`, which can be overridden
by subclass.
"""
super(CommandSyncer, self).__init__(local_dir, remote_dir)
if not isinstance(sync_template, str):
raise ValueError("{} is not a string.".format(sync_template))
validate_sync_string(sync_template)
self._sync_template = sync_template
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync",
dir=self._local_dir,
suffix=".log",
delete=False)
self.sync_process = None
def sync_function(self, source, target):
self.last_sync_time = time.time()
if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
logger.warning("Last sync is still in progress, skipping.")
return
final_cmd = self._sync_template.format(
source=quote(source), target=quote(target))
logger.debug("Running sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
return True
def reset(self):
if self.sync_process:
logger.warning("Sync process still running but resetting anyways.")
self.sync_process = None
super(CommandSyncer, self).reset()
def wait(self):
if self.sync_process:
self.sync_process.wait()
def _get_sync_cls(sync_function):
if not sync_function:
return
if isinstance(sync_function, types.FunctionType):
return BaseSyncer
elif isinstance(sync_function, str):
return CommandSyncer
else:
raise ValueError("Sync function {} must be string or function".format(
sync_function))
def get_syncer(local_dir, remote_dir=None, sync_function=None):
"""Returns a Syncer depending on given args.
def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None):
"""Returns a Syncer.
This syncer is in charge of syncing the local_dir with upload_dir.
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If None,
returns BaseSyncer with a noop.
local_dir (str): Source directory for syncing.
remote_dir (str): Target directory for syncing. If not provided, a
no-op Syncer is returned.
sync_function (func | str): Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults
to standard S3 or gsutil sync commands.
"""
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
if not remote_dir:
_syncers[key] = BaseSyncer(local_dir, remote_dir)
_syncers[key] = Syncer(local_dir, remote_dir, NOOP)
return _syncers[key]
sync_cls = _get_sync_cls(sync_function)
client = _get_sync_client(sync_function)
if sync_cls:
_syncers[key] = sync_cls(local_dir, remote_dir, sync_function)
if client:
_syncers[key] = Syncer(local_dir, remote_dir, client)
return _syncers[key]
if remote_dir.startswith(S3_PREFIX):
@@ -211,15 +268,17 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None):
raise TuneError(
"Upload uri starting with '{}' requires awscli tool"
" to be installed".format(S3_PREFIX))
_syncers[key] = CommandSyncer(local_dir, remote_dir,
"aws s3 sync {source} {target}")
template = "aws s3 sync {source} {target}"
s3_client = CommandBasedClient(template, template)
_syncers[key] = Syncer(local_dir, remote_dir, s3_client)
elif remote_dir.startswith(GS_PREFIX):
if not distutils.spawn.find_executable("gsutil"):
raise TuneError(
"Upload uri starting with '{}' requires gsutil tool"
" to be installed".format(GS_PREFIX))
_syncers[key] = CommandSyncer(local_dir, remote_dir,
"gsutil rsync -r {source} {target}")
template = "gsutil rsync -r {source} {target}"
gs_client = CommandBasedClient(template, template)
_syncers[key] = Syncer(local_dir, remote_dir, gs_client)
else:
raise TuneError("Upload uri must start with one of: {}"
"".format(ALLOWED_REMOTE_PREFIXES))
@@ -228,37 +287,48 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None):
def get_log_syncer(local_dir, remote_dir=None, sync_function=None):
"""Returns a Syncer depending on given args.
This syncer is in charge of syncing the local_dir with remote local_dir.
"""Returns a log Syncer.
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If None,
returns BaseSyncer with noop.
sync_function (func | str): Function for syncing the local_dir to
local_dir (str): Source directory for syncing.
remote_dir (str): Target directory for syncing. If not provided, a
no-op Syncer is returned.
sync_function (func|str): Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults rsync.
"""
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
sync_cls = None
if sync_function:
sync_cls = _get_sync_cls(sync_function)
elif not remote_dir:
sync_client = NOOP
elif sync_function:
sync_client = _get_sync_client(sync_function)
else:
sync_cls = CommandSyncer
sync_function = log_sync_template()
sync_up = log_sync_template()
sync_down = log_sync_template(options="--remove-source-files")
if sync_up and sync_down:
sync_client = CommandBasedClient(sync_up, sync_down)
sync_client.set_logdir(local_dir)
else:
sync_client = NOOP
if not remote_dir or sync_function is None:
sync_cls = BaseSyncer
class MixedSyncer(NodeSyncMixin, sync_cls):
class MixedSyncer(NodeSyncMixin, Syncer):
def __init__(self, *args, **kwargs):
sync_cls.__init__(self, *args, **kwargs)
Syncer.__init__(self, *args, **kwargs)
NodeSyncMixin.__init__(self)
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function)
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_client)
return _syncers[key]
def _get_sync_client(sync_function):
if not sync_function:
return None
if isinstance(sync_function, types.FunctionType):
return FunctionBasedClient(sync_function, sync_function)
elif isinstance(sync_function, str):
return CommandBasedClient(sync_function, sync_function)
else:
raise ValueError("Sync function {} must be string or function".format(
sync_function))
@@ -0,0 +1,106 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import sys
import unittest
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager, logger
if sys.version_info >= (3, 3):
from unittest.mock import patch
else:
from mock import patch
class CheckpointManagerTest(unittest.TestCase):
@staticmethod
def mock_result(i):
return {"i": i}
def testOnCheckpointOrdered(self):
"""
Tests increasing priorities. Also tests that that the worst checkpoints
are deleted when necessary.
"""
keep_checkpoints_num = 2
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoints = [
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
for i in range(3)
]
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
for j in range(3):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 2 else 1
self.assertEqual(rmtree_mock.call_count, expected_deletes)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertIn(checkpoints[1], best_checkpoints)
self.assertIn(checkpoints[2], best_checkpoints)
def testOnCheckpointUnordered(self):
"""
Tests priorities that aren't inserted in ascending order. Also tests
that the worst checkpoints are deleted when necessary.
"""
keep_checkpoints_num = 2
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoints = [
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
for i in range(3, -1, -1)
]
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
for j in range(0, len(checkpoints)):
checkpoint_manager.on_checkpoint(checkpoints[j])
expected_deletes = 0 if j != 3 else 1
self.assertEqual(rmtree_mock.call_count, expected_deletes)
self.assertEqual(checkpoint_manager.newest_checkpoint,
checkpoints[j])
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertIn(checkpoints[0], best_checkpoints)
self.assertIn(checkpoints[1], best_checkpoints)
def testBestCheckpoints(self):
"""
Tests that the best checkpoints are tracked and ordered correctly.
"""
keep_checkpoints_num = 4
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
checkpoints = [
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
for i in range(16)
]
random.shuffle(checkpoints)
for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
for i in range(len(best_checkpoints)):
self.assertEqual(best_checkpoints[i].value, i + 12)
def testOnCheckpointUnavailableAttribute(self):
"""
Tests that an error is logged when the associated result of the
checkpoint has no checkpoint score attribute.
"""
keep_checkpoints_num = 1
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
with patch.object(logger, "error") as log_error_mock:
checkpoint_manager.on_checkpoint(no_attr_checkpoint)
log_error_mock.assert_called_once()
# The newest checkpoint should still be set despite this error.
assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
+2 -2
View File
@@ -336,7 +336,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Recovery step
for i in range(3):
@@ -569,7 +569,7 @@ tune.run(
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
Experiment._register_if_needed(_Mock)
Experiment.register_if_needed(_Mock)
# Inspect the internal trialrunner
runner = TrialRunner(
+17 -17
View File
@@ -26,6 +26,7 @@ from ray.tune.result import (
EPISODES_TOTAL, TRAINING_ITERATION, TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
from ray.tune.logger import Logger
from ray.tune.syncer import CommandBasedClient
from ray.tune.util import pin_in_object_store, get_pinned_object, flatten_dict
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, ExportFormat
@@ -1124,21 +1125,20 @@ class TestSyncFunctionality(unittest.TestCase):
"sync_to_driver": "ls {source}"
}).trials
with patch("ray.tune.syncer.CommandSyncer.sync_function"
) as mock_fn, patch(
"ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "echo {source} {target}"
}).trials
self.assertGreater(mock_fn.call_count, 0)
with patch.object(CommandBasedClient, "execute") as mock_fn:
with patch("ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "echo {source} {target}"
}).trials
self.assertGreater(mock_fn.call_count, 0)
def testCloudFunctions(self):
tmpdir = tempfile.mkdtemp()
@@ -1166,7 +1166,7 @@ class TestSyncFunctionality(unittest.TestCase):
def testClusterSyncFunction(self):
def sync_func_driver(source, target):
assert ":" in source, "Source not a remote path."
assert ":" in source, "Source {} not a remote path.".format(source)
assert ":" not in target, "Target is supposed to be local."
with open(os.path.join(target, "test.log2"), "w") as f:
print("writing to", f.name)
@@ -1203,7 +1203,7 @@ class TestSyncFunctionality(unittest.TestCase):
def sync_func(source, target):
pass
with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync:
with patch.object(CommandBasedClient, "execute") as mock_sync:
[trial] = tune.run(
"__fake",
name="foo",
+1
View File
@@ -103,6 +103,7 @@ class TrackSession(object):
self.trial_config["end_time"] = datetime.now().isoformat()
# TODO(rliaw): Have Tune support updated configs
self._logger.update_config(self.trial_config)
self._logger.flush()
self._logger.close()
@property
+12 -25
View File
@@ -70,7 +70,6 @@ class Trainable(object):
self._experiment_id = uuid.uuid4().hex
self.config = config or {}
log_sys_usage = self.config.get("log_sys_usage", False)
if logger_creator:
self._result_logger = logger_creator(self.config)
@@ -92,6 +91,7 @@ class Trainable(object):
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = False
start_time = time.time()
self._setup(copy.deepcopy(self.config))
setup_time = time.time() - start_time
@@ -101,6 +101,7 @@ class Trainable(object):
"reuse_actors=True to reduce actor creation "
"overheads.".format(setup_time))
self._local_ip = ray.services.get_node_ip_address()
log_sys_usage = self.config.get("log_sys_usage", False)
self._monitor = UtilMonitor(start=log_sys_usage)
@classmethod
@@ -112,11 +113,11 @@ class Trainable(object):
Example:
>>> def default_resource_request(cls, config):
return Resources(
cpu=0,
gpu=0,
extra_cpu=config["workers"],
extra_gpu=int(config["use_gpu"]) * config["workers"])
>>> return Resources(
>>> cpu=0,
>>> gpu=0,
>>> extra_cpu=config["workers"],
>>> extra_gpu=int(config["use_gpu"]) * config["workers"])
"""
return None
@@ -171,7 +172,6 @@ class Trainable(object):
Returns:
A dict that describes training progress.
"""
start = time.time()
result = self._train()
assert isinstance(result, dict), "_train() needs to return a dict."
@@ -239,17 +239,6 @@ class Trainable(object):
return result
def delete_checkpoint(self, checkpoint_dir):
"""Removes subdirectory within checkpoint_folder
Args:
checkpoint_dir : path to checkpoint
"""
if os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.
@@ -262,9 +251,9 @@ class Trainable(object):
Returns:
Checkpoint path or prefix that may be passed to restore().
"""
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint = self._save(checkpoint_dir)
@@ -293,7 +282,7 @@ class Trainable(object):
"time_total": self._time_total,
"episodes_total": self._episodes_total,
"saved_as_dict": saved_as_dict,
"ray_version": ray.__version__
"ray_version": ray.__version__,
}, f)
return checkpoint_path
@@ -305,10 +294,8 @@ class Trainable(object):
Returns:
Object holding checkpoint data.
"""
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_path = self.save(tmpdir)
# Save all files in subtree.
data = {}
for basedir, _, file_names in os.walk(tmpdir):
@@ -356,7 +343,7 @@ class Trainable(object):
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = True
logger.info("Restored from checkpoint: {}".format(checkpoint_path))
logger.info("Restored from checkpoint: %s", checkpoint_path)
state = {
"_iteration": self._iteration,
"_timesteps_total": self._timesteps_total,
@@ -398,7 +385,7 @@ class Trainable(object):
export_dir (str): Optional dir to place the exported model.
Defaults to self.logdir.
Return:
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
export_dir = export_dir or self.logdir
@@ -422,7 +409,7 @@ class Trainable(object):
def stop(self):
"""Releases all resources used by this trainable."""
self._result_logger.flush()
self._result_logger.close()
self._stop()
+65 -74
View File
@@ -12,6 +12,7 @@ import tempfile
import os
from numbers import Number
from ray.tune import TuneError
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.logger import pretty_print, UnifiedLogger
from ray.tune.util import flatten_dict
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
@@ -31,29 +32,20 @@ def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
class Checkpoint(object):
"""Describes a checkpoint of trial state.
class Location(object):
"""Describes the location at which Trial is placed to run."""
Checkpoint may be saved in different storage.
def __init__(self, hostname=None, pid=None):
self.hostname = hostname
self.pid = pid
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY,value is a Python object.
If storage==DISK,value is a path points to the checkpoint in disk.
"""
MEMORY = "memory"
DISK = "disk"
def __init__(self, storage, value, last_result=None):
self.storage = storage
self.value = value
self.last_result = last_result or {}
@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)
def __str__(self):
if not self.pid:
return ""
elif self.hostname == os.uname()[1]:
return "pid={}".format(self.pid)
else:
return "{}:{}".format(self.hostname, self.pid)
class ExportFormat(object):
@@ -108,8 +100,9 @@ class Trial(object):
stopping_criterion=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
keep_checkpoints_num=None,
checkpoint_score_attr="",
checkpoint_score_attr=TRAINING_ITERATION,
export_formats=None,
restore_path=None,
trial_name_creator=None,
@@ -121,7 +114,6 @@ class Trial(object):
The args here take the same meaning as the command line flags defined
in ray.tune.config_parser.
"""
validate_trainable(trainable_name)
# Trial config
self.trainable_name = trainable_name
@@ -145,6 +137,7 @@ class Trial(object):
"clear the `resources_per_trial` option.".format(
trainable_cls, default_resources))
resources = default_resources
self.address = Location()
self.resources = resources or Resources(cpu=1, gpu=0)
self.stopping_criterion = stopping_criterion or {}
self.loggers = loggers
@@ -161,17 +154,12 @@ class Trial(object):
# stores in memory max/min/last result for each metric by trial
self.metric_analysis = {}
self.history = []
self.keep_checkpoints_num = keep_checkpoints_num
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
self.best_checkpoint_attr_value = -float("inf") \
if self._cmp_greater else float("inf")
# Strip off "min-" from checkpoint attribute
self.checkpoint_score_attr = checkpoint_score_attr \
if self._cmp_greater else checkpoint_score_attr[4:]
self.sync_on_checkpoint = sync_on_checkpoint
newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path)
self.checkpoint_manager = CheckpointManager(keep_checkpoints_num,
checkpoint_score_attr)
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.export_formats = export_formats
self.status = Trial.PENDING
self.logdir = None
@@ -190,7 +178,7 @@ class Trial(object):
self.extra_arg = None
self._nonjson_fields = [
"_checkpoint",
"checkpoint",
"loggers",
"sync_to_driver_fn",
"results",
@@ -201,6 +189,14 @@ class Trial(object):
if trial_name_creator:
self.custom_trial_name = trial_name_creator(self)
@property
def node_ip(self):
return self.address.hostname
@property
def checkpoint(self):
return self.checkpoint_manager.newest_checkpoint
@classmethod
def generate_id(cls):
return str(uuid.uuid1().hex)[:8]
@@ -249,10 +245,14 @@ class Trial(object):
"""
if self.result_logger:
self.result_logger.sync_results_to_new_location(worker_ip)
self.set_location(Location(worker_ip))
def set_location(self, location):
"""Sets the location of the trial."""
self.address = location
def close_logger(self):
"""Close logger."""
"""Closes logger."""
if self.result_logger:
self.result_logger.close()
self.result_logger = None
@@ -262,8 +262,9 @@ class Trial(object):
self.num_failures += 1 # may be moved to outer scope?
error_file = os.path.join(self.logdir,
"error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
with open(error_file, "a+") as f:
f.write("Failure # {}".format(self.num_failures) + "\n")
f.write(error_msg + "\n")
self.error_file = error_file
self.error_msg = error_msg
@@ -292,21 +293,36 @@ class Trial(object):
def should_checkpoint(self):
"""Whether this trial is due for checkpointing."""
result = self.last_result or {}
if result.get(DONE) and self.checkpoint_at_end:
return True
if self.checkpoint_freq:
return result.get(TRAINING_ITERATION,
0) % self.checkpoint_freq == 0
else:
return False
return (self.checkpoint_freq and
result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0)
def has_checkpoint(self):
return self._checkpoint.value is not None
return self.checkpoint.value is not None
def clear_checkpoint(self):
self._checkpoint.value = None
self.checkpoint.value = None
def on_checkpoint(self, checkpoint):
"""Hook for handling checkpoints taken by the Trainable.
Args:
checkpoint (Checkpoint): Checkpoint taken.
"""
if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK:
# Wait for any other syncs to finish. We need to sync again after
# this to handle checkpoints taken mid-sync.
self.result_logger.wait()
# Force sync down and wait before tracking the new checkpoint. This
# prevents attempts to restore from partially synced checkpoints.
if self.result_logger.sync_down():
self.result_logger.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", self)
self.checkpoint_manager.on_checkpoint(checkpoint)
def should_recover(self):
"""Returns whether the trial qualifies for retrying.
@@ -327,6 +343,7 @@ class Trial(object):
print("Result for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.set_location(Location(result.get("node_ip"), result.get("pid")))
self.last_result = result
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
@@ -345,28 +362,6 @@ class Trial(object):
value, self.metric_analysis[metric]["min"])
self.metric_analysis[metric]["last"] = value
def compare_checkpoints(self, attr_mean):
"""Compares two checkpoints based on the attribute attr_mean param.
Greater than is used by default. If command-line parameter
checkpoint_score_attr starts with "min-" less than is used.
Arguments:
attr_mean: mean of attribute value for the current checkpoint
Returns:
True: when attr_mean is greater than previous checkpoint attr_mean
and greater than function is selected
when attr_mean is less than previous checkpoint attr_mean and
less than function is selected
False: when attr_mean is not in alignment with selected cmp fn
"""
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
return True
elif (not self._cmp_greater
and attr_mean < self.best_checkpoint_attr_value):
return True
return False
def get_trainable_cls(self):
return get_trainable_cls(self.trainable_name)
@@ -376,10 +371,6 @@ class Trial(object):
def is_finished(self):
return self.status in [Trial.TERMINATED, Trial.ERROR]
@property
def node_ip(self):
return self.last_result.get("node_ip")
def __repr__(self):
return str(self)
@@ -407,7 +398,7 @@ class Trial(object):
Sets RUNNING trials to PENDING, and flushes the result logger.
Note this can only occur if the trial holds a DISK checkpoint.
"""
assert self._checkpoint.storage == Checkpoint.DISK, (
assert self.checkpoint.storage == Checkpoint.DISK, (
"Checkpoint must not be in-memory.")
state = self.__dict__.copy()
state["resources"] = resources_to_json(self.resources)
+31 -6
View File
@@ -6,6 +6,7 @@ from __future__ import print_function
import logging
from ray.tune.trial import Trial, Checkpoint
from ray.tune.error import TuneError
logger = logging.getLogger(__name__)
@@ -38,6 +39,8 @@ class TrialExecutor(object):
trial (Trial): Trial to checkpoint.
status (Trial.status): Status to set trial to.
"""
logger.debug("Trial %s: Changing status from %s to %s.", trial,
trial.status, status)
trial.status = status
if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)
@@ -48,14 +51,16 @@ class TrialExecutor(object):
Args:
trial (Trial): Trial to checkpoint.
"""
if trial._checkpoint.storage == Checkpoint.MEMORY:
logger.debug("Not saving data for trial w/ memory checkpoint.")
if trial.checkpoint.storage == Checkpoint.MEMORY:
logger.debug("Trial %s: Not saving data for memory checkpoint.",
trial)
return
try:
logger.debug("Saving trial metadata.")
logger.debug("Trial %s: Saving trial metadata.", trial)
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
except Exception:
logger.exception("Error checkpointing trial metadata.")
logger.exception("Trial %s: Error checkpointing trial metadata.",
trial)
def get_checkpoints(self):
"""Returns a copy of mapping of the trial ID to pickled metadata."""
@@ -118,7 +123,6 @@ class TrialExecutor(object):
def resume_trial(self, trial):
"""Resumes PAUSED trials. This is a blocking call."""
assert trial.status == Trial.PAUSED, trial.status
self.start_trial(trial)
@@ -150,6 +154,27 @@ class TrialExecutor(object):
"""A hook called after running one step of the trial event loop."""
pass
def on_no_available_trials(self, trial_runner):
if self._queue_trials:
return
for trial in trial_runner.get_trials():
if trial.status == Trial.PENDING:
if not self.has_resources(trial.resources):
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster has only {}. "
"Pass `queue_trials=True` in "
"ray.tune.run() or on the command "
"line to queue trials until the cluster scales "
"up or resources become available. {}").format(
trial.resources.summary_string(),
self.resource_string(),
trial.get_trainable_cls().resource_help(
trial.config)))
elif trial.status == Trial.PAUSED:
raise TuneError("There are paused trials, but no more pending "
"trials with sufficient resources.")
def get_next_available_trial(self):
"""Blocking call that waits until one result is ready.
@@ -188,7 +213,7 @@ class TrialExecutor(object):
def restore(self, trial, checkpoint=None):
"""Restores training state from a checkpoint.
If checkpoint is None, try to restore from trial._checkpoint.
If checkpoint is None, try to restore from trial.checkpoint.
If restoring fails, the trial status will be set to ERROR.
Args:
+19 -32
View File
@@ -17,8 +17,8 @@ from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
from ray.tune.syncer import get_syncer
from ray.tune.trial import Trial, Checkpoint
from ray.tune.syncer import get_cloud_syncer
from ray.tune.trial import Checkpoint, Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.util import warn_if_slow, flatten_dict
@@ -126,7 +126,7 @@ class TrialRunner(object):
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): see `tune.py:run`.
sync_to_cloud (func|str): See `tune.py:run`.
server_port (int): Port number for launching TuneServer.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
@@ -158,8 +158,8 @@ class TrialRunner(object):
os.makedirs(self._local_checkpoint_dir)
self._remote_checkpoint_dir = remote_checkpoint_dir
self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir,
sync_to_cloud)
self._syncer = get_cloud_syncer(local_checkpoint_dir,
remote_checkpoint_dir, sync_to_cloud)
self._resumed = False
@@ -216,8 +216,7 @@ class TrialRunner(object):
"Called resume from remote without remote directory.")
# Try syncing down the upload directory.
logger.info("Downloading from {}".format(
self._remote_checkpoint_dir))
logger.info("Downloading from %s", self._remote_checkpoint_dir)
self._syncer.sync_down_if_needed()
if not self.checkpoint_exists(self._local_checkpoint_dir):
@@ -334,24 +333,7 @@ class TrialRunner(object):
elif self.trial_executor.get_running_trials():
self._process_events() # blocking
else:
for trial in self._trials:
if trial.status == Trial.PENDING:
if not self.has_resources(trial.resources):
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster has only {}. "
"Pass `queue_trials=True` in "
"ray.tune.run() or on the command "
"line to queue trials until the cluster scales "
"up. {}").format(
trial.resources.summary_string(),
self.trial_executor.resource_string(),
trial.get_trainable_cls().resource_help(
trial.config)))
elif trial.status == Trial.PAUSED:
raise TuneError(
"There are paused trials, but no more pending "
"trials with sufficient resources.")
self.trial_executor.on_no_available_trials(self)
try:
with warn_if_slow("experiment_checkpoint"):
@@ -422,12 +404,12 @@ class TrialRunner(object):
def _process_events(self):
failed_trial = self.trial_executor.get_next_failed_trial()
if failed_trial:
error_msg = (
"{} (IP: {}) detected as stale. This is likely because the "
"node was lost").format(failed_trial, failed_trial.node_ip)
logger.info(error_msg)
with warn_if_slow("process_failed_trial"):
self._process_trial_failure(
failed_trial,
error_msg="{} (ip: {}) detected as stale. This is likely"
"because the node was lost".format(failed_trial,
failed_trial.node_ip))
self._process_trial_failure(failed_trial, error_msg=error_msg)
else:
trial = self.trial_executor.get_next_available_trial() # blocking
with warn_if_slow("process_trial"):
@@ -537,12 +519,17 @@ class TrialRunner(object):
stop_logger=False)
trial.result_logger.flush()
if self.trial_executor.has_resources(trial.resources):
logger.info("Attempting to recover trial.")
logger.info(
"Trial %s: Attempting to recover "
"trial state from last checkpoint.", trial)
self.trial_executor.start_trial(trial)
if trial.status == Trial.ERROR:
logger.error("Trial %s: Did not start correctly.", trial)
raise RuntimeError("Trial did not start correctly.")
logger.debug("Trial %s: Started correctly.", trial)
else:
logger.debug("Notifying Scheduler and requeueing trial.")
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
trial)
self._requeue_trial(trial)
except Exception:
logger.exception("Error recovering trial from checkpoint, abort.")
+14 -7
View File
@@ -69,6 +69,7 @@ def run(run_or_experiment,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
global_checkpoint_period=10,
@@ -118,18 +119,18 @@ def run(run_or_experiment,
`num_samples` of times.
local_dir (str): Local dir to save training results to.
Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results
to (e.g. ``s3://bucket``).
upload_dir (str): Optional URI to sync training results to
(e.g. ``s3://bucket``).
trial_name_creator (func): Optional function for generating
the trial string representation.
loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`.
sync_to_cloud (func|str): Function for syncing the local_dir to and
from upload_dir. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If not provided, the sync command defaults to standard
S3 or gsutil sync comamnds.
from upload_dir. If string, then it must be a string template that
includes `{source}` and `{target}` for the syncer to run. If not
provided, the sync command defaults to standard S3 or gsutil sync
comamnds.
sync_to_driver (func|str): Function for syncing trial logdir from
remote node to local. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
@@ -138,6 +139,11 @@ def run(run_or_experiment,
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to
guarantee recoverability. If set to False, checkpoint syncing from
worker to driver is asynchronous. Set this to False only if
synchronous checkpointing is too slow and trial restoration
failures can be tolerated. Defaults to True.
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
`None` keeps all checkpoints. Defaults to `None`. If set, need
to provide `checkpoint_score_attr`.
@@ -223,7 +229,7 @@ def run(run_or_experiment,
"not work with certain features.")
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
run_identifier = Experiment._register_if_needed(exp)
run_identifier = Experiment.register_if_needed(exp)
experiments[i] = Experiment(
name=name,
run=run_identifier,
@@ -238,6 +244,7 @@ def run(run_or_experiment,
loggers=loggers,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
sync_on_checkpoint=sync_on_checkpoint,
keep_checkpoints_num=keep_checkpoints_num,
checkpoint_score_attr=checkpoint_score_attr,
export_formats=export_formats,
+12 -5
View File
@@ -119,18 +119,25 @@ class warn_if_slow(object):
... ray.get(something)
"""
def __init__(self, name):
DEFAULT_THRESHOLD = 0.5
def __init__(self, name, threshold=None):
self.name = name
self.threshold = threshold or self.DEFAULT_THRESHOLD
self.too_slow = False
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, type, value, traceback):
now = time.time()
if now - self.start > 0.5 and now - START_OF_TIME > 60.0:
logger.warning("The `{}` operation took {} seconds to complete, ".
format(self.name, now - self.start) +
"which may be a performance bottleneck.")
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
self.too_slow = True
logger.warning(
"The `%s` operation took %s seconds to complete, "
"which may be a performance bottleneck.", self.name,
now - self.start)
def merge_dicts(d1, d2):
+2 -6
View File
@@ -311,10 +311,6 @@ class Worker(object):
"which is not an ray.ObjectID.".format(object_id))
if self.mode == LOCAL_MODE:
# TODO(ujvl): Remove check when local mode moved to core worker.
if timeout is not None:
raise ValueError(
"`get` must be called with timeout=None in local mode.")
return self.local_mode_manager.get_objects(object_ids)
timeout_ms = int(timeout * 1000) if timeout else -1
@@ -1407,8 +1403,8 @@ def get(object_ids, timeout=None):
Args:
object_ids: Object ID of the object to get or a list of object IDs to
get.
timeout (float): The maximum amount of time in seconds to wait before
returning.
timeout (Optional[float]): The maximum amount of time in seconds to
wait before returning.
Returns:
A Python object or a list of Python objects.