mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 21:29:15 +08:00
[tune] Fault tolerance improvements (#5877)
* Precede ray.get with ray.wait. * Trigger checkpoint deletes locally in Trainable * Clean-up code. * Minor changes. * Track best checkpoint so far again * Pulled checkpoint GC out of Trainable. * Added comments, error logging. * Immediate pull after checkpoint taken; rsync source delete on pull * Minor doc fixes * Fix checkpoint manager bug * Fix bugs, tests, formatting * Fix bugs, feature flag for force sync. * Fix test. * Fix minor bugs: clear proc and less verbose sync_on_checkpoint warnings. * Fix bug: update IP of last_result. * Fixed message. * Added a lot of logging. * Changes to ray trial executor. * More bug fixes (logging after failure), better logging. * Fix richards bug and logging * Add comments. * try-except * Fix heapq bug. * . * Move handling of no available trials to ray_trial_executor (#1) * Fix formatting bug, lint. * Addressed Richard's comments * Revert tests. * fix rebase * Fix trial location reporting. * Fix test * Fix lint * Rebase, use ray.get w/ timeout, lint. * lint * fix rebase * Address richard's comments
This commit is contained in:
committed by
Richard Liaw
parent
66edebce3a
commit
2965dc1b72
@@ -13,8 +13,22 @@ from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
randn, loguniform)
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run", "run_experiments", "Experiment", "function",
|
||||
"sample_from", "track", "uniform", "choice", "randint", "randn",
|
||||
"loguniform", "progress_reporter", "ExperimentAnalysis", "Analysis"
|
||||
"Trainable",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run",
|
||||
"run_experiments",
|
||||
"Experiment",
|
||||
"function",
|
||||
"sample_from",
|
||||
"track",
|
||||
"uniform",
|
||||
"choice",
|
||||
"randint",
|
||||
"randn",
|
||||
"loguniform",
|
||||
"ExperimentAnalysis",
|
||||
"Analysis",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
try:
|
||||
FileNotFoundError
|
||||
except NameError:
|
||||
FileNotFoundError = IOError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Checkpoint(object):
|
||||
"""Describes a checkpoint of trial state.
|
||||
|
||||
Checkpoint may be saved in different storage.
|
||||
|
||||
Attributes:
|
||||
storage (str): Storage type.
|
||||
value (str): If storage==MEMORY, value is a Python object.
|
||||
If storage==DISK, value is a path points to the checkpoint in disk.
|
||||
"""
|
||||
|
||||
MEMORY = "memory"
|
||||
DISK = "disk"
|
||||
|
||||
def __init__(self, storage, value, result=None):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
self.result = result or {}
|
||||
|
||||
def delete(self):
|
||||
"""Deletes checkpoint data if disk checkpoint."""
|
||||
if self.storage == Checkpoint.DISK and self.value:
|
||||
checkpoint_dir = self.value
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
raise FileNotFoundError(
|
||||
"Attempted to delete checkpoint at {} but "
|
||||
"path was not found.".format(checkpoint_dir))
|
||||
elif os.path.isfile(checkpoint_dir):
|
||||
shutil.rmtree(os.path.dirname(checkpoint_dir))
|
||||
else:
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
"""Creates a checkpoint from a Python object."""
|
||||
return Checkpoint(Checkpoint.MEMORY, value)
|
||||
|
||||
|
||||
class QueueItem(object):
|
||||
def __init__(self, priority, value):
|
||||
self.priority = priority
|
||||
self.value = value
|
||||
|
||||
def __cmp__(self, other):
|
||||
# For python2.7 compatibility.
|
||||
if self.priority == other.priority:
|
||||
return 0
|
||||
return -1 if self.priority < other.priority else 1
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.priority < other.priority
|
||||
|
||||
|
||||
class CheckpointManager(object):
|
||||
"""Manages checkpoints on the driver for a trial."""
|
||||
|
||||
def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
|
||||
"""Initializes a new CheckpointManager.
|
||||
|
||||
Args:
|
||||
keep_checkpoints_num (int): Keep at least this many checkpoints.
|
||||
checkpoint_score_attr (str): Attribute to use to determine which
|
||||
checkpoints to keep.
|
||||
"""
|
||||
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
|
||||
assert self.keep_checkpoints_num > 0, (
|
||||
"keep_checkpoints_num must be greater than 0.")
|
||||
self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-")
|
||||
if self._checkpoint_score_desc:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr[4:]
|
||||
else:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr
|
||||
|
||||
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self._best_checkpoints = []
|
||||
self._membership = set()
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Starts tracking checkpoint metadata on checkpoint.
|
||||
|
||||
Sets newest checkpoint. Deletes previous checkpoint as long as it isn't
|
||||
one of the best ones. Also deletes the worst checkpoint if at capacity.
|
||||
|
||||
Args:
|
||||
checkpoint (Checkpoint): Trial state checkpoint.
|
||||
|
||||
Raises:
|
||||
KeyError if checkpoint_score_attr not in result of checkpoint.
|
||||
"""
|
||||
old_checkpoint = self.newest_checkpoint
|
||||
self.newest_checkpoint = checkpoint
|
||||
|
||||
try:
|
||||
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
|
||||
except KeyError:
|
||||
if old_checkpoint not in self._membership:
|
||||
old_checkpoint.delete()
|
||||
logger.error("Result dict has no key: {}. "
|
||||
"checkpoint_score_attr must be set to a key in the "
|
||||
"result dict.".format(self._checkpoint_score_attr))
|
||||
return
|
||||
|
||||
if len(self._best_checkpoints) < self.keep_checkpoints_num:
|
||||
heapq.heappush(self._best_checkpoints, queue_item)
|
||||
self._membership.add(checkpoint)
|
||||
elif queue_item.priority >= self._best_checkpoints[0].priority:
|
||||
worst = heapq.heappushpop(self._best_checkpoints, queue_item).value
|
||||
self._membership.add(checkpoint)
|
||||
if worst in self._membership:
|
||||
self._membership.remove(worst)
|
||||
worst.delete()
|
||||
|
||||
# Remove the old checkpoint if it isn't one of the best ones.
|
||||
if old_checkpoint not in self._membership:
|
||||
old_checkpoint.delete()
|
||||
|
||||
def best_checkpoints(self):
|
||||
"""Returns best checkpoints, sorted by score."""
|
||||
checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority)
|
||||
return [queue_item.value for queue_item in checkpoints]
|
||||
|
||||
def _priority(self, checkpoint):
|
||||
priority = checkpoint.result[self._checkpoint_score_attr]
|
||||
return -priority if self._checkpoint_score_desc else priority
|
||||
@@ -76,11 +76,19 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
action="store_true",
|
||||
help="Whether to checkpoint at the end of the experiment. "
|
||||
"Default is False.")
|
||||
parser.add_argument(
|
||||
"--no-sync-on-checkpoint",
|
||||
action="store_true",
|
||||
help="Disable sync-down of trial checkpoint, which is enabled by "
|
||||
"default to guarantee recoverability. If set, checkpoint syncing from "
|
||||
"worker to driver is asynchronous. Set this only if synchronous "
|
||||
"checkpointing is too slow and trial restoration failures can be "
|
||||
"tolerated")
|
||||
parser.add_argument(
|
||||
"--keep-checkpoints-num",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Number of last checkpoints to keep. Others get "
|
||||
help="Number of best checkpoints to keep. Others get "
|
||||
"deleted. Default (None) keeps all checkpoints.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-score-attr",
|
||||
@@ -177,6 +185,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
sync_on_checkpoint=not args.no_sync_on_checkpoint,
|
||||
keep_checkpoints_num=args.keep_checkpoints_num,
|
||||
checkpoint_score_attr=args.checkpoint_score_attr,
|
||||
export_formats=spec.get("export_formats", []),
|
||||
|
||||
@@ -72,6 +72,7 @@ class Experiment(object):
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr=None,
|
||||
export_formats=None,
|
||||
@@ -80,6 +81,11 @@ class Experiment(object):
|
||||
repeat=None,
|
||||
trial_resources=None,
|
||||
sync_function=None):
|
||||
"""Initialize a new Experiment.
|
||||
|
||||
The args here take the same meaning as the command line flags defined
|
||||
in `tune.py:run`.
|
||||
"""
|
||||
if repeat:
|
||||
_raise_deprecation_note("repeat", "num_samples", soft=False)
|
||||
if trial_resources:
|
||||
@@ -102,7 +108,7 @@ class Experiment(object):
|
||||
"criteria must take exactly 2 parameters.".format(stop))
|
||||
|
||||
config = config or {}
|
||||
self._run_identifier = Experiment._register_if_needed(run)
|
||||
self._run_identifier = Experiment.register_if_needed(run)
|
||||
spec = {
|
||||
"run": self._run_identifier,
|
||||
"stop": stop,
|
||||
@@ -117,6 +123,7 @@ class Experiment(object):
|
||||
"sync_to_driver": sync_to_driver,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"sync_on_checkpoint": sync_on_checkpoint,
|
||||
"keep_checkpoints_num": keep_checkpoints_num,
|
||||
"checkpoint_score_attr": checkpoint_score_attr,
|
||||
"export_formats": export_formats or [],
|
||||
@@ -156,7 +163,7 @@ class Experiment(object):
|
||||
return exp
|
||||
|
||||
@classmethod
|
||||
def _register_if_needed(cls, run_object):
|
||||
def register_if_needed(cls, run_object):
|
||||
"""Registers Trainable or Function at runtime.
|
||||
|
||||
Assumes already registered if run_object is a string.
|
||||
|
||||
+38
-11
@@ -17,12 +17,17 @@ logger = logging.getLogger(__name__)
|
||||
_log_sync_warned = False
|
||||
|
||||
|
||||
def log_sync_template():
|
||||
def log_sync_template(options=""):
|
||||
"""Syncs the local_dir between driver and worker if possible.
|
||||
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
rsync to be installed.
|
||||
|
||||
Args:
|
||||
options (str): Addtional rsync options.
|
||||
|
||||
Returns:
|
||||
Sync template with source and target parameters.
|
||||
"""
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
@@ -36,12 +41,14 @@ def log_sync_template():
|
||||
_log_sync_warned = True
|
||||
return
|
||||
|
||||
return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" {{source}} {{target}}"""
|
||||
).format(ssh_key=quote(ssh_key))
|
||||
rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no"
|
||||
rsh = rsh.format(ssh_key=quote(ssh_key))
|
||||
template = """rsync {options} -savz -e "{rsh}" {{source}} {{target}}"""
|
||||
return template.format(options=options, rsh=rsh)
|
||||
|
||||
|
||||
class NodeSyncMixin(object):
|
||||
# TODO(ujvl): Refactor this code.
|
||||
"""Mixin for syncing files to/from a remote dir to a local dir."""
|
||||
|
||||
def __init__(self):
|
||||
@@ -53,23 +60,43 @@ class NodeSyncMixin(object):
|
||||
"""Set the worker ip to sync logs from."""
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
def _check_valid_worker_ip(self):
|
||||
def has_remote_target(self):
|
||||
"""Returns whether the Syncer has a remote target."""
|
||||
if not self.worker_ip:
|
||||
logger.debug("Worker ip unknown, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
logger.debug("Worker IP unknown, skipping log sync for %s",
|
||||
self._local_dir)
|
||||
return False
|
||||
if self.worker_ip == self.local_ip:
|
||||
logger.debug(
|
||||
"Worker ip is local ip, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
logger.debug("Worker IP is local IP, skipping log sync for %s",
|
||||
self._local_dir)
|
||||
return False
|
||||
return True
|
||||
|
||||
def sync_up_if_needed(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
super(NodeSyncMixin, self).sync_up()
|
||||
|
||||
def sync_down_if_needed(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
super(NodeSyncMixin, self).sync_down()
|
||||
|
||||
def sync_down(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncMixin, self).sync_down()
|
||||
|
||||
def sync_up(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncMixin, self).sync_up()
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if not self._check_valid_worker_ip():
|
||||
if not self.has_remote_target():
|
||||
return
|
||||
if ssh_user is None:
|
||||
if not _log_sync_warned:
|
||||
|
||||
@@ -409,8 +409,8 @@ class UnifiedLogger(Logger):
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.trial))
|
||||
except Exception as exc:
|
||||
logger.warning("Could not instantiate {}: {}.".format(
|
||||
cls.__name__, str(exc)))
|
||||
logger.warning("Could not instantiate %s: %s.", cls.__name__,
|
||||
str(exc))
|
||||
self._log_syncer = get_log_syncer(
|
||||
self.logdir,
|
||||
remote_dir=self.logdir,
|
||||
@@ -429,12 +429,21 @@ class UnifiedLogger(Logger):
|
||||
def close(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_down()
|
||||
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
self._log_syncer.sync_down()
|
||||
if not self._log_syncer.sync_down():
|
||||
logger.warning("Trial %s: Post-flush sync skipped.", self.trial)
|
||||
|
||||
def sync_up(self):
|
||||
return self._log_syncer.sync_up()
|
||||
|
||||
def sync_down(self):
|
||||
return self._log_syncer.sync_down()
|
||||
|
||||
def wait(self):
|
||||
self._log_syncer.wait()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
@@ -443,13 +452,19 @@ class UnifiedLogger(Logger):
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
logger.info("Syncing (blocking) results to {}".format(worker_ip))
|
||||
logger.info("Trial %s: Syncing (blocking) results to %s",
|
||||
self.trial, worker_ip)
|
||||
self._log_syncer.reset()
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
self._log_syncer.sync_up()
|
||||
# TODO: change this because this is blocking. But failures
|
||||
# are rare, so maybe this is OK?
|
||||
if not self._log_syncer.sync_up():
|
||||
logger.error(
|
||||
"Trial %s: Sync up to new location skipped. "
|
||||
"This should not occur.", self.trial)
|
||||
self._log_syncer.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Sync attempted to same IP %s. This "
|
||||
"should not occur.", self.trial, worker_ip)
|
||||
|
||||
|
||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
|
||||
@@ -4,18 +4,18 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTimeoutError
|
||||
from ray import ray_constants
|
||||
from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.error import AbortTrialExecution
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.trial import Trial, Checkpoint, Location
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.util import warn_if_slow
|
||||
@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
|
||||
BOTTLENECK_WARN_PERIOD_S = 60
|
||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||
DEFAULT_GET_TIMEOUT = 30.0 # seconds
|
||||
|
||||
|
||||
class _LocalWrapper(object):
|
||||
@@ -37,7 +38,7 @@ class _LocalWrapper(object):
|
||||
|
||||
|
||||
class RayTrialExecutor(TrialExecutor):
|
||||
"""An implemention of TrialExecutor based on Ray."""
|
||||
"""An implementation of TrialExecutor based on Ray."""
|
||||
|
||||
def __init__(self,
|
||||
queue_trials=False,
|
||||
@@ -71,36 +72,18 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if ray.is_initialized():
|
||||
self._update_avail_resources()
|
||||
|
||||
def _setup_runner(self, trial, reuse_allowed):
|
||||
def _setup_remote_runner(self, trial, reuse_allowed):
|
||||
trial.init_logger()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
remote_logdir = trial.logdir
|
||||
|
||||
if (self._reuse_actors and reuse_allowed
|
||||
and self._cached_actor is not None):
|
||||
logger.debug("Reusing cached runner {} for {}".format(
|
||||
self._cached_actor, trial.trial_id))
|
||||
existing_runner = self._cached_actor
|
||||
self._cached_actor = None
|
||||
else:
|
||||
if self._cached_actor:
|
||||
logger.debug(
|
||||
"Cannot reuse cached runner {} for new trial".format(
|
||||
self._cached_actor))
|
||||
self._cached_actor.stop.remote()
|
||||
self._cached_actor.__ray_terminate__.remote()
|
||||
self._cached_actor = None
|
||||
existing_runner = None
|
||||
cls = ray.remote(
|
||||
num_cpus=trial.resources.cpu,
|
||||
num_gpus=trial.resources.gpu,
|
||||
memory=trial.resources.memory,
|
||||
object_store_memory=trial.resources.object_store_memory,
|
||||
resources=trial.resources.custom_resources)(
|
||||
trial.get_trainable_cls())
|
||||
|
||||
trial.init_logger()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
remote_logdir = trial.logdir
|
||||
|
||||
if existing_runner:
|
||||
trial.runner = existing_runner
|
||||
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
|
||||
raise AbortTrialExecution(
|
||||
@@ -108,6 +91,21 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"implemented and return True.")
|
||||
return existing_runner
|
||||
|
||||
if self._cached_actor:
|
||||
logger.debug("Cannot reuse cached runner {} for new trial".format(
|
||||
self._cached_actor))
|
||||
self._cached_actor.stop.remote()
|
||||
self._cached_actor.__ray_terminate__.remote()
|
||||
self._cached_actor = None
|
||||
|
||||
cls = ray.remote(
|
||||
num_cpus=trial.resources.cpu,
|
||||
num_gpus=trial.resources.gpu,
|
||||
memory=trial.resources.memory,
|
||||
object_store_memory=trial.resources.object_store_memory,
|
||||
resources=trial.resources.custom_resources)(
|
||||
trial.get_trainable_cls())
|
||||
|
||||
def logger_creator(config):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
if not os.path.exists(remote_logdir):
|
||||
@@ -116,6 +114,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
os.chdir(remote_logdir)
|
||||
return NoopLogger(config, remote_logdir)
|
||||
|
||||
# Clear the Trial's location (to be updated later on result)
|
||||
# since we don't know where the remote runner is placed.
|
||||
trial.set_location(Location())
|
||||
logger.info("Trial %s: Setting up new remote runner.", trial)
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
return cls.remote(config=trial.config, logger_creator=logger_creator)
|
||||
@@ -136,22 +138,20 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
|
||||
Raises:
|
||||
ValueError if restoring from checkpoint fails.
|
||||
RuntimeError if restoring from checkpoint fails.
|
||||
"""
|
||||
prior_status = trial.status
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.runner = self._setup_runner(
|
||||
trial.runner = self._setup_remote_runner(
|
||||
trial,
|
||||
reuse_allowed=checkpoint is not None
|
||||
or trial._checkpoint.value is not None)
|
||||
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
|
||||
if not self.restore(trial, checkpoint):
|
||||
if trial.status == Trial.ERROR:
|
||||
raise RuntimeError(
|
||||
"Restore from checkpoint failed for Trial {}.".format(
|
||||
str(trial)))
|
||||
"Trial {}: Restore from checkpoint failed.".format(trial))
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if (prior_status == Trial.PAUSED and previous_run):
|
||||
if prior_status == Trial.PAUSED and previous_run:
|
||||
# If Trial was in flight when paused, self._paused stores result.
|
||||
self._paused.pop(previous_run[0])
|
||||
self._running[previous_run[0]] = trial
|
||||
@@ -175,10 +175,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
if error:
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
else:
|
||||
self.set_status(trial, Trial.TERMINATED)
|
||||
self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
|
||||
trial.set_location(Location())
|
||||
|
||||
try:
|
||||
trial.write_error_log(error_msg)
|
||||
@@ -188,12 +186,11 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.debug("Reusing actor for {}".format(trial.runner))
|
||||
self._cached_actor = trial.runner
|
||||
else:
|
||||
logger.debug(
|
||||
"Destroying actor for trial {}.".format(trial))
|
||||
logger.debug("Trial %s: Destroying actor.", trial)
|
||||
trial.runner.stop.remote()
|
||||
trial.runner.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Error stopping runner for Trial %s", str(trial))
|
||||
logger.exception("Trial %s: Error stopping runner.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
finally:
|
||||
trial.runner = None
|
||||
@@ -208,32 +205,35 @@ class RayTrialExecutor(TrialExecutor):
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
"""
|
||||
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
except Exception as e:
|
||||
logger.exception("Error starting runner for Trial %s", str(trial))
|
||||
error_msg = traceback.format_exc()
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
return # don't retry fatal Tune errors
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Trial %s: Error starting runner. Attempting "
|
||||
"restart without checkpoint.", trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
if isinstance(e, AbortTrialExecution):
|
||||
return # don't retry fatal Tune errors
|
||||
try:
|
||||
# This forces the trial to not start from checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
logger.info(
|
||||
"Trying to start runner for Trial %s without checkpoint.",
|
||||
str(trial))
|
||||
self._start_trial(trial)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error starting runner for Trial %s, aborting!",
|
||||
str(trial))
|
||||
"Trial %s: Error starting runner on second "
|
||||
"attempt, aborting!", trial)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
# Note that we don't return the resources, since they may
|
||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||
|
||||
def _find_item(self, dictionary, item):
|
||||
out = [rid for rid, t in dictionary.items() if t is item]
|
||||
@@ -245,7 +245,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._stop_trial(
|
||||
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
|
||||
if prior_status == Trial.RUNNING:
|
||||
logger.debug("Returning resources for Trial %s.", str(trial))
|
||||
logger.debug("Trial %s: Returning resources.", trial)
|
||||
self._return_resources(trial.resources)
|
||||
out = self._find_item(self._running, trial)
|
||||
for result_id in out:
|
||||
@@ -253,7 +253,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def continue_training(self, trial):
|
||||
"""Continues the training of this trial."""
|
||||
|
||||
self._train(trial)
|
||||
|
||||
def pause_trial(self, trial):
|
||||
@@ -262,7 +261,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
If trial is in-flight, preserves return value in separate queue
|
||||
before pausing, which is restored when Trial is resumed.
|
||||
"""
|
||||
|
||||
trial_future = self._find_item(self._running, trial)
|
||||
if trial_future:
|
||||
self._paused[trial_future[0]] = trial
|
||||
@@ -285,7 +283,13 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial.config = new_config
|
||||
trainable = trial.runner
|
||||
with warn_if_slow("reset_config"):
|
||||
reset_val = ray.get(trainable.reset_config.remote(new_config))
|
||||
try:
|
||||
reset_val = ray.get(
|
||||
trainable.reset_config.remote(new_config),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception("Trial %s: reset_config timed out.")
|
||||
return False
|
||||
return reset_val
|
||||
|
||||
def get_running_trials(self):
|
||||
@@ -351,7 +355,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
raise ValueError("Trial was not running.")
|
||||
self._running.pop(trial_future[0])
|
||||
with warn_if_slow("fetch_result"):
|
||||
result = ray.get(trial_future[0])
|
||||
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
|
||||
|
||||
# For local mode
|
||||
if isinstance(result, _LocalWrapper):
|
||||
@@ -530,48 +534,28 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
"""Saves the trial's state to a checkpoint."""
|
||||
trial._checkpoint.storage = storage
|
||||
trial._checkpoint.last_result = trial.last_result
|
||||
if storage == Checkpoint.MEMORY:
|
||||
trial._checkpoint.value = trial.runner.save_to_object.remote()
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, trial.last_result)
|
||||
else:
|
||||
# Keeps only highest performing checkpoints if enabled
|
||||
if trial.keep_checkpoints_num:
|
||||
try:
|
||||
last_attr_val = trial.last_result[
|
||||
trial.checkpoint_score_attr]
|
||||
if (trial.compare_checkpoints(last_attr_val)
|
||||
and not math.isnan(last_attr_val)):
|
||||
trial.best_checkpoint_attr_value = last_attr_val
|
||||
self._checkpoint_and_erase(trial)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Result dict has no key: {}. keep"
|
||||
"_checkpoints_num flag will not work".format(
|
||||
trial.checkpoint_score_attr))
|
||||
else:
|
||||
with warn_if_slow("save_to_disk"):
|
||||
trial._checkpoint.value = ray.get(
|
||||
trial.runner.save.remote())
|
||||
with warn_if_slow("save_checkpoint_to_disk"):
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
checkpoint = Checkpoint(storage, value, trial.last_result)
|
||||
|
||||
return trial._checkpoint.value
|
||||
|
||||
def _checkpoint_and_erase(self, trial):
|
||||
"""Checkpoints the model and erases old checkpoints
|
||||
if needed.
|
||||
Parameters
|
||||
----------
|
||||
trial : trial to save
|
||||
"""
|
||||
|
||||
with warn_if_slow("save_to_disk"):
|
||||
trial._checkpoint.value = ray.get(trial.runner.save.remote())
|
||||
|
||||
if len(trial.history) >= trial.keep_checkpoints_num:
|
||||
ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
|
||||
trial.history.pop()
|
||||
|
||||
trial.history.insert(0, trial._checkpoint.value)
|
||||
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
|
||||
try:
|
||||
trial.on_checkpoint(checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error handling checkpoint %s",
|
||||
trial, checkpoint.value)
|
||||
return None
|
||||
if profile.too_slow and trial.sync_on_checkpoint:
|
||||
logger.warning(
|
||||
"Consider turning off forced head-worker trial checkpoint "
|
||||
"syncs by setting sync_on_checkpoint=False. Note that this "
|
||||
"might result in faulty trial restoration for some worker "
|
||||
"failure modes.")
|
||||
return checkpoint.value
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
@@ -580,11 +564,13 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if restoring on a different node.
|
||||
"""
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
checkpoint = trial._checkpoint
|
||||
checkpoint = trial.checkpoint
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
return True
|
||||
if trial.runner is None:
|
||||
logger.error("Unable to restore - no runner.")
|
||||
logger.error(
|
||||
"Trial %s: Unable to restore - no runner. "
|
||||
"Setting status to ERROR.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
try:
|
||||
@@ -593,23 +579,31 @@ class RayTrialExecutor(TrialExecutor):
|
||||
assert type(value) != Checkpoint, type(value)
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
else:
|
||||
# TODO: Somehow, the call to get the current IP on the
|
||||
# remote actor can be very slow - a better fix would
|
||||
# be to use an actor table to detect the IP of the Trainable
|
||||
# and rsync the files there.
|
||||
# See https://github.com/ray-project/ray/issues/5168
|
||||
logger.info("Trial %s: Attempting restoration from %s", trial,
|
||||
checkpoint.value)
|
||||
with warn_if_slow("get_current_ip"):
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote())
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
with warn_if_slow("sync_to_new_location"):
|
||||
trial.sync_logger_to_new_location(worker_ip)
|
||||
with warn_if_slow("restore_from_disk"):
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
trial.last_result = checkpoint.last_result
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error restoring runner for Trial %s.", trial)
|
||||
ray.get(
|
||||
trial.runner.restore.remote(value),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception(
|
||||
"Trial %s: Unable to restore - runner task timed "
|
||||
"out. Setting status to ERROR", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Trial %s: Unable to restore. Setting status to ERROR", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
|
||||
trial.last_result = checkpoint.result
|
||||
return True
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
@@ -619,7 +613,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""
|
||||
if trial.export_formats and len(trial.export_formats) > 0:
|
||||
return ray.get(
|
||||
trial.runner.export_model.remote(trial.export_formats))
|
||||
trial.runner.export_model.remote(trial.export_formats),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
return {}
|
||||
|
||||
def has_gpus(self):
|
||||
|
||||
+217
-147
@@ -28,26 +28,142 @@ SYNC_PERIOD = 300
|
||||
_syncers = {}
|
||||
|
||||
|
||||
def validate_sync_string(sync_string):
|
||||
if "{source}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{source}'.")
|
||||
if "{target}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{target}'.")
|
||||
|
||||
|
||||
def wait_for_sync():
|
||||
for syncer in _syncers.values():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
class BaseSyncer(object):
|
||||
def __init__(self, local_dir, remote_dir, sync_function=None):
|
||||
class SyncClient(object):
|
||||
def sync_up(self, source, target):
|
||||
"""Sync up from source to target.
|
||||
|
||||
Args:
|
||||
source (str): Source path.
|
||||
target (str): Target path.
|
||||
|
||||
Returns:
|
||||
True if sync initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_down(self, source, target):
|
||||
"""Sync down from source to target.
|
||||
|
||||
Args:
|
||||
source (str): Source path.
|
||||
target (str): Target path.
|
||||
|
||||
Returns:
|
||||
True if sync initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self):
|
||||
"""Wait for current sync to complete, if asynchronously started."""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Resets state."""
|
||||
pass
|
||||
|
||||
|
||||
class FunctionBasedClient(SyncClient):
|
||||
def __init__(self, sync_up_func, sync_down_func):
|
||||
self.sync_up_func = sync_up_func
|
||||
self.sync_down_func = sync_down_func
|
||||
|
||||
def sync_up(self, source, target):
|
||||
self.sync_up_func(source, target)
|
||||
return True
|
||||
|
||||
def sync_down(self, source, target):
|
||||
self.sync_down_func(source, target)
|
||||
return True
|
||||
|
||||
|
||||
class CommandBasedClient(SyncClient):
|
||||
def __init__(self, sync_up_template, sync_down_template):
|
||||
"""Syncs between two directories with the given command.
|
||||
|
||||
Arguments:
|
||||
sync_up_template (str): A runnable string template; needs to
|
||||
include replacement fields '{source}' and '{target}'.
|
||||
sync_down_template (str): A runnable string template; needs to
|
||||
include replacement fields '{source}' and '{target}'.
|
||||
"""
|
||||
if not isinstance(sync_up_template, str):
|
||||
raise ValueError("{} is not a string.".format(sync_up_template))
|
||||
if not isinstance(sync_down_template, str):
|
||||
raise ValueError("{} is not a string.".format(sync_down_template))
|
||||
self._validate_sync_string(sync_up_template)
|
||||
self._validate_sync_string(sync_down_template)
|
||||
self.sync_up_template = sync_up_template
|
||||
self.sync_down_template = sync_down_template
|
||||
self.logfile = None
|
||||
self.sync_process = None
|
||||
|
||||
def set_logdir(self, logdir):
|
||||
"""Sets the directory to log sync execution output in.
|
||||
|
||||
Args:
|
||||
logdir: Log directory.
|
||||
"""
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync", dir=logdir, suffix=".log", delete=False)
|
||||
|
||||
def sync_up(self, source, target):
|
||||
return self.execute(self.sync_up_template, source, target)
|
||||
|
||||
def sync_down(self, source, target):
|
||||
return self.execute(self.sync_down_template, source, target)
|
||||
|
||||
def execute(self, sync_template, source, target):
|
||||
"""Executes sync_template on source and target."""
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
logger.warning("Last sync is still in progress, skipping.")
|
||||
return False
|
||||
final_cmd = sync_template.format(
|
||||
source=quote(source), target=quote(target))
|
||||
logger.debug("Running sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile)
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
_, error_msg = self.sync_process.communicate()
|
||||
error_msg = error_msg.decode("ascii")
|
||||
code = self.sync_process.returncode
|
||||
self.sync_process = None
|
||||
if code != 0:
|
||||
raise TuneError("Sync error ({}): {}".format(code, error_msg))
|
||||
|
||||
def reset(self):
|
||||
if self.sync_process:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self.sync_process = None
|
||||
|
||||
@staticmethod
|
||||
def _validate_sync_string(sync_string):
|
||||
if "{source}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{source}'.")
|
||||
if "{target}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{target}'.")
|
||||
|
||||
|
||||
NOOP = FunctionBasedClient(lambda s, t: None, lambda s, t: None)
|
||||
|
||||
|
||||
class Syncer(object):
|
||||
def __init__(self, local_dir, remote_dir, sync_client=NOOP):
|
||||
"""Syncs between two directories with the sync_function.
|
||||
|
||||
Arguments:
|
||||
local_dir (str): Directory to sync. Uniquely identifies the syncer.
|
||||
remote_dir (str): Remote directory to sync with.
|
||||
sync_function (func): Function for syncing the local_dir to
|
||||
sync_client (SyncClient): Client for syncing between local_dir and
|
||||
remote_dir. Defaults to a Noop.
|
||||
"""
|
||||
self._local_dir = (os.path.join(local_dir, "")
|
||||
@@ -55,32 +171,7 @@ class BaseSyncer(object):
|
||||
self._remote_dir = remote_dir
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
self._sync_function = sync_function or (lambda source, target: None)
|
||||
|
||||
def sync_function(self, source, target):
|
||||
"""Executes sync between source and target.
|
||||
|
||||
Can be overwritten by subclasses for custom sync procedures.
|
||||
|
||||
Args:
|
||||
source: Path to source file(s).
|
||||
target: Path to target file(s).
|
||||
"""
|
||||
if self._sync_function:
|
||||
return self._sync_function(source, target)
|
||||
|
||||
def sync(self, source, target):
|
||||
if not (source and target):
|
||||
logger.debug(
|
||||
"Source or target is empty, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
return
|
||||
|
||||
try:
|
||||
self.sync_function(source, target)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Sync function failed.")
|
||||
self.sync_client = sync_client
|
||||
|
||||
def sync_up_if_needed(self):
|
||||
if time.time() - self.last_sync_up_time > SYNC_PERIOD:
|
||||
@@ -90,120 +181,86 @@ class BaseSyncer(object):
|
||||
if time.time() - self.last_sync_down_time > SYNC_PERIOD:
|
||||
self.sync_down()
|
||||
|
||||
def sync_down(self, *args, **kwargs):
|
||||
self.sync(self._remote_path, self._local_dir, *args, **kwargs)
|
||||
self.last_sync_down_time = time.time()
|
||||
def sync_up(self):
|
||||
"""Attempts to start the sync-up to the remote path.
|
||||
|
||||
def sync_up(self, *args, **kwargs):
|
||||
self.sync(self._local_dir, self._remote_path, *args, **kwargs)
|
||||
self.last_sync_up_time = time.time()
|
||||
Returns:
|
||||
Whether the sync (if feasible) was successfully started.
|
||||
"""
|
||||
result = False
|
||||
if self.validate_hosts(self._local_dir, self._remote_path):
|
||||
try:
|
||||
result = self.sync_client.sync_up(self._local_dir,
|
||||
self._remote_path)
|
||||
self.last_sync_up_time = time.time()
|
||||
except Exception:
|
||||
logger.exception("Sync execution failed.")
|
||||
return result
|
||||
|
||||
def sync_down(self):
|
||||
"""Attempts to start the sync-down from the remote path.
|
||||
|
||||
Returns:
|
||||
Whether the sync (if feasible) was successfully started.
|
||||
"""
|
||||
result = False
|
||||
if self.validate_hosts(self._local_dir, self._remote_path):
|
||||
try:
|
||||
result = self.sync_client.sync_down(self._remote_path,
|
||||
self._local_dir)
|
||||
self.last_sync_down_time = time.time()
|
||||
except Exception:
|
||||
logger.exception("Sync execution failed.")
|
||||
return result
|
||||
|
||||
def validate_hosts(self, source, target):
|
||||
if not (source and target):
|
||||
logger.debug("Source or target is empty, skipping log sync for "
|
||||
"{}".format(self._local_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
"""Waits for the sync client to complete the current sync."""
|
||||
self.sync_client.wait()
|
||||
|
||||
def reset(self):
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
self.sync_client.reset()
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
"""Protected method for accessing remote_dir.
|
||||
|
||||
Can be overridden in subclass for custom path.
|
||||
"""
|
||||
return self._remote_dir
|
||||
|
||||
|
||||
class CommandSyncer(BaseSyncer):
|
||||
def __init__(self, local_dir, remote_dir, sync_template):
|
||||
"""Syncs between two directories with the given command.
|
||||
|
||||
Arguments:
|
||||
local_dir (str): Directory to sync.
|
||||
remote_dir (str): Remote directory to sync with.
|
||||
sync_template (str): A string template
|
||||
for syncer to run and needs to include replacement fields
|
||||
'{source}' and '{target}'. Returned when using
|
||||
`CommandSyncer.sync_template`, which can be overridden
|
||||
by subclass.
|
||||
"""
|
||||
super(CommandSyncer, self).__init__(local_dir, remote_dir)
|
||||
if not isinstance(sync_template, str):
|
||||
raise ValueError("{} is not a string.".format(sync_template))
|
||||
validate_sync_string(sync_template)
|
||||
self._sync_template = sync_template
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync",
|
||||
dir=self._local_dir,
|
||||
suffix=".log",
|
||||
delete=False)
|
||||
|
||||
self.sync_process = None
|
||||
|
||||
def sync_function(self, source, target):
|
||||
self.last_sync_time = time.time()
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
logger.warning("Last sync is still in progress, skipping.")
|
||||
return
|
||||
final_cmd = self._sync_template.format(
|
||||
source=quote(source), target=quote(target))
|
||||
logger.debug("Running sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
if self.sync_process:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self.sync_process = None
|
||||
super(CommandSyncer, self).reset()
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
self.sync_process.wait()
|
||||
|
||||
|
||||
def _get_sync_cls(sync_function):
|
||||
if not sync_function:
|
||||
return
|
||||
if isinstance(sync_function, types.FunctionType):
|
||||
return BaseSyncer
|
||||
elif isinstance(sync_function, str):
|
||||
return CommandSyncer
|
||||
else:
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
"""Returns a Syncer depending on given args.
|
||||
def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
"""Returns a Syncer.
|
||||
|
||||
This syncer is in charge of syncing the local_dir with upload_dir.
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If None,
|
||||
returns BaseSyncer with a noop.
|
||||
local_dir (str): Source directory for syncing.
|
||||
remote_dir (str): Target directory for syncing. If not provided, a
|
||||
no-op Syncer is returned.
|
||||
sync_function (func | str): Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults
|
||||
to standard S3 or gsutil sync commands.
|
||||
"""
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
|
||||
if not remote_dir:
|
||||
_syncers[key] = BaseSyncer(local_dir, remote_dir)
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, NOOP)
|
||||
return _syncers[key]
|
||||
|
||||
sync_cls = _get_sync_cls(sync_function)
|
||||
client = _get_sync_client(sync_function)
|
||||
|
||||
if sync_cls:
|
||||
_syncers[key] = sync_cls(local_dir, remote_dir, sync_function)
|
||||
if client:
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, client)
|
||||
return _syncers[key]
|
||||
|
||||
if remote_dir.startswith(S3_PREFIX):
|
||||
@@ -211,15 +268,17 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX))
|
||||
_syncers[key] = CommandSyncer(local_dir, remote_dir,
|
||||
"aws s3 sync {source} {target}")
|
||||
template = "aws s3 sync {source} {target}"
|
||||
s3_client = CommandBasedClient(template, template)
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, s3_client)
|
||||
elif remote_dir.startswith(GS_PREFIX):
|
||||
if not distutils.spawn.find_executable("gsutil"):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GS_PREFIX))
|
||||
_syncers[key] = CommandSyncer(local_dir, remote_dir,
|
||||
"gsutil rsync -r {source} {target}")
|
||||
template = "gsutil rsync -r {source} {target}"
|
||||
gs_client = CommandBasedClient(template, template)
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, gs_client)
|
||||
else:
|
||||
raise TuneError("Upload uri must start with one of: {}"
|
||||
"".format(ALLOWED_REMOTE_PREFIXES))
|
||||
@@ -228,37 +287,48 @@ def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
|
||||
|
||||
def get_log_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
"""Returns a Syncer depending on given args.
|
||||
|
||||
This syncer is in charge of syncing the local_dir with remote local_dir.
|
||||
"""Returns a log Syncer.
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If None,
|
||||
returns BaseSyncer with noop.
|
||||
sync_function (func | str): Function for syncing the local_dir to
|
||||
local_dir (str): Source directory for syncing.
|
||||
remote_dir (str): Target directory for syncing. If not provided, a
|
||||
no-op Syncer is returned.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults rsync.
|
||||
"""
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
|
||||
sync_cls = None
|
||||
if sync_function:
|
||||
sync_cls = _get_sync_cls(sync_function)
|
||||
elif not remote_dir:
|
||||
sync_client = NOOP
|
||||
elif sync_function:
|
||||
sync_client = _get_sync_client(sync_function)
|
||||
else:
|
||||
sync_cls = CommandSyncer
|
||||
sync_function = log_sync_template()
|
||||
sync_up = log_sync_template()
|
||||
sync_down = log_sync_template(options="--remove-source-files")
|
||||
if sync_up and sync_down:
|
||||
sync_client = CommandBasedClient(sync_up, sync_down)
|
||||
sync_client.set_logdir(local_dir)
|
||||
else:
|
||||
sync_client = NOOP
|
||||
|
||||
if not remote_dir or sync_function is None:
|
||||
sync_cls = BaseSyncer
|
||||
|
||||
class MixedSyncer(NodeSyncMixin, sync_cls):
|
||||
class MixedSyncer(NodeSyncMixin, Syncer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
sync_cls.__init__(self, *args, **kwargs)
|
||||
Syncer.__init__(self, *args, **kwargs)
|
||||
NodeSyncMixin.__init__(self)
|
||||
|
||||
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function)
|
||||
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_client)
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
def _get_sync_client(sync_function):
|
||||
if not sync_function:
|
||||
return None
|
||||
if isinstance(sync_function, types.FunctionType):
|
||||
return FunctionBasedClient(sync_function, sync_function)
|
||||
elif isinstance(sync_function, str):
|
||||
return CommandBasedClient(sync_function, sync_function)
|
||||
else:
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager, logger
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import patch
|
||||
else:
|
||||
from mock import patch
|
||||
|
||||
|
||||
class CheckpointManagerTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def mock_result(i):
|
||||
return {"i": i}
|
||||
|
||||
def testOnCheckpointOrdered(self):
|
||||
"""
|
||||
Tests increasing priorities. Also tests that that the worst checkpoints
|
||||
are deleted when necessary.
|
||||
"""
|
||||
keep_checkpoints_num = 2
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
|
||||
for j in range(3):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 2 else 1
|
||||
self.assertEqual(rmtree_mock.call_count, expected_deletes)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
self.assertIn(checkpoints[1], best_checkpoints)
|
||||
self.assertIn(checkpoints[2], best_checkpoints)
|
||||
|
||||
def testOnCheckpointUnordered(self):
|
||||
"""
|
||||
Tests priorities that aren't inserted in ascending order. Also tests
|
||||
that the worst checkpoints are deleted when necessary.
|
||||
"""
|
||||
keep_checkpoints_num = 2
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
|
||||
for i in range(3, -1, -1)
|
||||
]
|
||||
|
||||
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
|
||||
for j in range(0, len(checkpoints)):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 3 else 1
|
||||
self.assertEqual(rmtree_mock.call_count, expected_deletes)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
self.assertIn(checkpoints[0], best_checkpoints)
|
||||
self.assertIn(checkpoints[1], best_checkpoints)
|
||||
|
||||
def testBestCheckpoints(self):
|
||||
"""
|
||||
Tests that the best checkpoints are tracked and ordered correctly.
|
||||
"""
|
||||
keep_checkpoints_num = 4
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
|
||||
for i in range(16)
|
||||
]
|
||||
random.shuffle(checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
for i in range(len(best_checkpoints)):
|
||||
self.assertEqual(best_checkpoints[i].value, i + 12)
|
||||
|
||||
def testOnCheckpointUnavailableAttribute(self):
|
||||
"""
|
||||
Tests that an error is logged when the associated result of the
|
||||
checkpoint has no checkpoint score attribute.
|
||||
"""
|
||||
keep_checkpoints_num = 1
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
|
||||
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
|
||||
with patch.object(logger, "error") as log_error_mock:
|
||||
checkpoint_manager.on_checkpoint(no_attr_checkpoint)
|
||||
log_error_mock.assert_called_once()
|
||||
# The newest checkpoint should still be set despite this error.
|
||||
assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
|
||||
@@ -336,7 +336,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
|
||||
runner.step() # Recovery step
|
||||
for i in range(3):
|
||||
@@ -569,7 +569,7 @@ tune.run(
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
Experiment._register_if_needed(_Mock)
|
||||
Experiment.register_if_needed(_Mock)
|
||||
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner(
|
||||
|
||||
@@ -26,6 +26,7 @@ from ray.tune.result import (
|
||||
EPISODES_TOTAL, TRAINING_ITERATION, TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.syncer import CommandBasedClient
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object, flatten_dict
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, ExportFormat
|
||||
@@ -1124,21 +1125,20 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"sync_to_driver": "ls {source}"
|
||||
}).trials
|
||||
|
||||
with patch("ray.tune.syncer.CommandSyncer.sync_function"
|
||||
) as mock_fn, patch(
|
||||
"ray.services.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "echo {source} {target}"
|
||||
}).trials
|
||||
self.assertGreater(mock_fn.call_count, 0)
|
||||
with patch.object(CommandBasedClient, "execute") as mock_fn:
|
||||
with patch("ray.services.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "echo {source} {target}"
|
||||
}).trials
|
||||
self.assertGreater(mock_fn.call_count, 0)
|
||||
|
||||
def testCloudFunctions(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
@@ -1166,7 +1166,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
def sync_func_driver(source, target):
|
||||
assert ":" in source, "Source not a remote path."
|
||||
assert ":" in source, "Source {} not a remote path.".format(source)
|
||||
assert ":" not in target, "Target is supposed to be local."
|
||||
with open(os.path.join(target, "test.log2"), "w") as f:
|
||||
print("writing to", f.name)
|
||||
@@ -1203,7 +1203,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
def sync_func(source, target):
|
||||
pass
|
||||
|
||||
with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync:
|
||||
with patch.object(CommandBasedClient, "execute") as mock_sync:
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
|
||||
@@ -103,6 +103,7 @@ class TrackSession(object):
|
||||
self.trial_config["end_time"] = datetime.now().isoformat()
|
||||
# TODO(rliaw): Have Tune support updated configs
|
||||
self._logger.update_config(self.trial_config)
|
||||
self._logger.flush()
|
||||
self._logger.close()
|
||||
|
||||
@property
|
||||
|
||||
@@ -70,7 +70,6 @@ class Trainable(object):
|
||||
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config or {}
|
||||
log_sys_usage = self.config.get("log_sys_usage", False)
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
@@ -92,6 +91,7 @@ class Trainable(object):
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
self._restored = False
|
||||
|
||||
start_time = time.time()
|
||||
self._setup(copy.deepcopy(self.config))
|
||||
setup_time = time.time() - start_time
|
||||
@@ -101,6 +101,7 @@ class Trainable(object):
|
||||
"reuse_actors=True to reduce actor creation "
|
||||
"overheads.".format(setup_time))
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
log_sys_usage = self.config.get("log_sys_usage", False)
|
||||
self._monitor = UtilMonitor(start=log_sys_usage)
|
||||
|
||||
@classmethod
|
||||
@@ -112,11 +113,11 @@ class Trainable(object):
|
||||
|
||||
Example:
|
||||
>>> def default_resource_request(cls, config):
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=config["workers"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["workers"])
|
||||
>>> return Resources(
|
||||
>>> cpu=0,
|
||||
>>> gpu=0,
|
||||
>>> extra_cpu=config["workers"],
|
||||
>>> extra_gpu=int(config["use_gpu"]) * config["workers"])
|
||||
"""
|
||||
|
||||
return None
|
||||
@@ -171,7 +172,6 @@ class Trainable(object):
|
||||
Returns:
|
||||
A dict that describes training progress.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
assert isinstance(result, dict), "_train() needs to return a dict."
|
||||
@@ -239,17 +239,6 @@ class Trainable(object):
|
||||
|
||||
return result
|
||||
|
||||
def delete_checkpoint(self, checkpoint_dir):
|
||||
"""Removes subdirectory within checkpoint_folder
|
||||
|
||||
Args:
|
||||
checkpoint_dir : path to checkpoint
|
||||
"""
|
||||
if os.path.isfile(checkpoint_dir):
|
||||
shutil.rmtree(os.path.dirname(checkpoint_dir))
|
||||
else:
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
@@ -262,9 +251,9 @@ class Trainable(object):
|
||||
Returns:
|
||||
Checkpoint path or prefix that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
"checkpoint_{}".format(self._iteration))
|
||||
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
@@ -293,7 +282,7 @@ class Trainable(object):
|
||||
"time_total": self._time_total,
|
||||
"episodes_total": self._episodes_total,
|
||||
"saved_as_dict": saved_as_dict,
|
||||
"ray_version": ray.__version__
|
||||
"ray_version": ray.__version__,
|
||||
}, f)
|
||||
return checkpoint_path
|
||||
|
||||
@@ -305,10 +294,8 @@ class Trainable(object):
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
|
||||
checkpoint_path = self.save(tmpdir)
|
||||
|
||||
# Save all files in subtree.
|
||||
data = {}
|
||||
for basedir, _, file_names in os.walk(tmpdir):
|
||||
@@ -356,7 +343,7 @@ class Trainable(object):
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
self._restored = True
|
||||
logger.info("Restored from checkpoint: {}".format(checkpoint_path))
|
||||
logger.info("Restored from checkpoint: %s", checkpoint_path)
|
||||
state = {
|
||||
"_iteration": self._iteration,
|
||||
"_timesteps_total": self._timesteps_total,
|
||||
@@ -398,7 +385,7 @@ class Trainable(object):
|
||||
export_dir (str): Optional dir to place the exported model.
|
||||
Defaults to self.logdir.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
export_dir = export_dir or self.logdir
|
||||
@@ -422,7 +409,7 @@ class Trainable(object):
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this trainable."""
|
||||
|
||||
self._result_logger.flush()
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
|
||||
+65
-74
@@ -12,6 +12,7 @@ import tempfile
|
||||
import os
|
||||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
from ray.tune.util import flatten_dict
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
@@ -31,29 +32,20 @@ def date_str():
|
||||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
class Checkpoint(object):
|
||||
"""Describes a checkpoint of trial state.
|
||||
class Location(object):
|
||||
"""Describes the location at which Trial is placed to run."""
|
||||
|
||||
Checkpoint may be saved in different storage.
|
||||
def __init__(self, hostname=None, pid=None):
|
||||
self.hostname = hostname
|
||||
self.pid = pid
|
||||
|
||||
Attributes:
|
||||
storage (str): Storage type.
|
||||
value (str): If storage==MEMORY,value is a Python object.
|
||||
If storage==DISK,value is a path points to the checkpoint in disk.
|
||||
"""
|
||||
|
||||
MEMORY = "memory"
|
||||
DISK = "disk"
|
||||
|
||||
def __init__(self, storage, value, last_result=None):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
self.last_result = last_result or {}
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
"""Creates a checkpoint from a Python object."""
|
||||
return Checkpoint(Checkpoint.MEMORY, value)
|
||||
def __str__(self):
|
||||
if not self.pid:
|
||||
return ""
|
||||
elif self.hostname == os.uname()[1]:
|
||||
return "pid={}".format(self.pid)
|
||||
else:
|
||||
return "{}:{}".format(self.hostname, self.pid)
|
||||
|
||||
|
||||
class ExportFormat(object):
|
||||
@@ -108,8 +100,9 @@ class Trial(object):
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr="",
|
||||
checkpoint_score_attr=TRAINING_ITERATION,
|
||||
export_formats=None,
|
||||
restore_path=None,
|
||||
trial_name_creator=None,
|
||||
@@ -121,7 +114,6 @@ class Trial(object):
|
||||
The args here take the same meaning as the command line flags defined
|
||||
in ray.tune.config_parser.
|
||||
"""
|
||||
|
||||
validate_trainable(trainable_name)
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
@@ -145,6 +137,7 @@ class Trial(object):
|
||||
"clear the `resources_per_trial` option.".format(
|
||||
trainable_cls, default_resources))
|
||||
resources = default_resources
|
||||
self.address = Location()
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.loggers = loggers
|
||||
@@ -161,17 +154,12 @@ class Trial(object):
|
||||
# stores in memory max/min/last result for each metric by trial
|
||||
self.metric_analysis = {}
|
||||
|
||||
self.history = []
|
||||
self.keep_checkpoints_num = keep_checkpoints_num
|
||||
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
|
||||
self.best_checkpoint_attr_value = -float("inf") \
|
||||
if self._cmp_greater else float("inf")
|
||||
# Strip off "min-" from checkpoint attribute
|
||||
self.checkpoint_score_attr = checkpoint_score_attr \
|
||||
if self._cmp_greater else checkpoint_score_attr[4:]
|
||||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path)
|
||||
self.checkpoint_manager = CheckpointManager(keep_checkpoints_num,
|
||||
checkpoint_score_attr)
|
||||
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
|
||||
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.export_formats = export_formats
|
||||
self.status = Trial.PENDING
|
||||
self.logdir = None
|
||||
@@ -190,7 +178,7 @@ class Trial(object):
|
||||
self.extra_arg = None
|
||||
|
||||
self._nonjson_fields = [
|
||||
"_checkpoint",
|
||||
"checkpoint",
|
||||
"loggers",
|
||||
"sync_to_driver_fn",
|
||||
"results",
|
||||
@@ -201,6 +189,14 @@ class Trial(object):
|
||||
if trial_name_creator:
|
||||
self.custom_trial_name = trial_name_creator(self)
|
||||
|
||||
@property
|
||||
def node_ip(self):
|
||||
return self.address.hostname
|
||||
|
||||
@property
|
||||
def checkpoint(self):
|
||||
return self.checkpoint_manager.newest_checkpoint
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
return str(uuid.uuid1().hex)[:8]
|
||||
@@ -249,10 +245,14 @@ class Trial(object):
|
||||
"""
|
||||
if self.result_logger:
|
||||
self.result_logger.sync_results_to_new_location(worker_ip)
|
||||
self.set_location(Location(worker_ip))
|
||||
|
||||
def set_location(self, location):
|
||||
"""Sets the location of the trial."""
|
||||
self.address = location
|
||||
|
||||
def close_logger(self):
|
||||
"""Close logger."""
|
||||
|
||||
"""Closes logger."""
|
||||
if self.result_logger:
|
||||
self.result_logger.close()
|
||||
self.result_logger = None
|
||||
@@ -262,8 +262,9 @@ class Trial(object):
|
||||
self.num_failures += 1 # may be moved to outer scope?
|
||||
error_file = os.path.join(self.logdir,
|
||||
"error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
with open(error_file, "a+") as f:
|
||||
f.write("Failure # {}".format(self.num_failures) + "\n")
|
||||
f.write(error_msg + "\n")
|
||||
self.error_file = error_file
|
||||
self.error_msg = error_msg
|
||||
|
||||
@@ -292,21 +293,36 @@ class Trial(object):
|
||||
def should_checkpoint(self):
|
||||
"""Whether this trial is due for checkpointing."""
|
||||
result = self.last_result or {}
|
||||
|
||||
if result.get(DONE) and self.checkpoint_at_end:
|
||||
return True
|
||||
|
||||
if self.checkpoint_freq:
|
||||
return result.get(TRAINING_ITERATION,
|
||||
0) % self.checkpoint_freq == 0
|
||||
else:
|
||||
return False
|
||||
return (self.checkpoint_freq and
|
||||
result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0)
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint.value is not None
|
||||
return self.checkpoint.value is not None
|
||||
|
||||
def clear_checkpoint(self):
|
||||
self._checkpoint.value = None
|
||||
self.checkpoint.value = None
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Hook for handling checkpoints taken by the Trainable.
|
||||
|
||||
Args:
|
||||
checkpoint (Checkpoint): Checkpoint taken.
|
||||
"""
|
||||
if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK:
|
||||
# Wait for any other syncs to finish. We need to sync again after
|
||||
# this to handle checkpoints taken mid-sync.
|
||||
self.result_logger.wait()
|
||||
# Force sync down and wait before tracking the new checkpoint. This
|
||||
# prevents attempts to restore from partially synced checkpoints.
|
||||
if self.result_logger.sync_down():
|
||||
self.result_logger.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Checkpoint sync skipped. "
|
||||
"This should not happen.", self)
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for retrying.
|
||||
@@ -327,6 +343,7 @@ class Trial(object):
|
||||
print("Result for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
self.set_location(Location(result.get("node_ip"), result.get("pid")))
|
||||
self.last_result = result
|
||||
self.last_update_time = time.time()
|
||||
self.result_logger.on_result(self.last_result)
|
||||
@@ -345,28 +362,6 @@ class Trial(object):
|
||||
value, self.metric_analysis[metric]["min"])
|
||||
self.metric_analysis[metric]["last"] = value
|
||||
|
||||
def compare_checkpoints(self, attr_mean):
|
||||
"""Compares two checkpoints based on the attribute attr_mean param.
|
||||
Greater than is used by default. If command-line parameter
|
||||
checkpoint_score_attr starts with "min-" less than is used.
|
||||
|
||||
Arguments:
|
||||
attr_mean: mean of attribute value for the current checkpoint
|
||||
|
||||
Returns:
|
||||
True: when attr_mean is greater than previous checkpoint attr_mean
|
||||
and greater than function is selected
|
||||
when attr_mean is less than previous checkpoint attr_mean and
|
||||
less than function is selected
|
||||
False: when attr_mean is not in alignment with selected cmp fn
|
||||
"""
|
||||
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
|
||||
return True
|
||||
elif (not self._cmp_greater
|
||||
and attr_mean < self.best_checkpoint_attr_value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_trainable_cls(self):
|
||||
return get_trainable_cls(self.trainable_name)
|
||||
|
||||
@@ -376,10 +371,6 @@ class Trial(object):
|
||||
def is_finished(self):
|
||||
return self.status in [Trial.TERMINATED, Trial.ERROR]
|
||||
|
||||
@property
|
||||
def node_ip(self):
|
||||
return self.last_result.get("node_ip")
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
@@ -407,7 +398,7 @@ class Trial(object):
|
||||
Sets RUNNING trials to PENDING, and flushes the result logger.
|
||||
Note this can only occur if the trial holds a DISK checkpoint.
|
||||
"""
|
||||
assert self._checkpoint.storage == Checkpoint.DISK, (
|
||||
assert self.checkpoint.storage == Checkpoint.DISK, (
|
||||
"Checkpoint must not be in-memory.")
|
||||
state = self.__dict__.copy()
|
||||
state["resources"] = resources_to_json(self.resources)
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import print_function
|
||||
import logging
|
||||
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,6 +39,8 @@ class TrialExecutor(object):
|
||||
trial (Trial): Trial to checkpoint.
|
||||
status (Trial.status): Status to set trial to.
|
||||
"""
|
||||
logger.debug("Trial %s: Changing status from %s to %s.", trial,
|
||||
trial.status, status)
|
||||
trial.status = status
|
||||
if status in [Trial.TERMINATED, Trial.ERROR]:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
@@ -48,14 +51,16 @@ class TrialExecutor(object):
|
||||
Args:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
"""
|
||||
if trial._checkpoint.storage == Checkpoint.MEMORY:
|
||||
logger.debug("Not saving data for trial w/ memory checkpoint.")
|
||||
if trial.checkpoint.storage == Checkpoint.MEMORY:
|
||||
logger.debug("Trial %s: Not saving data for memory checkpoint.",
|
||||
trial)
|
||||
return
|
||||
try:
|
||||
logger.debug("Saving trial metadata.")
|
||||
logger.debug("Trial %s: Saving trial metadata.", trial)
|
||||
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
|
||||
except Exception:
|
||||
logger.exception("Error checkpointing trial metadata.")
|
||||
logger.exception("Trial %s: Error checkpointing trial metadata.",
|
||||
trial)
|
||||
|
||||
def get_checkpoints(self):
|
||||
"""Returns a copy of mapping of the trial ID to pickled metadata."""
|
||||
@@ -118,7 +123,6 @@ class TrialExecutor(object):
|
||||
|
||||
def resume_trial(self, trial):
|
||||
"""Resumes PAUSED trials. This is a blocking call."""
|
||||
|
||||
assert trial.status == Trial.PAUSED, trial.status
|
||||
self.start_trial(trial)
|
||||
|
||||
@@ -150,6 +154,27 @@ class TrialExecutor(object):
|
||||
"""A hook called after running one step of the trial event loop."""
|
||||
pass
|
||||
|
||||
def on_no_available_trials(self, trial_runner):
|
||||
if self._queue_trials:
|
||||
return
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster has only {}. "
|
||||
"Pass `queue_trials=True` in "
|
||||
"ray.tune.run() or on the command "
|
||||
"line to queue trials until the cluster scales "
|
||||
"up or resources become available. {}").format(
|
||||
trial.resources.summary_string(),
|
||||
self.resource_string(),
|
||||
trial.get_trainable_cls().resource_help(
|
||||
trial.config)))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError("There are paused trials, but no more pending "
|
||||
"trials with sufficient resources.")
|
||||
|
||||
def get_next_available_trial(self):
|
||||
"""Blocking call that waits until one result is ready.
|
||||
|
||||
@@ -188,7 +213,7 @@ class TrialExecutor(object):
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a checkpoint.
|
||||
|
||||
If checkpoint is None, try to restore from trial._checkpoint.
|
||||
If checkpoint is None, try to restore from trial.checkpoint.
|
||||
If restoring fails, the trial status will be set to ERROR.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -17,8 +17,8 @@ from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
from ray.tune.syncer import get_syncer
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.syncer import get_cloud_syncer
|
||||
from ray.tune.trial import Checkpoint, Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.util import warn_if_slow, flatten_dict
|
||||
@@ -126,7 +126,7 @@ class TrialRunner(object):
|
||||
global checkpoints are stored and restored from. Used
|
||||
if `resume` == REMOTE.
|
||||
resume (str|False): see `tune.py:run`.
|
||||
sync_to_cloud (func|str): see `tune.py:run`.
|
||||
sync_to_cloud (func|str): See `tune.py:run`.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
@@ -158,8 +158,8 @@ class TrialRunner(object):
|
||||
os.makedirs(self._local_checkpoint_dir)
|
||||
|
||||
self._remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir,
|
||||
sync_to_cloud)
|
||||
self._syncer = get_cloud_syncer(local_checkpoint_dir,
|
||||
remote_checkpoint_dir, sync_to_cloud)
|
||||
|
||||
self._resumed = False
|
||||
|
||||
@@ -216,8 +216,7 @@ class TrialRunner(object):
|
||||
"Called resume from remote without remote directory.")
|
||||
|
||||
# Try syncing down the upload directory.
|
||||
logger.info("Downloading from {}".format(
|
||||
self._remote_checkpoint_dir))
|
||||
logger.info("Downloading from %s", self._remote_checkpoint_dir)
|
||||
self._syncer.sync_down_if_needed()
|
||||
|
||||
if not self.checkpoint_exists(self._local_checkpoint_dir):
|
||||
@@ -334,24 +333,7 @@ class TrialRunner(object):
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events() # blocking
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster has only {}. "
|
||||
"Pass `queue_trials=True` in "
|
||||
"ray.tune.run() or on the command "
|
||||
"line to queue trials until the cluster scales "
|
||||
"up. {}").format(
|
||||
trial.resources.summary_string(),
|
||||
self.trial_executor.resource_string(),
|
||||
trial.get_trainable_cls().resource_help(
|
||||
trial.config)))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError(
|
||||
"There are paused trials, but no more pending "
|
||||
"trials with sufficient resources.")
|
||||
self.trial_executor.on_no_available_trials(self)
|
||||
|
||||
try:
|
||||
with warn_if_slow("experiment_checkpoint"):
|
||||
@@ -422,12 +404,12 @@ class TrialRunner(object):
|
||||
def _process_events(self):
|
||||
failed_trial = self.trial_executor.get_next_failed_trial()
|
||||
if failed_trial:
|
||||
error_msg = (
|
||||
"{} (IP: {}) detected as stale. This is likely because the "
|
||||
"node was lost").format(failed_trial, failed_trial.node_ip)
|
||||
logger.info(error_msg)
|
||||
with warn_if_slow("process_failed_trial"):
|
||||
self._process_trial_failure(
|
||||
failed_trial,
|
||||
error_msg="{} (ip: {}) detected as stale. This is likely"
|
||||
"because the node was lost".format(failed_trial,
|
||||
failed_trial.node_ip))
|
||||
self._process_trial_failure(failed_trial, error_msg=error_msg)
|
||||
else:
|
||||
trial = self.trial_executor.get_next_available_trial() # blocking
|
||||
with warn_if_slow("process_trial"):
|
||||
@@ -537,12 +519,17 @@ class TrialRunner(object):
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.trial_executor.has_resources(trial.resources):
|
||||
logger.info("Attempting to recover trial.")
|
||||
logger.info(
|
||||
"Trial %s: Attempting to recover "
|
||||
"trial state from last checkpoint.", trial)
|
||||
self.trial_executor.start_trial(trial)
|
||||
if trial.status == Trial.ERROR:
|
||||
logger.error("Trial %s: Did not start correctly.", trial)
|
||||
raise RuntimeError("Trial did not start correctly.")
|
||||
logger.debug("Trial %s: Started correctly.", trial)
|
||||
else:
|
||||
logger.debug("Notifying Scheduler and requeueing trial.")
|
||||
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
|
||||
trial)
|
||||
self._requeue_trial(trial)
|
||||
except Exception:
|
||||
logger.exception("Error recovering trial from checkpoint, abort.")
|
||||
|
||||
+14
-7
@@ -69,6 +69,7 @@ def run(run_or_experiment,
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr=None,
|
||||
global_checkpoint_period=10,
|
||||
@@ -118,18 +119,18 @@ def run(run_or_experiment,
|
||||
`num_samples` of times.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results
|
||||
to (e.g. ``s3://bucket``).
|
||||
upload_dir (str): Optional URI to sync training results to
|
||||
(e.g. ``s3://bucket``).
|
||||
trial_name_creator (func): Optional function for generating
|
||||
the trial string representation.
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
sync_to_cloud (func|str): Function for syncing the local_dir to and
|
||||
from upload_dir. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If not provided, the sync command defaults to standard
|
||||
S3 or gsutil sync comamnds.
|
||||
from upload_dir. If string, then it must be a string template that
|
||||
includes `{source}` and `{target}` for the syncer to run. If not
|
||||
provided, the sync command defaults to standard S3 or gsutil sync
|
||||
comamnds.
|
||||
sync_to_driver (func|str): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
@@ -138,6 +139,11 @@ def run(run_or_experiment,
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to
|
||||
guarantee recoverability. If set to False, checkpoint syncing from
|
||||
worker to driver is asynchronous. Set this to False only if
|
||||
synchronous checkpointing is too slow and trial restoration
|
||||
failures can be tolerated. Defaults to True.
|
||||
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
|
||||
`None` keeps all checkpoints. Defaults to `None`. If set, need
|
||||
to provide `checkpoint_score_attr`.
|
||||
@@ -223,7 +229,7 @@ def run(run_or_experiment,
|
||||
"not work with certain features.")
|
||||
for i, exp in enumerate(experiments):
|
||||
if not isinstance(exp, Experiment):
|
||||
run_identifier = Experiment._register_if_needed(exp)
|
||||
run_identifier = Experiment.register_if_needed(exp)
|
||||
experiments[i] = Experiment(
|
||||
name=name,
|
||||
run=run_identifier,
|
||||
@@ -238,6 +244,7 @@ def run(run_or_experiment,
|
||||
loggers=loggers,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
sync_on_checkpoint=sync_on_checkpoint,
|
||||
keep_checkpoints_num=keep_checkpoints_num,
|
||||
checkpoint_score_attr=checkpoint_score_attr,
|
||||
export_formats=export_formats,
|
||||
|
||||
+12
-5
@@ -119,18 +119,25 @@ class warn_if_slow(object):
|
||||
... ray.get(something)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
DEFAULT_THRESHOLD = 0.5
|
||||
|
||||
def __init__(self, name, threshold=None):
|
||||
self.name = name
|
||||
self.threshold = threshold or self.DEFAULT_THRESHOLD
|
||||
self.too_slow = False
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
now = time.time()
|
||||
if now - self.start > 0.5 and now - START_OF_TIME > 60.0:
|
||||
logger.warning("The `{}` operation took {} seconds to complete, ".
|
||||
format(self.name, now - self.start) +
|
||||
"which may be a performance bottleneck.")
|
||||
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
|
||||
self.too_slow = True
|
||||
logger.warning(
|
||||
"The `%s` operation took %s seconds to complete, "
|
||||
"which may be a performance bottleneck.", self.name,
|
||||
now - self.start)
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
|
||||
@@ -311,10 +311,6 @@ class Worker(object):
|
||||
"which is not an ray.ObjectID.".format(object_id))
|
||||
|
||||
if self.mode == LOCAL_MODE:
|
||||
# TODO(ujvl): Remove check when local mode moved to core worker.
|
||||
if timeout is not None:
|
||||
raise ValueError(
|
||||
"`get` must be called with timeout=None in local mode.")
|
||||
return self.local_mode_manager.get_objects(object_ids)
|
||||
|
||||
timeout_ms = int(timeout * 1000) if timeout else -1
|
||||
@@ -1407,8 +1403,8 @@ def get(object_ids, timeout=None):
|
||||
Args:
|
||||
object_ids: Object ID of the object to get or a list of object IDs to
|
||||
get.
|
||||
timeout (float): The maximum amount of time in seconds to wait before
|
||||
returning.
|
||||
timeout (Optional[float]): The maximum amount of time in seconds to
|
||||
wait before returning.
|
||||
|
||||
Returns:
|
||||
A Python object or a list of Python objects.
|
||||
|
||||
Reference in New Issue
Block a user