mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 05:05:21 +08:00
[tune] Async restores and S3/GCP-capable trial FT (#6376)
* Initial commit for asynchronous save/restore * Set stage for cloud checkpointable trainable. * Refactor log_sync and sync_client. * Add durable trainable impl. * Support delete in cmd based client * Fix some tests and such * Cleanup, comments. * Use upload_dir instead. * Revert files belonging to other PR in split. * Pass upload_dir into trainable init. * Pickle checkpoint at driver, more robust checkpoint_dir discovery. * Cleanup trainable helper functions, fix tests. * Addressed comments. * Fix bugs from cluster testing, add parameterized cluster tests. * Add trainable util test * package_ref * pbt_address * Fix bug after running pbt example (_save returning dir). * get cluster tests running, other bug fixes. * raise_errors * Fix deleter bug, add durable trainable example. * Fix cluster test bugs. * filelock * save/restore bug fixes * . * Working cluster tests. * Lint, revert to tracking memory checkpoints. * Documentation, cleanup * fixinitialsync * fix_one_test * Fix cluster test bug * nit * lint * Revert tune md change * Fix basename bug for directories. * lint * fix_tests * nit_fix * Add __init__ file. * Move to utils package Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
committed by
Richard Liaw
parent
57061a15cf
commit
ca651af1d7
@@ -148,7 +148,14 @@ py_test(
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_trainable_util",
|
||||
size = "small",
|
||||
srcs = ["tests/test_trainable_util.py"],
|
||||
deps = [":tune_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_trial_scheduler",
|
||||
size = "medium",
|
||||
|
||||
@@ -8,12 +8,14 @@ from ray.tune.experiment import Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis, Analysis
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
randn, loguniform)
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"DurableTrainable",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
|
||||
@@ -5,8 +5,6 @@ from __future__ import print_function
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
try:
|
||||
FileNotFoundError
|
||||
@@ -23,31 +21,18 @@ class Checkpoint:
|
||||
|
||||
Attributes:
|
||||
storage (str): Storage type.
|
||||
value (str): If storage==MEMORY, value is a Python object.
|
||||
If storage==DISK, value is a path points to the checkpoint in disk.
|
||||
value (str): If storage==MEMORY, it is a Python object.
|
||||
If storage==PERSISTENT, it is a path to persistent storage.
|
||||
"""
|
||||
|
||||
MEMORY = "memory"
|
||||
DISK = "disk"
|
||||
PERSISTENT = "persistent"
|
||||
|
||||
def __init__(self, storage, value, result=None):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
self.result = result or {}
|
||||
|
||||
def delete(self):
|
||||
"""Deletes checkpoint data if disk checkpoint."""
|
||||
if self.storage == Checkpoint.DISK and self.value:
|
||||
checkpoint_dir = self.value
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
raise FileNotFoundError(
|
||||
"Attempted to delete checkpoint at {} but "
|
||||
"path was not found.".format(checkpoint_dir))
|
||||
elif os.path.isfile(checkpoint_dir):
|
||||
shutil.rmtree(os.path.dirname(checkpoint_dir))
|
||||
else:
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
"""Creates a checkpoint from a Python object."""
|
||||
@@ -72,13 +57,15 @@ class QueueItem:
|
||||
class CheckpointManager:
|
||||
"""Manages checkpoints on the driver for a trial."""
|
||||
|
||||
def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
|
||||
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
|
||||
"""Initializes a new CheckpointManager.
|
||||
|
||||
Args:
|
||||
keep_checkpoints_num (int): Keep at least this many checkpoints.
|
||||
checkpoint_score_attr (str): Attribute to use to determine which
|
||||
checkpoints to keep.
|
||||
delete_fn (function): Function that deletes checkpoints. Must be
|
||||
idempotent.
|
||||
"""
|
||||
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
|
||||
assert self.keep_checkpoints_num > 0, (
|
||||
@@ -88,7 +75,7 @@ class CheckpointManager:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr[4:]
|
||||
else:
|
||||
self._checkpoint_score_attr = checkpoint_score_attr
|
||||
|
||||
self.delete = delete_fn
|
||||
self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self._best_checkpoints = []
|
||||
self._membership = set()
|
||||
@@ -101,9 +88,6 @@ class CheckpointManager:
|
||||
|
||||
Args:
|
||||
checkpoint (Checkpoint): Trial state checkpoint.
|
||||
|
||||
Raises:
|
||||
KeyError if checkpoint_score_attr not in result of checkpoint.
|
||||
"""
|
||||
old_checkpoint = self.newest_checkpoint
|
||||
self.newest_checkpoint = checkpoint
|
||||
@@ -112,7 +96,7 @@ class CheckpointManager:
|
||||
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
|
||||
except KeyError:
|
||||
if old_checkpoint not in self._membership:
|
||||
old_checkpoint.delete()
|
||||
self.delete(old_checkpoint)
|
||||
logger.error("Result dict has no key: {}. "
|
||||
"checkpoint_score_attr must be set to a key in the "
|
||||
"result dict.".format(self._checkpoint_score_attr))
|
||||
@@ -126,11 +110,11 @@ class CheckpointManager:
|
||||
self._membership.add(checkpoint)
|
||||
if worst in self._membership:
|
||||
self._membership.remove(worst)
|
||||
worst.delete()
|
||||
self.delete(worst)
|
||||
|
||||
# Remove the old checkpoint if it isn't one of the best ones.
|
||||
if old_checkpoint not in self._membership:
|
||||
old_checkpoint.delete()
|
||||
if old_checkpoint.value and old_checkpoint not in self._membership:
|
||||
self.delete(old_checkpoint)
|
||||
|
||||
def best_checkpoints(self):
|
||||
"""Returns best checkpoints, sorted by score."""
|
||||
@@ -140,3 +124,13 @@ class CheckpointManager:
|
||||
def _priority(self, checkpoint):
|
||||
priority = checkpoint.result[self._checkpoint_score_attr]
|
||||
return -priority if self._checkpoint_score_desc else priority
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Avoid serializing lambda since it may capture cyclical dependencies.
|
||||
state.pop("delete")
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.delete = None
|
||||
|
||||
@@ -183,6 +183,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
local_dir=os.path.join(spec["local_dir"], output_path),
|
||||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
sync_on_checkpoint=not args.no_sync_on_checkpoint,
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.syncer import get_cloud_sync_client
|
||||
|
||||
|
||||
class DurableTrainable(Trainable):
|
||||
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
|
||||
|
||||
Supports checkpointing to and restoring from remote storage. To use this
|
||||
class, implement the same private methods as ray.tune.Trainable (`_save`,
|
||||
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).
|
||||
|
||||
.. warning:: This class is currently **experimental** and may
|
||||
be subject to change.
|
||||
|
||||
Run this with Tune as follows. Setting `sync_to_driver=False` disables
|
||||
syncing to the driver to avoid keeping redundant checkpoints around, as
|
||||
well as preventing the driver from syncing up the same checkpoint.
|
||||
|
||||
See ``tune/trainable.py``.
|
||||
|
||||
Attributes:
|
||||
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
||||
storage_client: Tune-internal interface for interacting with external
|
||||
storage.
|
||||
|
||||
>>> tune.run(MyDurableTrainable, sync_to_driver=False)
|
||||
"""
|
||||
|
||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
||||
"""Initializes a DurableTrainable.
|
||||
|
||||
Args:
|
||||
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
||||
"""
|
||||
super(DurableTrainable, self).__init__(*args, **kwargs)
|
||||
self.remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self.storage_client = self._create_storage_client()
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
"""Saves the current model state to a checkpoint, persisted remotely.
|
||||
|
||||
The storage client must provide durability for
|
||||
restoration to work. That is, once ``storage.client.wait()``
|
||||
returns after a checkpoint `sync up`, the checkpoint is considered
|
||||
committed and can be used to restore the trainable.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (Optional[str]): Optional dir to place the
|
||||
checkpoint. Must be ``logdir`` or a sub-directory.
|
||||
|
||||
Returns:
|
||||
Checkpoint path or prefix that may be passed to restore().
|
||||
"""
|
||||
if checkpoint_dir:
|
||||
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
|
||||
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
|
||||
"a sub-directory.")
|
||||
|
||||
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
|
||||
self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir)
|
||||
self.storage_client.wait()
|
||||
return checkpoint_path
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given checkpoint persisted remotely.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Local path to checkpoint.
|
||||
"""
|
||||
self.storage_client.sync_down(self.remote_checkpoint_dir, self.logdir)
|
||||
self.storage_client.wait()
|
||||
super(DurableTrainable, self).restore(checkpoint_path)
|
||||
|
||||
def delete_checkpoint(self, checkpoint_path):
|
||||
"""Deletes checkpoint from both local and remote storage.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Local path to checkpoint.
|
||||
"""
|
||||
super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
|
||||
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
self.storage_client.delete(self._storage_path(local_dirpath))
|
||||
|
||||
def _create_storage_client(self):
|
||||
"""Returns a storage client."""
|
||||
return get_cloud_sync_client(self.remote_checkpoint_dir)
|
||||
|
||||
def _storage_path(self, local_path):
|
||||
rel_local_path = os.path.relpath(local_path, self.logdir)
|
||||
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
|
||||
@@ -0,0 +1,126 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import DurableTrainable
|
||||
from ray.tune.sync_client import get_sync_client
|
||||
|
||||
import cloudpickle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockDurableTrainable(DurableTrainable):
|
||||
"""Mocks the storage client on initialization to store data locally."""
|
||||
|
||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
||||
# Mock the path as a local path.
|
||||
local_dir_suffix = remote_checkpoint_dir.split("://")[1]
|
||||
remote_checkpoint_dir = os.path.join("/tmp", local_dir_suffix)
|
||||
# Disallow malformed relative paths for delete safety.
|
||||
assert os.path.abspath(remote_checkpoint_dir).startswith("/tmp")
|
||||
logger.info("Using %s as the mocked remote checkpoint directory.",
|
||||
self.remote_checkpoint_dir)
|
||||
super(MockDurableTrainable, self).__init__(remote_checkpoint_dir,
|
||||
*args, **kwargs)
|
||||
|
||||
def _create_storage_client(self):
|
||||
sync = "mkdir -p {target} && rsync -avz {source} {target}"
|
||||
delete = "rm -rf {target}"
|
||||
return get_sync_client(sync, delete)
|
||||
|
||||
|
||||
class OptimusFn(object):
|
||||
def __init__(self, params, max_t=10000):
|
||||
self.params = params
|
||||
self.noise = np.random.normal(size=max_t) * 0.005
|
||||
|
||||
def eval(self, k, add_noise=True):
|
||||
b0, b1, b2 = self.params
|
||||
score = (b0 * k / 100 + 0.1 * b1 + 0.5)**(-1) + b2 * 0.01
|
||||
if add_noise:
|
||||
return score + abs(self.noise[k])
|
||||
else:
|
||||
return score
|
||||
|
||||
|
||||
def get_optimus_trainable(parent_cls):
|
||||
class OptimusTrainable(parent_cls):
|
||||
def _setup(self, config):
|
||||
self.iter = 0
|
||||
if config.get("seed"):
|
||||
np.random.seed(config["seed"])
|
||||
time.sleep(config.get("startup_delay", 0))
|
||||
params = [config["param1"], config["param2"], config["param3"]]
|
||||
self.func = OptimusFn(params=params)
|
||||
self.initial_samples_per_step = 500
|
||||
self.mock_data = open("/dev/urandom", "rb").read(1024)
|
||||
|
||||
def _train(self):
|
||||
self.iter += 1
|
||||
new_loss = self.func.eval(self.iter)
|
||||
time.sleep(0.5)
|
||||
return {
|
||||
"mean_loss": float(new_loss),
|
||||
"mean_accuracy": (2 - new_loss) / 2,
|
||||
"samples": self.initial_samples_per_step
|
||||
}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
time.sleep(0.5)
|
||||
return {
|
||||
"func": cloudpickle.dumps(self.func),
|
||||
"seed": np.random.get_state(),
|
||||
"data": self.mock_data,
|
||||
"iter": self.iter
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
self.func = cloudpickle.loads(checkpoint["func"])
|
||||
self.data = checkpoint["data"]
|
||||
self.iter = checkpoint["iter"]
|
||||
np.random.set_state(checkpoint["seed"])
|
||||
|
||||
return OptimusTrainable
|
||||
|
||||
|
||||
def parse():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--local", action="store_true", default=False)
|
||||
parser.add_argument("--mock-storage", action="store_true", default=False)
|
||||
parser.add_argument("--remote-dir", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse()
|
||||
address = None if args.local else "auto"
|
||||
ray.init(address=address)
|
||||
|
||||
config = {
|
||||
"seed": None,
|
||||
"startup_delay": 0.001,
|
||||
"param1": tune.sample_from(lambda spec: np.random.exponential(0.1)),
|
||||
"param2": tune.sample_from(lambda _: np.random.rand()),
|
||||
"param3": tune.sample_from(lambda _: np.random.rand()),
|
||||
}
|
||||
|
||||
parent = MockDurableTrainable if args.mock_storage else DurableTrainable
|
||||
analysis = tune.run(
|
||||
get_optimus_trainable(parent),
|
||||
name="durableTrainable" + str(time.time()),
|
||||
config=config,
|
||||
num_samples=4,
|
||||
verbose=1,
|
||||
queue_trials=True,
|
||||
# fault tolerance parameters
|
||||
max_failures=-1,
|
||||
checkpoint_freq=20,
|
||||
sync_to_driver=False,
|
||||
sync_on_checkpoint=False,
|
||||
upload_dir="s3://ray-tune-test/exps/",
|
||||
checkpoint_score_attr="training_iteration",
|
||||
)
|
||||
@@ -11,6 +11,7 @@ from ray.tune.trial import ExportFormat
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from filelock import FileLock
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -248,7 +249,8 @@ class PytorchTrainable(tune.Trainable):
|
||||
self.netG.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
betas=(beta1, 0.999))
|
||||
self.dataloader = get_data_loader()
|
||||
with FileLock(os.path.expanduser("~/.data.lock")):
|
||||
self.dataloader = get_data_loader()
|
||||
|
||||
def _train(self):
|
||||
lossG, lossD, is_score = train(
|
||||
|
||||
@@ -10,7 +10,7 @@ import six
|
||||
import types
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import sample_from
|
||||
|
||||
@@ -34,6 +34,22 @@ def _raise_deprecation_note(deprecated, replacement, soft=False):
|
||||
raise DeprecationWarning(error_msg)
|
||||
|
||||
|
||||
def _raise_on_durable(trainable_name, sync_to_driver, upload_dir):
|
||||
trainable_cls = get_trainable_cls(trainable_name)
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
if issubclass(trainable_cls, DurableTrainable):
|
||||
if sync_to_driver is not False:
|
||||
raise ValueError(
|
||||
"EXPERIMENTAL: DurableTrainable will automatically sync "
|
||||
"results to the provided upload_dir. "
|
||||
"Set `sync_to_driver=False` to avoid data inconsistencies.")
|
||||
if not upload_dir:
|
||||
raise ValueError(
|
||||
"EXPERIMENTAL: DurableTrainable will automatically sync "
|
||||
"results to the provided upload_dir. "
|
||||
"`upload_dir` must be provided.")
|
||||
|
||||
|
||||
class Experiment:
|
||||
"""Tracks experiment specifications.
|
||||
|
||||
@@ -109,6 +125,14 @@ class Experiment:
|
||||
|
||||
config = config or {}
|
||||
self._run_identifier = Experiment.register_if_needed(run)
|
||||
self.name = name or self._run_identifier
|
||||
if upload_dir:
|
||||
self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
|
||||
else:
|
||||
self.remote_checkpoint_dir = None
|
||||
|
||||
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
|
||||
|
||||
spec = {
|
||||
"run": self._run_identifier,
|
||||
"stop": stop,
|
||||
@@ -118,6 +142,7 @@ class Experiment:
|
||||
"local_dir": os.path.abspath(
|
||||
os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
|
||||
"upload_dir": upload_dir,
|
||||
"remote_checkpoint_dir": self.remote_checkpoint_dir,
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"loggers": loggers,
|
||||
"sync_to_driver": sync_to_driver,
|
||||
@@ -131,8 +156,6 @@ class Experiment:
|
||||
"restore": os.path.abspath(os.path.expanduser(restore))
|
||||
if restore else None
|
||||
}
|
||||
|
||||
self.name = name or self._run_identifier
|
||||
self.spec = spec
|
||||
|
||||
@classmethod
|
||||
@@ -204,11 +227,6 @@ class Experiment:
|
||||
if self.local_dir:
|
||||
return os.path.join(self.local_dir, self.name)
|
||||
|
||||
@property
|
||||
def remote_checkpoint_dir(self):
|
||||
if self.spec["upload_dir"]:
|
||||
return os.path.join(self.spec["upload_dir"], self.name)
|
||||
|
||||
@property
|
||||
def run_identifier(self):
|
||||
"""Returns a string representing the trainable identifier."""
|
||||
|
||||
@@ -430,11 +430,13 @@ class UnifiedLogger(Logger):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
|
||||
def flush(self):
|
||||
def flush(self, sync_down=True):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
if not self._log_syncer.sync_down():
|
||||
logger.warning("Trial %s: Post-flush sync skipped.", self.trial)
|
||||
if sync_down:
|
||||
if not self._log_syncer.sync_down():
|
||||
logger.warning("Trial %s: Post-flush sync skipped.",
|
||||
self.trial)
|
||||
|
||||
def sync_up(self):
|
||||
return self._log_syncer.sync_up()
|
||||
|
||||
@@ -238,7 +238,7 @@ def _get_trial_info(trial, parameters, metrics):
|
||||
metrics (List[str]): Names of metrics to include.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
trial_info = [str(trial), trial.status, str(trial.address)]
|
||||
trial_info = [str(trial), trial.status, str(trial.location)]
|
||||
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
|
||||
trial_info += [result.get(metric) for metric in metrics]
|
||||
return trial_info
|
||||
|
||||
@@ -13,10 +13,12 @@ import ray
|
||||
from ray.exceptions import RayTimeoutError
|
||||
from ray import ray_constants
|
||||
from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trial import Trial, Checkpoint, Location
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial, Checkpoint, Location
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.util import warn_if_slow
|
||||
from ray.tune.error import TuneError
|
||||
@@ -27,7 +29,6 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
|
||||
BOTTLENECK_WARN_PERIOD_S = 60
|
||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||
DEFAULT_GET_TIMEOUT = 30.0 # seconds
|
||||
TRIAL_START_ATTEMPTS = 3
|
||||
|
||||
|
||||
class _LocalWrapper:
|
||||
@@ -86,7 +87,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._cached_actor)
|
||||
existing_runner = self._cached_actor
|
||||
self._cached_actor = None
|
||||
trial.runner = existing_runner
|
||||
trial.set_runner(existing_runner)
|
||||
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
|
||||
raise AbortTrialExecution(
|
||||
"Trainable runner reuse requires reset_config() to be "
|
||||
@@ -122,11 +123,16 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.debug("Trial %s: Setting up new remote runner.", trial)
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
return cls.remote(config=trial.config, logger_creator=logger_creator)
|
||||
kwargs = {
|
||||
"config": trial.config,
|
||||
"logger_creator": logger_creator,
|
||||
}
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||
return cls.remote(**kwargs)
|
||||
|
||||
def _train(self, trial):
|
||||
"""Start one iteration of training and save remote id."""
|
||||
|
||||
if self._find_item(self._paused, trial):
|
||||
raise TuneError(
|
||||
"Should not call `train` on PAUSED trial {}. "
|
||||
@@ -168,9 +174,11 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""
|
||||
prior_status = trial.status
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.runner = runner or self._setup_remote_runner(
|
||||
trial,
|
||||
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
|
||||
trial.set_runner(
|
||||
runner or self._setup_remote_runner(
|
||||
trial,
|
||||
reuse_allowed=checkpoint is not None
|
||||
or trial.has_checkpoint()))
|
||||
self.restore(trial, checkpoint)
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
@@ -178,7 +186,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# If Trial was in flight when paused, self._paused stores result.
|
||||
self._paused.pop(previous_run[0])
|
||||
self._running[previous_run[0]] = trial
|
||||
else:
|
||||
elif not trial.is_restoring:
|
||||
self._train(trial)
|
||||
|
||||
def _stop_trial(self, trial, error=False, error_msg=None,
|
||||
@@ -194,7 +202,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
@@ -206,7 +213,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if hasattr(trial, "runner") and trial.runner:
|
||||
if (not error and self._reuse_actors
|
||||
and self._cached_actor is None):
|
||||
logger.debug("Reusing actor for {}".format(trial.runner))
|
||||
logger.debug("Reusing actor for %s", trial.runner)
|
||||
self._cached_actor = trial.runner
|
||||
else:
|
||||
logger.debug("Trial %s: Destroying actor.", trial)
|
||||
@@ -216,7 +223,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.exception("Trial %s: Error stopping runner.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
finally:
|
||||
trial.runner = None
|
||||
trial.set_runner(None)
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
"""Starts the trial.
|
||||
@@ -229,49 +236,22 @@ class RayTrialExecutor(TrialExecutor):
|
||||
of trial.
|
||||
"""
|
||||
self._commit_resources(trial.resources)
|
||||
remote_runner = None
|
||||
attempts = 0
|
||||
while attempts < TRIAL_START_ATTEMPTS:
|
||||
attempts += 1
|
||||
if attempts > 1:
|
||||
logger.warning("Trial %s: Start attempt #%s...", trial,
|
||||
attempts)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint, remote_runner)
|
||||
break
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
break # don't retry fatal Tune errors
|
||||
except RayTimeoutError:
|
||||
# Reuse the existing runner on retries.
|
||||
remote_runner = trial.runner
|
||||
warning = ("Runner task timed out. This could be due to "
|
||||
"slow worker startup.")
|
||||
if attempts == TRIAL_START_ATTEMPTS:
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
else:
|
||||
warning += " Reusing the same runner."
|
||||
logger.warning("Trial %s: %s", trial, warning)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error starting runner.", trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
remote_runner = None
|
||||
# This forces the trial to not start from checkpoint.
|
||||
checkpoint = None
|
||||
trial.clear_checkpoint()
|
||||
# Note that we don't return the resources, since they may
|
||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||
else:
|
||||
logger.exception(
|
||||
"Trial %s: Aborting trial after %s start "
|
||||
"attempts!", trial, TRIAL_START_ATTEMPTS)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Unexpected error starting runner.",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
# Note that we don't return the resources, since they may
|
||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||
|
||||
def _find_item(self, dictionary, item):
|
||||
out = [rid for rid, t in dictionary.items() if t is item]
|
||||
@@ -332,7 +312,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def get_running_trials(self):
|
||||
"""Returns the running trials."""
|
||||
|
||||
return list(self._running.values())
|
||||
|
||||
def get_alive_node_ips(self):
|
||||
@@ -387,7 +366,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""Fetches one result of the running trials.
|
||||
|
||||
Returns:
|
||||
Result of the most recent trial training run."""
|
||||
Result of the most recent trial training run.
|
||||
"""
|
||||
trial_future = self._find_item(self._running, trial)
|
||||
if not trial_future:
|
||||
raise ValueError("Trial was not running.")
|
||||
@@ -437,6 +417,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"Resource invalid: {}".format(resources))
|
||||
|
||||
def _update_avail_resources(self, num_retries=5):
|
||||
resources = None
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
resources = ray.cluster_resources()
|
||||
@@ -520,7 +501,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
if self._resources_initialized:
|
||||
status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, "
|
||||
"{}/{} GiB heap, {}/{} GiB objects".format(
|
||||
@@ -548,7 +528,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def resource_string(self):
|
||||
"""Returns a string describing the total resources available."""
|
||||
|
||||
if self._resources_initialized:
|
||||
res_str = ("{} CPUs, {} GPUs, "
|
||||
"{} GiB heap, {} GiB objects".format(
|
||||
@@ -570,18 +549,28 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""Before step() called, update the available resources."""
|
||||
self._update_avail_resources()
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK, result=None):
|
||||
"""Saves the trial's state to a checkpoint."""
|
||||
result = result or trial.last_result
|
||||
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
|
||||
"""Saves the trial's state to a checkpoint.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
storage (str): Where to store the checkpoint. Defaults to
|
||||
PERSISTENT.
|
||||
result (dict): The state of this trial as a dictionary to be saved.
|
||||
If result is None, the trial's last result will be used.
|
||||
|
||||
Returns:
|
||||
Checkpoint future, or None if an Exception occurs.
|
||||
"""
|
||||
result = result or trial.last_result
|
||||
if storage == Checkpoint.MEMORY:
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
else:
|
||||
with warn_if_slow("save_checkpoint_to_disk"):
|
||||
with warn_if_slow("save_checkpoint_to_storage"):
|
||||
# TODO(ujvl): Make this asynchronous.
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
|
||||
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
|
||||
try:
|
||||
trial.on_checkpoint(checkpoint)
|
||||
@@ -600,13 +589,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
This will also sync the trial results to a new location
|
||||
if restoring on a different node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: This error is raised if no runner is found.
|
||||
RayTimeoutError: This error is raised if a remote call to the
|
||||
runner times out.
|
||||
AbortTrialExecution: This error is raised if the trial is
|
||||
ineligible for restoration, given the Tune input arguments.
|
||||
"""
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
checkpoint = trial.checkpoint
|
||||
@@ -617,19 +603,29 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"Trial {}: Unable to restore - no runner found.".format(trial))
|
||||
value = checkpoint.value
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
assert not isinstance(value, Checkpoint), type(value)
|
||||
logger.debug("Trial %s: Attempting restore from object", trial)
|
||||
# Note that we don't store the remote since in-memory checkpoints
|
||||
# don't guarantee fault tolerance and don't need to be waited on.
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
else:
|
||||
logger.info("Trial %s: Attempting restore from %s", trial, value)
|
||||
with warn_if_slow("get_current_ip"):
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
with warn_if_slow("sync_to_new_location"):
|
||||
trial.sync_logger_to_new_location(worker_ip)
|
||||
with warn_if_slow("restore_from_disk"):
|
||||
# TODO(ujvl): Take blocking restores out of the control loop.
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
trial.last_result = checkpoint.result
|
||||
logger.debug("Trial %s: Attempting restore from %s", trial, value)
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
remote = trial.runner.restore.remote(value)
|
||||
elif trial.sync_on_checkpoint:
|
||||
# This provides FT backwards compatibility in the
|
||||
# case where a DurableTrainable is not provided.
|
||||
logger.warning("Trial %s: Reading checkpoint into memory.",
|
||||
trial)
|
||||
data_dict = TrainableUtil.pickle_checkpoint(value)
|
||||
remote = trial.runner.restore_from_object.remote(data_dict)
|
||||
else:
|
||||
raise AbortTrialExecution(
|
||||
"Pass in `sync_on_checkpoint=True` for driver-based trial"
|
||||
"restoration. Pass in an `upload_dir` and a Trainable "
|
||||
"extending `DurableTrainable` for remote storage-based "
|
||||
"restoration")
|
||||
self._running[remote] = trial
|
||||
trial.restoring_from = checkpoint
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
|
||||
@@ -28,25 +28,31 @@ def noop(*args):
|
||||
return
|
||||
|
||||
|
||||
def get_sync_client(sync_function):
|
||||
def get_sync_client(sync_function, delete_function=None):
|
||||
"""Returns a sync client.
|
||||
|
||||
Args:
|
||||
sync_function (Optional[str|function]): Sync function.
|
||||
delete_function (Optional[str|function]): Delete function. Must be
|
||||
the same type as sync_function if it is provided.
|
||||
|
||||
Raises:
|
||||
ValueError if sync_function is malformed.
|
||||
ValueError if sync_function or delete_function are malformed.
|
||||
"""
|
||||
if sync_function is None:
|
||||
return None
|
||||
if delete_function and type(sync_function) != type(delete_function):
|
||||
raise ValueError("Sync and delete functions must be of same type.")
|
||||
if isinstance(sync_function, types.FunctionType):
|
||||
delete_function = delete_function or noop
|
||||
client_cls = FunctionBasedClient
|
||||
elif isinstance(sync_function, str):
|
||||
delete_function = delete_function or noop_template
|
||||
client_cls = CommandBasedClient
|
||||
else:
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
return client_cls(sync_function, sync_function)
|
||||
return client_cls(sync_function, sync_function, delete_function)
|
||||
|
||||
|
||||
def get_cloud_sync_client(remote_path):
|
||||
@@ -63,17 +69,19 @@ def get_cloud_sync_client(remote_path):
|
||||
raise ValueError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX))
|
||||
template = "aws s3 sync {source} {target}"
|
||||
template = "aws s3 sync {source} {target} --only-show-errors"
|
||||
delete_template = "aws s3 rm {target} --recursive --only-show-errors"
|
||||
elif remote_path.startswith(GS_PREFIX):
|
||||
if not distutils.spawn.find_executable("gsutil"):
|
||||
raise ValueError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GS_PREFIX))
|
||||
template = "gsutil rsync -r {source} {target}"
|
||||
delete_template = "gsutil rm -r {target}"
|
||||
else:
|
||||
raise ValueError("Upload uri must start with one of: {}"
|
||||
"".format(ALLOWED_REMOTE_PREFIXES))
|
||||
return CommandBasedClient(template, template)
|
||||
return CommandBasedClient(template, template, delete_template)
|
||||
|
||||
|
||||
class SyncClient:
|
||||
@@ -103,6 +111,17 @@ class SyncClient:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self, target):
|
||||
"""Deletes target.
|
||||
|
||||
Args:
|
||||
target (str): Target path.
|
||||
|
||||
Returns:
|
||||
True if delete initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self):
|
||||
"""Waits for current sync to complete, if asynchronously started."""
|
||||
pass
|
||||
@@ -113,9 +132,10 @@ class SyncClient:
|
||||
|
||||
|
||||
class FunctionBasedClient(SyncClient):
|
||||
def __init__(self, sync_up_func, sync_down_func):
|
||||
def __init__(self, sync_up_func, sync_down_func, delete_func=None):
|
||||
self.sync_up_func = sync_up_func
|
||||
self.sync_down_func = sync_down_func
|
||||
self.delete_func = delete_func or noop
|
||||
|
||||
def sync_up(self, source, target):
|
||||
self.sync_up_func(source, target)
|
||||
@@ -125,12 +145,19 @@ class FunctionBasedClient(SyncClient):
|
||||
self.sync_down_func(source, target)
|
||||
return True
|
||||
|
||||
def delete(self, target):
|
||||
self.delete_func(target)
|
||||
return True
|
||||
|
||||
|
||||
NOOP = FunctionBasedClient(noop, noop)
|
||||
|
||||
|
||||
class CommandBasedClient(SyncClient):
|
||||
def __init__(self, sync_up_template, sync_down_template):
|
||||
def __init__(self,
|
||||
sync_up_template,
|
||||
sync_down_template,
|
||||
delete_template=noop_template):
|
||||
"""Syncs between two directories with the given command.
|
||||
|
||||
Arguments:
|
||||
@@ -138,11 +165,14 @@ class CommandBasedClient(SyncClient):
|
||||
include replacement fields '{source}' and '{target}'.
|
||||
sync_down_template (str): A runnable string template; needs to
|
||||
include replacement fields '{source}' and '{target}'.
|
||||
delete_template (Optional[str]): A runnable string template; needs
|
||||
to include replacement field '{target}'. Noop by default.
|
||||
"""
|
||||
self._validate_sync_string(sync_up_template)
|
||||
self._validate_sync_string(sync_down_template)
|
||||
self.sync_up_template = sync_up_template
|
||||
self.sync_down_template = sync_down_template
|
||||
self.delete_template = delete_template
|
||||
self.logfile = None
|
||||
self.cmd_process = None
|
||||
|
||||
@@ -153,7 +183,7 @@ class CommandBasedClient(SyncClient):
|
||||
logdir (str): Log directory.
|
||||
"""
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync", dir=logdir, suffix=".log", delete=False)
|
||||
prefix="log_sync_out", dir=logdir, suffix=".log", delete=False)
|
||||
|
||||
def sync_up(self, source, target):
|
||||
return self._execute(self.sync_up_template, source, target)
|
||||
@@ -161,14 +191,27 @@ class CommandBasedClient(SyncClient):
|
||||
def sync_down(self, source, target):
|
||||
return self._execute(self.sync_down_template, source, target)
|
||||
|
||||
def delete(self, target):
|
||||
if self.is_running:
|
||||
logger.warning("Last sync client cmd still in progress, skipping.")
|
||||
return False
|
||||
final_cmd = self.delete_template.format(target=quote(target))
|
||||
logger.debug("Running delete: {}".format(final_cmd))
|
||||
self.cmd_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile)
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
if self.cmd_process:
|
||||
_, error_msg = self.cmd_process.communicate()
|
||||
error_msg = error_msg.decode("ascii")
|
||||
code = self.cmd_process.returncode
|
||||
args = self.cmd_process.args
|
||||
self.cmd_process = None
|
||||
if code != 0:
|
||||
raise TuneError("Sync error ({}): {}".format(code, error_msg))
|
||||
raise TuneError("Sync error. Ran command: {}\n"
|
||||
"Error message ({}): {}".format(
|
||||
args, code, error_msg))
|
||||
|
||||
def reset(self):
|
||||
if self.is_running:
|
||||
@@ -177,7 +220,7 @@ class CommandBasedClient(SyncClient):
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
"""Returns whether a sync process is running."""
|
||||
"""Returns whether a sync or delete process is running."""
|
||||
if self.cmd_process:
|
||||
self.cmd_process.poll()
|
||||
return self.cmd_process.returncode is None
|
||||
|
||||
@@ -191,6 +191,8 @@ class NodeSyncer(Syncer):
|
||||
def sync_down(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
logger.debug("Syncing from %s to %s", self._remote_path,
|
||||
self._local_dir)
|
||||
return super(NodeSyncer, self).sync_down()
|
||||
|
||||
@property
|
||||
@@ -250,23 +252,23 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
Args:
|
||||
local_dir (str): Source directory for syncing.
|
||||
remote_dir (str): Target directory for syncing. If not provided, a
|
||||
no-op Syncer is returned.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
noop Syncer is returned.
|
||||
sync_function (func|str|bool): Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults rsync.
|
||||
syncer to run. If True or not provided, it defaults rsync. If
|
||||
False, a noop Syncer is returned.
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
elif not remote_dir:
|
||||
elif not remote_dir or sync_function is False:
|
||||
sync_client = NOOP
|
||||
elif sync_function:
|
||||
elif sync_function and sync_function is not True:
|
||||
sync_client = get_sync_client(sync_function)
|
||||
else:
|
||||
sync_up = log_sync_template()
|
||||
sync_down = log_sync_template(options="--remove-source-files")
|
||||
if sync_up and sync_down:
|
||||
sync_client = CommandBasedClient(sync_up, sync_down)
|
||||
sync = log_sync_template()
|
||||
if sync:
|
||||
sync_client = CommandBasedClient(sync, sync)
|
||||
sync_client.set_logdir(local_dir)
|
||||
else:
|
||||
sync_client = NOOP
|
||||
|
||||
@@ -2,16 +2,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import shutil
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import DurableTrainable, Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.trial import Trial
|
||||
@@ -25,6 +28,7 @@ from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
|
||||
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
|
||||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
@@ -703,6 +707,34 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
]
|
||||
self.assertTrue(all(complete_results1))
|
||||
|
||||
def testDurableTrainable(self):
|
||||
class TestTrain(DurableTrainable):
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def _train(self):
|
||||
self.state["iter"] += 1
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
sync_client = mock_storage_client()
|
||||
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
|
||||
with patch(mock_get_client) as mock_get_cloud_sync_client:
|
||||
mock_get_cloud_sync_client.return_value = sync_client
|
||||
test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
checkpoint_path = test_trainable.save()
|
||||
test_trainable.train()
|
||||
test_trainable.state["hi"] = 2
|
||||
test_trainable.restore(checkpoint_path)
|
||||
self.assertEqual(test_trainable.state["hi"], 1)
|
||||
|
||||
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
def _setup(self, config):
|
||||
|
||||
@@ -16,23 +16,28 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
def mock_result(i):
|
||||
return {"i": i}
|
||||
|
||||
def checkpoint_manager(self, keep_checkpoints_num):
|
||||
return CheckpointManager(
|
||||
keep_checkpoints_num, "i", delete_fn=lambda c: None)
|
||||
|
||||
def testOnCheckpointOrdered(self):
|
||||
"""
|
||||
Tests increasing priorities. Also tests that that the worst checkpoints
|
||||
are deleted when necessary.
|
||||
"""
|
||||
keep_checkpoints_num = 2
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
|
||||
Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i))
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
|
||||
with patch.object(checkpoint_manager, "delete") as \
|
||||
delete_mock:
|
||||
for j in range(3):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 2 else 1
|
||||
self.assertEqual(rmtree_mock.call_count, expected_deletes)
|
||||
self.assertEqual(delete_mock.call_count, expected_deletes, j)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
@@ -47,17 +52,17 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
that the worst checkpoints are deleted when necessary.
|
||||
"""
|
||||
keep_checkpoints_num = 2
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i))
|
||||
Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i))
|
||||
for i in range(3, -1, -1)
|
||||
]
|
||||
|
||||
with patch("shutil.rmtree") as rmtree_mock, patch("os.path"):
|
||||
with patch.object(checkpoint_manager, "delete") as delete_mock:
|
||||
for j in range(0, len(checkpoints)):
|
||||
checkpoint_manager.on_checkpoint(checkpoints[j])
|
||||
expected_deletes = 0 if j != 3 else 1
|
||||
self.assertEqual(rmtree_mock.call_count, expected_deletes)
|
||||
self.assertEqual(delete_mock.call_count, expected_deletes)
|
||||
self.assertEqual(checkpoint_manager.newest_checkpoint,
|
||||
checkpoints[j])
|
||||
|
||||
@@ -71,7 +76,7 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
Tests that the best checkpoints are tracked and ordered correctly.
|
||||
"""
|
||||
keep_checkpoints_num = 4
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i))
|
||||
for i in range(16)
|
||||
@@ -92,7 +97,7 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
checkpoint has no checkpoint score attribute.
|
||||
"""
|
||||
keep_checkpoints_num = 1
|
||||
checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i")
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
|
||||
no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {})
|
||||
with patch.object(logger, "error") as log_error_mock:
|
||||
|
||||
@@ -9,20 +9,26 @@ import os
|
||||
import pytest
|
||||
import shutil
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.test_utils import run_string_as_driver_nonblocking
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils.mock import (MockDurableTrainer, MockRemoteTrainer,
|
||||
MockNodeSyncer, mock_storage_client,
|
||||
MOCK_REMOTE_DIR)
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
@@ -36,6 +42,8 @@ def _start_new_cluster():
|
||||
})
|
||||
})
|
||||
# Pytest doesn't play nicely with imports
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
_register_all()
|
||||
return cluster
|
||||
|
||||
@@ -53,7 +61,6 @@ def start_connected_cluster():
|
||||
@pytest.fixture
|
||||
def start_connected_emptyhead_cluster():
|
||||
"""Starts head with no resources."""
|
||||
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
@@ -65,6 +72,8 @@ def start_connected_emptyhead_cluster():
|
||||
})
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
@@ -208,7 +217,8 @@ def test_queue_trials(start_connected_emptyhead_cluster):
|
||||
assert gpu_trial.status == Trial.TERMINATED
|
||||
|
||||
|
||||
def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
"""Removing a node while cluster has space should migrate trial.
|
||||
|
||||
The trial state should also be consistent with the checkpoint.
|
||||
@@ -220,14 +230,16 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
"training_iteration": 4
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
t = Trial("__fake", **kwargs)
|
||||
t = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
@@ -241,13 +253,13 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
# because checkpoint handling is messy and should be refactored
|
||||
# rather than hotfixed.
|
||||
# assert t.last_result is None, "Trial result not restored correctly."
|
||||
for i in range(3):
|
||||
for i in range(4):
|
||||
runner.step()
|
||||
|
||||
assert t.status == Trial.TERMINATED
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t2 = Trial("__fake", **kwargs)
|
||||
t2 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t2)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
@@ -256,13 +268,23 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
node3 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node2)
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # 3 result + start and fail 4 result
|
||||
runner.step() # Recovery step
|
||||
runner.step() # Process recovery
|
||||
runner.step() # result
|
||||
if t2.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
assert t2.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
# Test recovery of trial that won't be checkpointed
|
||||
t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
t3 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t3)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
@@ -278,7 +300,8 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
runner.step()
|
||||
|
||||
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
@@ -290,10 +313,12 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
"training_iteration": 5
|
||||
},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1
|
||||
"max_failures": 1,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -309,7 +334,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
runner.step()
|
||||
|
||||
|
||||
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])
|
||||
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
||||
trainable_id):
|
||||
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
@@ -318,32 +345,59 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
"training_iteration": 4
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake_remote",
|
||||
}
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t1)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
assert t1.has_checkpoint()
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
# The following patches only affect __fake_remote.
|
||||
find_checkpoint_dir = TrainableUtil.find_checkpoint_dir
|
||||
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
|
||||
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
|
||||
with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir:
|
||||
|
||||
runner.step() # Recovery step
|
||||
for i in range(3):
|
||||
if t1.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
|
||||
client = mock_storage_client()
|
||||
return MockNodeSyncer(local_dir, remote_dir, client)
|
||||
|
||||
mock_get_node_syncer.side_effect = mock_get_syncer_fn
|
||||
|
||||
def mock_find_dir_fn(checkpoint_path):
|
||||
"""Converts back to local path first."""
|
||||
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
|
||||
checkpoint_path = os.path.join("/", checkpoint_path)
|
||||
return find_checkpoint_dir(checkpoint_path)
|
||||
|
||||
# __fake_remote trainables save to a separate "remote" directory.
|
||||
# TrainableUtil will not check this path unless we mock it.
|
||||
mock_find_dir.side_effect = mock_find_dir_fn
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t1)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
assert t1.has_checkpoint()
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
|
||||
runner.step() # collect result 3, kick off + fail result 4
|
||||
runner.step() # Recovery step
|
||||
runner.step() # Process Recovery + step 4
|
||||
for i in range(3):
|
||||
if t1.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
|
||||
def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||
"""Tests that TrialRunner save/restore works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
cluster.add_node(num_cpus=1)
|
||||
@@ -356,9 +410,11 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
"training_iteration": 2
|
||||
},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1
|
||||
"max_failures": 1,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -374,6 +430,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
cluster = _start_new_cluster()
|
||||
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
|
||||
runner.step() # start
|
||||
runner.step() # process restore
|
||||
runner.step() # start2
|
||||
|
||||
for i in range(3):
|
||||
@@ -387,26 +444,31 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_cluster_down_full(start_connected_cluster, tmpdir):
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
|
||||
"""Tests that run_experiment restoring works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
|
||||
exp1_args = dict(
|
||||
run="__fake",
|
||||
use_default_sync = trainable_id == "__fake"
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
local_dir = DEFAULT_RESULTS_DIR
|
||||
upload_dir = None if use_default_sync else MOCK_REMOTE_DIR
|
||||
|
||||
base_dict = dict(
|
||||
run=trainable_id,
|
||||
stop=dict(training_iteration=3),
|
||||
local_dir=dirpath,
|
||||
checkpoint_freq=1)
|
||||
exp2_args = dict(run="__fake", stop=dict(training_iteration=3))
|
||||
exp3_args = dict(
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
config=dict(mock_error=True))
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
sync_to_driver=use_default_sync,
|
||||
)
|
||||
|
||||
exp1_args = base_dict
|
||||
exp2_args = dict(base_dict.items(), local_dir=dirpath, checkpoint_freq=1)
|
||||
exp3_args = dict(base_dict.items(), config=dict(mock_error=True))
|
||||
exp4_args = dict(
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
config=dict(mock_error=True),
|
||||
checkpoint_freq=1)
|
||||
base_dict.items(), config=dict(mock_error=True), checkpoint_freq=1)
|
||||
|
||||
all_experiments = {
|
||||
"exp1": exp1_args,
|
||||
"exp2": exp2_args,
|
||||
@@ -414,14 +476,20 @@ def test_cluster_down_full(start_connected_cluster, tmpdir):
|
||||
"exp4": exp4_args
|
||||
}
|
||||
|
||||
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
|
||||
mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
|
||||
with patch(mock_get_client) as mock_get_cloud_syncer:
|
||||
mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client())
|
||||
mock_get_cloud_syncer.return_value = mock_syncer
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
|
||||
trials = tune.run_experiments(
|
||||
all_experiments, resume=True, raise_on_failed_trial=False)
|
||||
|
||||
trials = tune.run_experiments(
|
||||
all_experiments, resume=True, raise_on_failed_trial=False)
|
||||
assert len(trials) == 4
|
||||
assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
|
||||
ray.shutdown()
|
||||
|
||||
@@ -4,11 +4,25 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
|
||||
class ExperimentTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def setUp(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def testConvertExperimentFromExperiment(self):
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
|
||||
@@ -4,12 +4,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTimeoutError
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
@@ -41,33 +38,11 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.save(trial, Checkpoint.DISK)
|
||||
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.trial_executor.restore(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testSaveRestoreTimeout(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.save(trial, Checkpoint.DISK)
|
||||
self.trial_executor.set_status(trial, Trial.PAUSED)
|
||||
|
||||
ray_get = ray.get
|
||||
start_trial = self.trial_executor._start_trial
|
||||
|
||||
# Timeout on first two attempts, then succeed on subsequent gets.
|
||||
side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get]
|
||||
with patch.object(self.trial_executor, "_start_trial") as mock_start:
|
||||
with patch("ray.get", side_effect=side_effects):
|
||||
mock_start.side_effect = start_trial
|
||||
self.trial_executor.start_trial(trial, trial.checkpoint)
|
||||
|
||||
# Trial starts successfully on 3rd attempt.
|
||||
assert mock_start.call_count == 3
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
|
||||
def testPauseResume(self):
|
||||
"""Tests that pausing works for trials in flight."""
|
||||
trial = Trial("__fake")
|
||||
@@ -216,4 +191,5 @@ class LocalModeExecutorTest(RayTrialExecutorTest):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
|
||||
|
||||
class TrainableUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.checkpoint_dir = "/tmp/tune/MyTrainable123"
|
||||
TrainableUtil.make_checkpoint_dir(self.checkpoint_dir)
|
||||
|
||||
def tearDown(self):
|
||||
self.addCleanup(shutil.rmtree, self.checkpoint_dir)
|
||||
|
||||
def testFindCheckpointDir(self):
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt")
|
||||
os.makedirs(checkpoint_path)
|
||||
found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
self.assertEquals(self.checkpoint_dir, found_dir)
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
parent = os.path.dirname(found_dir)
|
||||
TrainableUtil.find_checkpoint_dir(parent)
|
||||
|
||||
def testPickleCheckpoint(self):
|
||||
for i in range(5):
|
||||
path = os.path.join(self.checkpoint_dir, str(i))
|
||||
with open(path, "w") as f:
|
||||
f.write(str(i))
|
||||
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, "0")
|
||||
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
loaded = pickle.loads(data_dict)
|
||||
|
||||
checkpoint_name = os.path.basename(checkpoint_path)
|
||||
self.assertEqual(loaded["checkpoint_name"], checkpoint_name)
|
||||
|
||||
for i in range(5):
|
||||
path = os.path.join(self.checkpoint_dir, str(i))
|
||||
self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
|
||||
@@ -19,9 +19,11 @@ from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
|
||||
class TrialRunnerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testTrialStatus(self):
|
||||
ray.init()
|
||||
|
||||
@@ -182,9 +182,11 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 2)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 3)
|
||||
@@ -242,6 +244,7 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
|
||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
|
||||
|
||||
@@ -203,7 +203,7 @@ class _MockTrialExecutor(TrialExecutor):
|
||||
def restore(self, trial, checkpoint=None):
|
||||
pass
|
||||
|
||||
def save(self, trial, type=Checkpoint.DISK, result=None):
|
||||
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
|
||||
return trial.trainable_name
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
|
||||
@@ -29,6 +29,57 @@ logger = logging.getLogger(__name__)
|
||||
SETUP_TIME_THRESHOLD = 10
|
||||
|
||||
|
||||
class TrainableUtil:
|
||||
@staticmethod
|
||||
def pickle_checkpoint(checkpoint_path):
|
||||
"""Pickles checkpoint data."""
|
||||
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
data = {}
|
||||
for basedir, _, file_names in os.walk(checkpoint_dir):
|
||||
for file_name in file_names:
|
||||
path = os.path.join(basedir, file_name)
|
||||
with open(path, "rb") as f:
|
||||
data[os.path.relpath(path, checkpoint_dir)] = f.read()
|
||||
# Use normpath so that a directory path isn't mapped to empty string.
|
||||
name = os.path.basename(os.path.normpath(checkpoint_path))
|
||||
name += os.path.sep if os.path.isdir(checkpoint_path) else ""
|
||||
data_dict = pickle.dumps({
|
||||
"checkpoint_name": name,
|
||||
"data": data,
|
||||
})
|
||||
return data_dict
|
||||
|
||||
@staticmethod
|
||||
def find_checkpoint_dir(checkpoint_path):
|
||||
"""Returns the directory containing the checkpoint path.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError if the directory is not found.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError("Path does not exist", checkpoint_path)
|
||||
if os.path.isdir(checkpoint_path):
|
||||
checkpoint_dir = checkpoint_path
|
||||
else:
|
||||
checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||
while checkpoint_dir != os.path.dirname(checkpoint_dir):
|
||||
if os.path.exists(os.path.join(checkpoint_dir, ".is_checkpoint")):
|
||||
break
|
||||
checkpoint_dir = os.path.dirname(checkpoint_dir)
|
||||
else:
|
||||
raise FileNotFoundError("Checkpoint directory not found for {}"
|
||||
.format(checkpoint_path))
|
||||
return checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def make_checkpoint_dir(checkpoint_dir):
|
||||
"""Creates a checkpoint directory at the provided path."""
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
# Drop marker in directory to identify it as a checkpoint dir.
|
||||
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
|
||||
|
||||
|
||||
class Trainable:
|
||||
"""Abstract class for trainable models, functions, etc.
|
||||
|
||||
@@ -119,17 +170,14 @@ class Trainable:
|
||||
>>> extra_cpu=config["workers"],
|
||||
>>> extra_gpu=int(config["use_gpu"]) * config["workers"])
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
"""
|
||||
"""Returns a help string for configuring this trainable's resources.
|
||||
|
||||
Args:
|
||||
config (dict): The Trainer's config dict.
|
||||
|
||||
Returns:
|
||||
str: A help string for configuring this trainable's resources.
|
||||
"""
|
||||
return ""
|
||||
|
||||
@@ -258,9 +306,7 @@ class Trainable:
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
"checkpoint_{}".format(self._iteration))
|
||||
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
TrainableUtil.make_checkpoint_dir(checkpoint_dir)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, string_types):
|
||||
@@ -270,6 +316,10 @@ class Trainable:
|
||||
"given checkpoint dir {}: {}".format(
|
||||
checkpoint_dir, checkpoint))
|
||||
checkpoint_path = checkpoint
|
||||
if os.path.isdir(checkpoint_path):
|
||||
# Add trailing slash to prevent tune metadata from
|
||||
# being written outside the directory.
|
||||
checkpoint_path = os.path.join(checkpoint_path, "")
|
||||
elif isinstance(checkpoint, dict):
|
||||
saved_as_dict = True
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
@@ -302,19 +352,8 @@ class Trainable:
|
||||
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
|
||||
checkpoint_path = self.save(tmpdir)
|
||||
# Save all files in subtree.
|
||||
data = {}
|
||||
for basedir, _, file_names in os.walk(tmpdir):
|
||||
for file_name in file_names:
|
||||
path = os.path.join(basedir, file_name)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data[os.path.relpath(path, tmpdir)] = f.read()
|
||||
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
data_dict = pickle.dumps({
|
||||
"checkpoint_name": os.path.relpath(checkpoint_path, tmpdir),
|
||||
"data": data,
|
||||
})
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
@@ -348,14 +387,15 @@ class Trainable:
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
self._restored = True
|
||||
logger.info("Restored from checkpoint: %s", checkpoint_path)
|
||||
logger.info("Restored on %s from checkpoint: %s", self.current_ip(),
|
||||
checkpoint_path)
|
||||
state = {
|
||||
"_iteration": self._iteration,
|
||||
"_timesteps_total": self._timesteps_total,
|
||||
"_time_total": self._time_total,
|
||||
"_episodes_total": self._episodes_total,
|
||||
}
|
||||
logger.info("Current state after restoring: {}".format(state))
|
||||
logger.info("Current state after restoring: %s", state)
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
@@ -379,6 +419,22 @@ class Trainable:
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def delete_checkpoint(self, checkpoint_path):
|
||||
"""Deletes local copy of checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to checkpoint.
|
||||
"""
|
||||
try:
|
||||
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
except FileNotFoundError:
|
||||
# The checkpoint won't exist locally if the
|
||||
# trial was rescheduled to another worker.
|
||||
logger.debug("Checkpoint not found during garbage collection.")
|
||||
return
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
def export_model(self, export_formats, export_dir=None):
|
||||
"""Exports model based on export_formats.
|
||||
|
||||
@@ -429,7 +485,7 @@ class Trainable:
|
||||
Note that the current working directory will also be changed to this.
|
||||
|
||||
"""
|
||||
return self._logdir
|
||||
return os.path.join(self._logdir, "")
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
|
||||
+117
-41
@@ -6,6 +6,7 @@ import ray.cloudpickle as cloudpickle
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
import time
|
||||
import tempfile
|
||||
@@ -13,6 +14,7 @@ import os
|
||||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
from ray.tune.util import flatten_dict
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
@@ -73,6 +75,30 @@ class ExportFormat:
|
||||
export_formats[i])
|
||||
|
||||
|
||||
def checkpoint_deleter(trial_id, runner):
|
||||
"""Returns a checkpoint deleter callback for a runner."""
|
||||
if not runner:
|
||||
return lambda checkpoint: None
|
||||
|
||||
def delete(checkpoint):
|
||||
"""Requests checkpoint deletion asynchronously.
|
||||
|
||||
Args:
|
||||
checkpoint (Checkpoint): Checkpoint to delete.
|
||||
"""
|
||||
if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
|
||||
logger.debug("Trial %s: Deleting checkpoint %s", trial_id,
|
||||
checkpoint.value)
|
||||
checkpoint_path = checkpoint.value
|
||||
# Delete local copy, if any exists.
|
||||
if os.path.exists(checkpoint_path):
|
||||
shutil.rmtree(checkpoint_path)
|
||||
# TODO(ujvl): Batch remote deletes.
|
||||
runner.delete_checkpoint.remote(checkpoint.value)
|
||||
|
||||
return delete
|
||||
|
||||
|
||||
class Trial:
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
@@ -98,6 +124,7 @@ class Trial:
|
||||
experiment_tag="",
|
||||
resources=None,
|
||||
stopping_criterion=None,
|
||||
remote_checkpoint_dir=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
@@ -137,7 +164,7 @@ class Trial:
|
||||
"clear the `resources_per_trial` option.".format(
|
||||
trainable_cls, default_resources))
|
||||
resources = default_resources
|
||||
self.address = Location()
|
||||
self.location = Location()
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.loggers = loggers
|
||||
@@ -148,18 +175,10 @@ class Trial:
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = {}
|
||||
self.last_update_time = -float("inf")
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
|
||||
# stores in memory max/min/last result for each metric by trial
|
||||
self.metric_analysis = {}
|
||||
|
||||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
newest_checkpoint = Checkpoint(Checkpoint.DISK, restore_path)
|
||||
self.checkpoint_manager = CheckpointManager(keep_checkpoints_num,
|
||||
checkpoint_score_attr)
|
||||
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
|
||||
|
||||
self.export_formats = export_formats
|
||||
self.status = Trial.PENDING
|
||||
self.start_time = None
|
||||
@@ -169,9 +188,27 @@ class Trial:
|
||||
self.last_debug = 0
|
||||
self.error_file = None
|
||||
self.error_msg = None
|
||||
self.num_failures = 0
|
||||
self.custom_trial_name = None
|
||||
|
||||
# Checkpointing fields
|
||||
if remote_checkpoint_dir:
|
||||
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
|
||||
else:
|
||||
self.remote_checkpoint_dir_prefix = None
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
keep_checkpoints_num, checkpoint_score_attr,
|
||||
checkpoint_deleter(str(self), self.runner))
|
||||
self.checkpoint_manager.newest_checkpoint = newest_checkpoint
|
||||
|
||||
# Restoration fields
|
||||
self.restoring_from = None
|
||||
self.num_failures = 0
|
||||
self.num_consecutive_start_attempts = 0
|
||||
|
||||
# AutoML fields
|
||||
self.results = None
|
||||
self.best_result = None
|
||||
@@ -179,7 +216,6 @@ class Trial:
|
||||
self.extra_arg = None
|
||||
|
||||
self._nonjson_fields = [
|
||||
"checkpoint",
|
||||
"loggers",
|
||||
"sync_to_driver_fn",
|
||||
"results",
|
||||
@@ -192,7 +228,7 @@ class Trial:
|
||||
|
||||
@property
|
||||
def node_ip(self):
|
||||
return self.address.hostname
|
||||
return self.location.hostname
|
||||
|
||||
@property
|
||||
def checkpoint(self):
|
||||
@@ -202,6 +238,14 @@ class Trial:
|
||||
def generate_id(cls):
|
||||
return str(uuid.uuid1().hex)[:8]
|
||||
|
||||
@property
|
||||
def remote_checkpoint_dir(self):
|
||||
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
|
||||
if not self.remote_checkpoint_dir_prefix:
|
||||
return None
|
||||
logdir_name = os.path.basename(self.logdir)
|
||||
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
|
||||
|
||||
@classmethod
|
||||
def create_logdir(cls, identifier, local_dir):
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
@@ -213,7 +257,6 @@ class Trial:
|
||||
|
||||
def init_logger(self):
|
||||
"""Init logger."""
|
||||
|
||||
if not self.result_logger:
|
||||
if not self.logdir:
|
||||
self.logdir = Trial.create_logdir(str(self), self.local_dir)
|
||||
@@ -239,24 +282,20 @@ class Trial:
|
||||
raise ValueError("Cannot update resources while Trial is running.")
|
||||
self.resources = Resources(cpu, gpu, **kwargs)
|
||||
|
||||
def sync_logger_to_new_location(self, worker_ip):
|
||||
"""Updates the logger location.
|
||||
|
||||
Also pushes logdir to worker_ip, allowing for cross-node recovery.
|
||||
"""
|
||||
if self.result_logger:
|
||||
self.result_logger.sync_results_to_new_location(worker_ip)
|
||||
self.set_location(Location(worker_ip))
|
||||
def set_runner(self, runner):
|
||||
self.runner = runner
|
||||
self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner)
|
||||
|
||||
def set_location(self, location):
|
||||
"""Sets the location of the trial."""
|
||||
self.address = location
|
||||
self.location = location
|
||||
|
||||
def set_status(self, status):
|
||||
"""Sets the status of the trial."""
|
||||
if status == Trial.RUNNING and self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
self.status = status
|
||||
if status == Trial.RUNNING:
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def close_logger(self):
|
||||
"""Closes logger."""
|
||||
@@ -266,7 +305,7 @@ class Trial:
|
||||
|
||||
def write_error_log(self, error_msg):
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1 # may be moved to outer scope?
|
||||
self.num_failures += 1
|
||||
self.error_file = os.path.join(self.logdir, "error.txt")
|
||||
with open(self.error_file, "a+") as f:
|
||||
f.write("Failure # {} (occurred at {})\n".format(
|
||||
@@ -276,7 +315,6 @@ class Trial:
|
||||
|
||||
def should_stop(self, result):
|
||||
"""Whether the given result meets this trial's stopping criteria."""
|
||||
|
||||
if result.get(DONE):
|
||||
return True
|
||||
|
||||
@@ -309,6 +347,7 @@ class Trial:
|
||||
|
||||
def clear_checkpoint(self):
|
||||
self.checkpoint.value = None
|
||||
self.restoring_from = None
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Hook for handling checkpoints taken by the Trainable.
|
||||
@@ -316,20 +355,52 @@ class Trial:
|
||||
Args:
|
||||
checkpoint (Checkpoint): Checkpoint taken.
|
||||
"""
|
||||
if self.sync_on_checkpoint and checkpoint.storage == Checkpoint.DISK:
|
||||
# Wait for any other syncs to finish. We need to sync again after
|
||||
# this to handle checkpoints taken mid-sync.
|
||||
self.result_logger.wait()
|
||||
# Force sync down and wait before tracking the new checkpoint. This
|
||||
# prevents attempts to restore from partially synced checkpoints.
|
||||
if self.result_logger.sync_down():
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
# TODO(ujvl): Handle this separately to avoid restoration failure.
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
return
|
||||
if self.sync_on_checkpoint:
|
||||
try:
|
||||
# Wait for any other syncs to finish. We need to sync again
|
||||
# after this to handle checkpoints taken mid-sync.
|
||||
self.result_logger.wait()
|
||||
else:
|
||||
except TuneError as e:
|
||||
# Errors occurring during this wait are not fatal for this
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
"Trial %s: Checkpoint sync skipped. "
|
||||
"This should not happen.", self)
|
||||
"Trial %s: An error occurred during the "
|
||||
"checkpoint pre-sync wait.", str(e))
|
||||
# Force sync down and wait before tracking the new checkpoint.
|
||||
try:
|
||||
if self.result_logger.sync_down():
|
||||
self.result_logger.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Checkpoint sync skipped. "
|
||||
"This should not happen.", self)
|
||||
except TuneError as e:
|
||||
if issubclass(self.get_trainable_cls(), DurableTrainable):
|
||||
# Even though rsync failed the trainable can restore
|
||||
# from remote durable storage.
|
||||
logger.error("Trial %s: Sync error - %s", self, str(e))
|
||||
else:
|
||||
# If the trainable didn't have remote storage to upload
|
||||
# to then this checkpoint may have been lost, so we
|
||||
# shouldn't track it with the checkpoint_manager.
|
||||
raise e
|
||||
if not issubclass(self.get_trainable_cls(), DurableTrainable):
|
||||
if not os.path.exists(checkpoint.value):
|
||||
raise TuneError("Trial {}: Checkpoint path {} not "
|
||||
"found after successful sync down.".format(
|
||||
self, checkpoint.value))
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
def on_restore(self):
|
||||
"""Handles restoration completion."""
|
||||
assert self.is_restoring
|
||||
self.last_result = self.restoring_from.result
|
||||
self.restoring_from = None
|
||||
|
||||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for retrying.
|
||||
|
||||
@@ -375,7 +446,11 @@ class Trial:
|
||||
self.verbose = verbose
|
||||
|
||||
def is_finished(self):
|
||||
return self.status in [Trial.TERMINATED, Trial.ERROR]
|
||||
return self.status in [Trial.ERROR, Trial.TERMINATED]
|
||||
|
||||
@property
|
||||
def is_restoring(self):
|
||||
return self.restoring_from is not None
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
@@ -383,7 +458,7 @@ class Trial:
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
|
||||
|
||||
Can be overriden with a custom string creator.
|
||||
Can be overridden with a custom string creator.
|
||||
"""
|
||||
if self.custom_trial_name:
|
||||
return self.custom_trial_name
|
||||
@@ -402,9 +477,9 @@ class Trial:
|
||||
"""Memento generator for Trial.
|
||||
|
||||
Sets RUNNING trials to PENDING, and flushes the result logger.
|
||||
Note this can only occur if the trial holds a DISK checkpoint.
|
||||
Note this can only occur if the trial holds a PERSISTENT checkpoint.
|
||||
"""
|
||||
assert self.checkpoint.storage == Checkpoint.DISK, (
|
||||
assert self.checkpoint.storage == Checkpoint.PERSISTENT, (
|
||||
"Checkpoint must not be in-memory.")
|
||||
state = self.__dict__.copy()
|
||||
state["resources"] = resources_to_json(self.resources)
|
||||
@@ -415,7 +490,7 @@ class Trial:
|
||||
state["runner"] = None
|
||||
state["result_logger"] = None
|
||||
if self.result_logger:
|
||||
self.result_logger.flush()
|
||||
self.result_logger.flush(sync_down=False)
|
||||
state["__logger_started__"] = True
|
||||
else:
|
||||
state["__logger_started__"] = False
|
||||
@@ -424,6 +499,7 @@ class Trial:
|
||||
def __setstate__(self, state):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
|
||||
if state["status"] == Trial.RUNNING:
|
||||
state["status"] = Trial.PENDING
|
||||
for key in self._nonjson_fields:
|
||||
|
||||
@@ -39,8 +39,11 @@ class TrialExecutor:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
status (Trial.status): Status to set trial to.
|
||||
"""
|
||||
logger.debug("Trial %s: Changing status from %s to %s.", trial,
|
||||
trial.status, status)
|
||||
if trial.status == status:
|
||||
logger.debug("Trial %s: Status %s unchanged.", trial, trial.status)
|
||||
else:
|
||||
logger.debug("Trial %s: Changing status from %s to %s.", trial,
|
||||
trial.status, status)
|
||||
trial.set_status(status)
|
||||
if status in [Trial.TERMINATED, Trial.ERROR]:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
@@ -226,14 +229,15 @@ class TrialExecutor:
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"restore() method")
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK, result=None):
|
||||
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
|
||||
"""Saves training state of this trial to a checkpoint.
|
||||
|
||||
If result is None, this trial's last result will be used.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
storage (str): Where to store the checkpoint. Defaults to DISK.
|
||||
storage (str): Where to store the checkpoint. Defaults to
|
||||
PERSISTENT.
|
||||
result (dict): The state of this trial as a dictionary to be saved.
|
||||
|
||||
Return:
|
||||
|
||||
@@ -7,7 +7,6 @@ from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
@@ -31,12 +30,6 @@ MAX_DEBUG_TRIALS = 20
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _naturalize(string):
|
||||
"""Provides a natural representation for string for nice sorting."""
|
||||
splits = re.split("([0-9]+)", string)
|
||||
return [int(text) if text.isdigit() else text.lower() for text in splits]
|
||||
|
||||
|
||||
def _find_newest_ckpt(ckpt_dir):
|
||||
"""Returns path to most recently modified checkpoint."""
|
||||
full_paths = [
|
||||
@@ -222,6 +215,9 @@ class TrialRunner:
|
||||
|
||||
# Try syncing down the upload directory.
|
||||
logger.info("Downloading from %s", self._remote_checkpoint_dir)
|
||||
# TODO(ujvl): Note that this syncs down the entire directory,
|
||||
# which may also contain trial checkpoints. We should selectively
|
||||
# sync the necessary files instead.
|
||||
self._syncer.sync_down_if_needed()
|
||||
|
||||
if not self.checkpoint_exists(self._local_checkpoint_dir):
|
||||
@@ -417,11 +413,18 @@ class TrialRunner:
|
||||
with warn_if_slow("process_failed_trial"):
|
||||
self._process_trial_failure(failed_trial, error_msg=error_msg)
|
||||
else:
|
||||
# TODO(ujvl): Consider combining get_next_available_trial and
|
||||
# fetch_result functionality so that we don't timeout on fetch.
|
||||
trial = self.trial_executor.get_next_available_trial() # blocking
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
if trial.is_restoring:
|
||||
with warn_if_slow("process_trial_restore"):
|
||||
self._process_trial_restore(trial)
|
||||
else:
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
|
||||
def _process_trial(self, trial):
|
||||
"""Processes a trial result."""
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
|
||||
@@ -479,7 +482,24 @@ class TrialRunner:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
except Exception:
|
||||
logger.exception("Error processing event.")
|
||||
logger.exception("Trial %s: Error processing event.", trial)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _process_trial_restore(self, trial):
|
||||
"""Processes a trial restore.
|
||||
|
||||
Args:
|
||||
trial: Trial being restored.
|
||||
"""
|
||||
logger.debug("Trial %s: Processing trial restore.", trial)
|
||||
try:
|
||||
self.trial_executor.fetch_result(trial)
|
||||
trial.on_restore()
|
||||
logger.debug("Trial %s: Restore processed successfully", trial)
|
||||
self.trial_executor.set_status(trial, Trial.RUNNING)
|
||||
self.trial_executor.continue_training(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error processing restore.", trial)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _process_trial_failure(self, trial, error_msg):
|
||||
@@ -504,8 +524,8 @@ class TrialRunner:
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
if trial.should_checkpoint() or force:
|
||||
# Save trial runtime if possible
|
||||
if hasattr(trial, "runner") and trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
||||
if trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
@@ -517,30 +537,32 @@ class TrialRunner:
|
||||
trial (Trial): Trial to recover.
|
||||
error_msg (str): Error message from prior to invoking this method.
|
||||
"""
|
||||
try:
|
||||
self.trial_executor.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.trial_executor.has_resources(trial.resources):
|
||||
logger.info(
|
||||
"Trial %s: Attempting to recover "
|
||||
"trial state from last checkpoint.", trial)
|
||||
self.trial_executor.start_trial(trial)
|
||||
if trial.status == Trial.ERROR:
|
||||
logger.error("Trial %s: Did not start correctly.", trial)
|
||||
raise RuntimeError("Trial did not start correctly.")
|
||||
logger.debug("Trial %s: Started correctly.", trial)
|
||||
if trial.is_restoring:
|
||||
# Restore was unsuccessful, try again without checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
self.trial_executor.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.trial_executor.has_resources(trial.resources):
|
||||
logger.info(
|
||||
"Trial %s: Attempting to restore "
|
||||
"trial state from last checkpoint.", trial)
|
||||
self.trial_executor.start_trial(trial)
|
||||
if trial.status == Trial.ERROR:
|
||||
logger.exception(
|
||||
"Trial %s: Error restoring trial from checkpoint, abort.",
|
||||
trial)
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
else:
|
||||
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
|
||||
trial)
|
||||
self._requeue_trial(trial)
|
||||
except Exception:
|
||||
logger.exception("Error recovering trial from checkpoint, abort.")
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
logger.debug("Trial %s: Restore dispatched correctly.", trial)
|
||||
else:
|
||||
logger.debug("Trial %s: Notifying Scheduler and requeueing.",
|
||||
trial)
|
||||
self._requeue_trial(trial)
|
||||
|
||||
def _requeue_trial(self, trial):
|
||||
"""Notification to TrialScheduler and requeue trial.
|
||||
|
||||
+10
-10
@@ -119,8 +119,8 @@ def run(run_or_experiment,
|
||||
`num_samples` of times.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results to
|
||||
(e.g. ``s3://bucket``).
|
||||
upload_dir (str): Optional URI to sync training results and checkpoints
|
||||
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
||||
trial_name_creator (func): Optional function for generating
|
||||
the trial string representation.
|
||||
loggers (list): List of logger creators to be used with
|
||||
@@ -130,20 +130,20 @@ def run(run_or_experiment,
|
||||
from upload_dir. If string, then it must be a string template that
|
||||
includes `{source}` and `{target}` for the syncer to run. If not
|
||||
provided, the sync command defaults to standard S3 or gsutil sync
|
||||
comamnds.
|
||||
sync_to_driver (func|str): Function for syncing trial logdir from
|
||||
commands.
|
||||
sync_to_driver (func|str|bool): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If not provided, defaults to using rsync.
|
||||
If True or not provided, it defaults to using rsync. If False,
|
||||
syncing to driver is disabled.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint, to
|
||||
guarantee recoverability. If set to False, checkpoint syncing from
|
||||
worker to driver is asynchronous. Set this to False only if
|
||||
synchronous checkpointing is too slow and trial restoration
|
||||
failures can be tolerated. Defaults to True.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
|
||||
driver. If set to False, checkpoint syncing from worker to driver
|
||||
is asynchronous and best-effort. This does not affect persistent
|
||||
storage syncing. Defaults to True.
|
||||
keep_checkpoints_num (int): Number of checkpoints to keep. A value of
|
||||
`None` keeps all checkpoints. Defaults to `None`. If set, need
|
||||
to provide `checkpoint_score_attr`.
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.rllib.agents.mock import _MockTrainer
|
||||
from ray.tune import DurableTrainable
|
||||
from ray.tune.sync_client import get_sync_client
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
|
||||
MOCK_REMOTE_DIR = "/tmp/mock-tune-remote/"
|
||||
# Sync and delete templates that operate on local directories.
|
||||
LOCAL_SYNC_TEMPLATE = "mkdir -p {target} && rsync -avz {source}/ {target}/"
|
||||
LOCAL_DELETE_TEMPLATE = "rm -rf {target}"
|
||||
|
||||
|
||||
def mock_storage_client():
|
||||
"""Mocks storage client that treats a local dir as durable storage."""
|
||||
return get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE)
|
||||
|
||||
|
||||
class MockNodeSyncer(NodeSyncer):
|
||||
"""Mock NodeSyncer that syncs to and from /tmp"""
|
||||
|
||||
def has_remote_target(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
if self._remote_dir.startswith("/"):
|
||||
self._remote_dir = self._remote_dir[1:]
|
||||
return os.path.join(MOCK_REMOTE_DIR, self._remote_dir)
|
||||
|
||||
|
||||
class MockRemoteTrainer(_MockTrainer):
|
||||
"""Mock Trainable that saves at tmp for simulated clusters."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MockRemoteTrainer, self).__init__(*args, **kwargs)
|
||||
if self._logdir.startswith("/"):
|
||||
self._logdir = self._logdir[1:]
|
||||
self._logdir = os.path.join(MOCK_REMOTE_DIR, self._logdir)
|
||||
if not os.path.exists(self._logdir):
|
||||
os.makedirs(self._logdir)
|
||||
|
||||
|
||||
class MockDurableTrainer(DurableTrainable, _MockTrainer):
|
||||
"""Mock DurableTrainable that saves at tmp for simulated clusters."""
|
||||
|
||||
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned
|
||||
# up once the durable training API converges.
|
||||
|
||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
||||
_MockTrainer.__init__(self, *args, **kwargs)
|
||||
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)
|
||||
|
||||
def _create_storage_client(self):
|
||||
return mock_storage_client()
|
||||
Reference in New Issue
Block a user