[tune] Change the log syncing behavior (#4450)

* Change the log syncing behavior

* fix up abstractions for syncer

* Finished checkpoint syncing

* Code

* Set of changes to get things running

* Fixes for log syncing

* Fix parts

* Lint and other fixes

* fix some test

* Remove extra parsing functionality

* some test fixes

* Fix up cloud syncing

* Another thing to do

* Fix up tests and local sync

Changes LogSync into a mixin, and adds tests for different
functionalities.

* Fix up tests, start on local migration

* fix distributed migrations

* comments

* formatting

* Better checkpoint directory handling

* fix tests

* fix tests

* fix click

* comments

* formatting comments

* formatting and comments

* sync function deprecations

* syncfunction

* Add documentation for Syncing and Uploading

* nit

* BaseSyncer as base for Mixin in edge case

* more docs

* clean up assertions

* validate

* nit

* Update test_cluster.py

* betterdoc

* Update tune-usage.rst

* cleanup

* nit
This commit is contained in:
Kristian Hartikainen
2019-07-02 20:46:00 -07:00
committed by Richard Liaw
parent 71d4637b75
commit 9e0192bc0b
14 changed files with 718 additions and 466 deletions
+12
View File
@@ -10,6 +10,7 @@ import yaml
import ray
from ray.tests.cluster_utils import Cluster
from ray.tune.config_parser import make_parser
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments
@@ -71,6 +72,17 @@ def create_parser(parser_creator=None):
default="default",
type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--local-dir",
default=DEFAULT_RESULTS_DIR,
type=str,
help="Local dir to save training results to. Defaults to '{}'.".format(
DEFAULT_RESULTS_DIR))
parser.add_argument(
"--upload-dir",
default="",
type=str,
help="Optional URI to sync training results to (e.g. s3://bucket).")
parser.add_argument(
"--resume",
action="store_true",
+2 -15
View File
@@ -10,7 +10,6 @@ import os
from six import string_types
from ray.tune import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Trial, json_to_resources
from ray.tune.logger import _SafeFallbackEncoder
@@ -65,17 +64,6 @@ def make_parser(parser_creator=None, **kwargs):
default=1,
type=int,
help="Number of times to repeat each trial.")
parser.add_argument(
"--local-dir",
default=DEFAULT_RESULTS_DIR,
type=str,
help="Local dir to save training results to. Defaults to '{}'.".format(
DEFAULT_RESULTS_DIR))
parser.add_argument(
"--upload-dir",
default="",
type=str,
help="Optional URI to sync training results to (e.g. s3://bucket).")
parser.add_argument(
"--checkpoint-freq",
default=0,
@@ -183,7 +171,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
trainable_name=spec["run"],
# json.load leads to str -> unicode in py2.7
config=spec.get("config", {}),
local_dir=os.path.join(args.local_dir, output_path),
local_dir=os.path.join(spec["local_dir"], output_path),
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
@@ -193,10 +181,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
export_formats=spec.get("export_formats", []),
# str(None) doesn't create None
restore_path=spec.get("restore"),
upload_dir=args.upload_dir,
trial_name_creator=spec.get("trial_name_creator"),
loggers=spec.get("loggers"),
# str(None) doesn't create None
sync_function=spec.get("sync_function"),
sync_to_driver_fn=spec.get("sync_to_driver"),
max_failures=args.max_failures,
**trial_kwargs)
+5 -8
View File
@@ -11,9 +11,8 @@ import random
import numpy as np
import ray
from ray import tune
from ray.tune import Trainable, run, Experiment
from ray.tune import Trainable, run
class TestLogger(tune.logger.Logger):
@@ -60,11 +59,11 @@ if __name__ == "__main__":
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
exp = Experiment(
trials = run(
MyTrainableClass,
name="hyperband_test",
run=MyTrainableClass,
num_samples=1,
num_samples=5,
trial_name_creator=tune.function(trial_str_creator),
loggers=[TestLogger],
stop={"training_iteration": 1 if args.smoke_test else 99999},
@@ -73,5 +72,3 @@ if __name__ == "__main__":
lambda spec: 10 + int(90 * random.random())),
"height": tune.sample_from(lambda spec: int(100 * random.random()))
})
trials = run(exp)
+15 -12
View File
@@ -52,7 +52,6 @@ class Experiment(object):
>>> },
>>> num_samples=10,
>>> local_dir="~/ray_results",
>>> upload_dir="s3://your_bucket/path",
>>> checkpoint_freq=10,
>>> max_failures=2)
"""
@@ -68,7 +67,7 @@ class Experiment(object):
upload_dir=None,
trial_name_creator=None,
loggers=None,
sync_function=None,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
keep_checkpoints_num=None,
@@ -78,18 +77,16 @@ class Experiment(object):
restore=None,
repeat=None,
trial_resources=None,
custom_loggers=None):
if sync_function:
assert upload_dir, "Need `upload_dir` if sync_function given."
custom_loggers=None,
sync_function=None):
if repeat:
_raise_deprecation_note("repeat", "num_samples", soft=False)
if trial_resources:
_raise_deprecation_note(
"trial_resources", "resources_per_trial", soft=False)
if custom_loggers:
_raise_deprecation_note("custom_loggers", "loggers", soft=False)
if sync_function:
_raise_deprecation_note(
"sync_function", "sync_to_driver", soft=False)
run_identifier = Experiment._register_if_needed(run)
spec = {
"run": run_identifier,
@@ -98,10 +95,10 @@ class Experiment(object):
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
"upload_dir": upload_dir or "", # argparse converts None to "null"
"upload_dir": upload_dir,
"trial_name_creator": trial_name_creator,
"loggers": loggers,
"sync_function": sync_function,
"sync_to_driver": sync_to_driver,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"keep_checkpoints_num": keep_checkpoints_num,
@@ -182,7 +179,13 @@ class Experiment(object):
@property
def checkpoint_dir(self):
return os.path.join(self.spec["local_dir"], self.name)
if self.local_dir:
return os.path.join(self.local_dir, self.name)
@property
def remote_checkpoint_dir(self):
if self.spec["upload_dir"]:
return os.path.join(self.spec["upload_dir"], self.name)
def convert_to_experiment_list(experiments):
+41 -208
View File
@@ -4,11 +4,6 @@ from __future__ import print_function
import distutils.spawn
import logging
import os
import subprocess
import tempfile
import time
import types
try: # py3
from shlex import quote
@@ -17,231 +12,69 @@ except ImportError: # py2
import ray
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import function as tune_function
logger = logging.getLogger(__name__)
_log_sync_warned = False
# Map from (logdir, remote_dir) -> syncer
_syncers = {}
S3_PREFIX = "s3://"
GCS_PREFIX = "gs://"
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX)
def log_sync_template():
"""Syncs the local_dir between driver and worker if possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
def get_syncer(local_dir, remote_dir=None, sync_function=None):
if remote_dir:
if not sync_function and not any(
remote_dir.startswith(prefix)
for prefix in ALLOWED_REMOTE_PREFIXES):
raise TuneError("Upload uri must start with one of: {}"
"".format(ALLOWED_REMOTE_PREFIXES))
if (remote_dir.startswith(S3_PREFIX)
and not distutils.spawn.find_executable("aws")):
raise TuneError(
"Upload uri starting with '{}' requires awscli tool"
" to be installed".format(S3_PREFIX))
elif (remote_dir.startswith(GCS_PREFIX)
and not distutils.spawn.find_executable("gsutil")):
raise TuneError(
"Upload uri starting with '{}' requires gsutil tool"
" to be installed".format(GCS_PREFIX))
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
remote_dir = os.path.join(remote_dir, rel_path)
key = (local_dir, remote_dir)
if key not in _syncers:
_syncers[key] = _LogSyncer(local_dir, remote_dir, sync_function)
return _syncers[key]
def wait_for_log_sync():
for syncer in _syncers.values():
syncer.wait()
def validate_sync_function(sync_function):
if sync_function is None:
return
elif isinstance(sync_function, str):
assert "{remote_dir}" in sync_function, (
"Sync template missing '{remote_dir}'.")
assert "{local_dir}" in sync_function, (
"Sync template missing '{local_dir}'.")
elif not (isinstance(sync_function, types.FunctionType)
or isinstance(sync_function, tune_function)):
raise ValueError("Sync function {} must be string or function".format(
sync_function))
class _LogSyncer(object):
"""Log syncer for tune.
This syncs files from workers to the local node, and optionally also from
the local node to a remote directory (e.g. S3).
Arguments:
logdir (str): Directory to sync from.
upload_uri (str): Directory to sync to.
sync_function (func|str): Function for syncing the local_dir to
upload_dir. If string, then it must be a string template
for syncer to run and needs to include replacement fields
'{local_dir}' and '{remote_dir}'.
"""
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
return
global _log_sync_warned
ssh_key = get_ssh_key()
if ssh_key is None:
if not _log_sync_warned:
logger.error("Log sync requires cluster to be setup with "
"`ray up`.")
_log_sync_warned = True
return
def __init__(self, local_dir, remote_dir=None, sync_function=None):
self.local_dir = local_dir
self.remote_dir = remote_dir
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync", dir=self.local_dir, suffix=".log", delete=False)
return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {{source}} {{target}}"""
).format(ssh_key=quote(ssh_key))
# Resolve sync_function into template or function
self.sync_func = None
self.sync_cmd_tmpl = None
if isinstance(sync_function, types.FunctionType) or isinstance(
sync_function, tune_function):
self.sync_func = sync_function
elif isinstance(sync_function, str):
self.sync_cmd_tmpl = sync_function
self.last_sync_time = 0
self.sync_process = None
class NodeSyncMixin():
"""Mixin for syncing files to/from a remote dir to a local dir."""
def __init__(self):
assert hasattr(self, "_remote_dir"), "Mixin not mixed with Syncer."
self.local_ip = ray.services.get_node_ip_address()
self.worker_ip = None
logger.debug("Created LogSyncer for {} -> {}".format(
local_dir, remote_dir))
def close(self):
self.logfile.close()
def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip
def sync_if_needed(self):
if time.time() - self.last_sync_time > 300:
self.sync_now()
def sync_to_worker_if_possible(self):
"""Syncs the local logdir on driver to worker if possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
"""
def _check_valid_worker_ip(self):
if not self.worker_ip:
logger.info("Worker ip unknown, skipping log sync for {}".format(
self._local_dir))
return False
if self.worker_ip == self.local_ip:
return
ssh_key = get_ssh_key()
logger.debug(
"Worker ip is local ip, skipping log sync for {}".format(
self._local_dir))
return False
return True
@property
def _remote_path(self):
ssh_user = get_ssh_user()
global _log_sync_warned
if ssh_key is None or ssh_user is None:
if not self._check_valid_worker_ip():
return
if ssh_user is None:
if not _log_sync_warned:
logger.error("Log sync requires cluster to be setup with "
"`ray up`.")
_log_sync_warned = True
return
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
return
source = "{}/".format(self.local_dir)
target = "{}@{}:{}/".format(ssh_user, self.worker_ip, self.local_dir)
final_cmd = (("""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {} {}""").format(
quote(ssh_key), quote(source), quote(target)))
logger.info("Syncing results to %s", str(self.worker_ip))
sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
sync_process.wait()
def sync_now(self, force=False):
self.last_sync_time = time.time()
if not self.worker_ip:
logger.debug("Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return
if self.worker_ip == self.local_ip:
worker_to_local_sync_cmd = None # don't need to rsync
else:
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
global _log_sync_warned
if ssh_key is None or ssh_user is None:
if not _log_sync_warned:
logger.error("Log sync requires cluster to be setup with "
"`ray up`.")
_log_sync_warned = True
return
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
return
source = "{}@{}:{}/".format(ssh_user, self.worker_ip,
self.local_dir)
target = "{}/".format(self.local_dir)
worker_to_local_sync_cmd = ((
"""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {} {}""").format(
quote(ssh_key), quote(source), quote(target)))
if self.remote_dir:
if self.sync_func:
local_to_remote_sync_cmd = None
try:
self.sync_func(self.local_dir, self.remote_dir)
except Exception:
logger.exception("Sync function failed.")
else:
local_to_remote_sync_cmd = self.get_remote_sync_cmd()
else:
local_to_remote_sync_cmd = None
if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
if force:
self.sync_process.kill()
else:
logger.warning("Last sync is still in progress, skipping.")
return
if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
final_cmd = ""
if worker_to_local_sync_cmd:
final_cmd += worker_to_local_sync_cmd
if local_to_remote_sync_cmd:
if final_cmd:
final_cmd += " && "
final_cmd += local_to_remote_sync_cmd
logger.debug("Running log sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
def wait(self):
if self.sync_process:
self.sync_process.wait()
def get_remote_sync_cmd(self):
if self.sync_cmd_tmpl:
local_to_remote_sync_cmd = (self.sync_cmd_tmpl.format(
local_dir=quote(self.local_dir),
remote_dir=quote(self.remote_dir)))
elif self.remote_dir.startswith(S3_PREFIX):
local_to_remote_sync_cmd = (
"aws s3 sync {local_dir} {remote_dir}".format(
local_dir=quote(self.local_dir),
remote_dir=quote(self.remote_dir)))
elif self.remote_dir.startswith(GCS_PREFIX):
local_to_remote_sync_cmd = (
"gsutil rsync -r {local_dir} {remote_dir}".format(
local_dir=quote(self.local_dir),
remote_dir=quote(self.remote_dir)))
else:
logger.warning("Remote sync unsupported, skipping.")
local_to_remote_sync_cmd = None
return local_to_remote_sync_cmd
return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir)
+18 -22
View File
@@ -13,7 +13,7 @@ import numbers
import numpy as np
import ray.cloudpickle as cloudpickle
from ray.tune.log_sync import get_syncer
from ray.tune.syncer import get_log_syncer
from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
TIMESTEPS_TOTAL
@@ -33,13 +33,11 @@ class Logger(object):
Arguments:
config: Configuration passed to all logger creators.
logdir: Directory for all logger creators to log to.
upload_uri (str): Optional URI where the logdir is sync'ed to.
"""
def __init__(self, config, logdir, upload_uri=None):
def __init__(self, config, logdir):
self.config = config
self.logdir = logdir
self.uri = upload_uri
self._init()
def _init(self):
@@ -196,24 +194,16 @@ DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
class UnifiedLogger(Logger):
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
This class also periodically syncs output to the given upload uri.
Arguments:
config: Configuration passed to all logger creators.
logdir: Directory for all logger creators to log to.
upload_uri (str): Optional URI where the logdir is sync'ed to.
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
and JSON loggers.
sync_function (func|str): Optional function for syncer to run.
See ray/python/ray/tune/log_sync.py
"""
def __init__(self,
config,
logdir,
upload_uri=None,
loggers=None,
sync_function=None):
def __init__(self, config, logdir, loggers=None, sync_function=None):
if loggers is None:
self._logger_cls_list = DEFAULT_LOGGERS
else:
@@ -221,24 +211,26 @@ class UnifiedLogger(Logger):
self._sync_function = sync_function
self._log_syncer = None
Logger.__init__(self, config, logdir, upload_uri)
super(UnifiedLogger, self).__init__(config, logdir)
def _init(self):
self._loggers = []
for cls in self._logger_cls_list:
try:
self._loggers.append(cls(self.config, self.logdir, self.uri))
self._loggers.append(cls(self.config, self.logdir))
except Exception:
logger.warning("Could not instantiate {} - skipping.".format(
str(cls)))
self._log_syncer = get_syncer(
self.logdir, self.uri, sync_function=self._sync_function)
self._log_syncer = get_log_syncer(
self.logdir,
remote_dir=self.logdir,
sync_function=self._sync_function)
def on_result(self, result):
for _logger in self._loggers:
_logger.on_result(result)
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_if_needed()
self._log_syncer.sync_down_if_needed()
def update_config(self, config):
for _logger in self._loggers:
@@ -247,13 +239,12 @@ class UnifiedLogger(Logger):
def close(self):
for _logger in self._loggers:
_logger.close()
self._log_syncer.sync_now(force=False)
self._log_syncer.close()
self._log_syncer.sync_down()
def flush(self):
for _logger in self._loggers:
_logger.flush()
self._log_syncer.sync_now(force=False)
self._log_syncer.sync_down()
def sync_results_to_new_location(self, worker_ip):
"""Sends the current log directory to the remote node.
@@ -262,8 +253,13 @@ class UnifiedLogger(Logger):
with the Ray autoscaler.
"""
if worker_ip != self._log_syncer.worker_ip:
logger.info("Syncing (blocking) results to {}".format(worker_ip))
self._log_syncer.reset()
self._log_syncer.set_worker_ip(worker_ip)
self._log_syncer.sync_to_worker_if_possible()
self._log_syncer.sync_up()
# TODO: change this because this is blocking. But failures
# are rare, so maybe this is OK?
self._log_syncer.wait()
class _SafeFallbackEncoder(json.JSONEncoder):
+266
View File
@@ -0,0 +1,266 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils
import logging
import os
import subprocess
import tempfile
import time
import types
try: # py3
from shlex import quote
except ImportError: # py2
from pipes import quote
from ray.tune.sample import function as tune_function
from ray.tune.error import TuneError
from ray.tune.log_sync import log_sync_template, NodeSyncMixin
logger = logging.getLogger(__name__)
S3_PREFIX = "s3://"
GS_PREFIX = "gs://"
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GS_PREFIX)
SYNC_PERIOD = 300
_syncers = {}
def validate_sync_string(sync_string):
if "{source}" not in sync_string:
raise ValueError("Sync template missing '{source}'.")
if "{target}" not in sync_string:
raise ValueError("Sync template missing '{target}'.")
def wait_for_sync():
for syncer in _syncers.values():
syncer.wait()
class BaseSyncer(object):
def __init__(self, local_dir, remote_dir, sync_function=None):
"""Syncs between two directories with the sync_function.
Arguments:
local_dir (str): Directory to sync. Uniquely identifies the syncer.
remote_dir (str): Remote directory to sync with.
sync_function (func): Function for syncing the local_dir to
remote_dir. Defaults to a Noop.
"""
self._local_dir = (os.path.join(local_dir, "")
if local_dir else local_dir)
self._remote_dir = remote_dir
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
self._sync_function = sync_function or (lambda source, target: None)
def sync_function(self, source, target):
"""Executes sync between source and target.
Can be overwritten by subclasses for custom sync procedures.
Args:
source: Path to source file(s).
target: Path to target file(s).
"""
if self._sync_function:
return self._sync_function(source, target)
def sync(self, source, target):
if not (source and target):
logger.debug(
"Source or target is empty, skipping log sync for {}".format(
self._local_dir))
return
try:
self.sync_function(source, target)
return True
except Exception:
logger.exception("Sync function failed.")
def sync_up_if_needed(self):
if time.time() - self.last_sync_up_time > SYNC_PERIOD:
self.sync_up()
def sync_down_if_needed(self):
if time.time() - self.last_sync_down_time > SYNC_PERIOD:
self.sync_down()
def sync_down(self, *args, **kwargs):
self.sync(self._remote_path, self._local_dir, *args, **kwargs)
self.last_sync_down_time = time.time()
def sync_up(self, *args, **kwargs):
self.sync(self._local_dir, self._remote_path, *args, **kwargs)
self.last_sync_up_time = time.time()
def reset(self):
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
def wait(self):
pass
@property
def _remote_path(self):
"""Protected method for accessing remote_dir.
Can be overridden in subclass for custom path.
"""
return self._remote_dir
class CommandSyncer(BaseSyncer):
def __init__(self, local_dir, remote_dir, sync_template):
"""Syncs between two directories with the given command.
Arguments:
local_dir (str): Directory to sync.
remote_dir (str): Remote directory to sync with.
sync_template (str): A string template
for syncer to run and needs to include replacement fields
'{source}' and '{target}'. Returned when using
`CommandSyncer.sync_template`, which can be overridden
by subclass.
"""
super(CommandSyncer, self).__init__(local_dir, remote_dir)
if not isinstance(sync_template, str):
raise ValueError("{} is not a string.".format(sync_template))
validate_sync_string(sync_template)
self._sync_template = sync_template
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync",
dir=self._local_dir,
suffix=".log",
delete=False)
self.sync_process = None
def sync_function(self, source, target):
self.last_sync_time = time.time()
if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
logger.warning("Last sync is still in progress, skipping.")
return
final_cmd = self._sync_template.format(
source=quote(source), target=quote(target))
logger.debug("Running sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(
final_cmd, shell=True, stdout=self.logfile)
return True
def reset(self):
if self.sync_process:
logger.warning("Sync process still running but resetting anyways.")
self.sync_process = None
super(CommandSyncer, self).reset()
def wait(self):
if self.sync_process:
self.sync_process.wait()
def _get_sync_cls(sync_function):
if not sync_function:
return
if isinstance(sync_function, types.FunctionType) or isinstance(
sync_function, tune_function):
return BaseSyncer
elif isinstance(sync_function, str):
return CommandSyncer
else:
raise ValueError("Sync function {} must be string or function".format(
sync_function))
def get_syncer(local_dir, remote_dir=None, sync_function=None):
"""Returns a Syncer depending on given args.
This syncer is in charge of syncing the local_dir with upload_dir.
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If None,
returns BaseSyncer with a noop.
sync_function (func | str): Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults
to standard S3 or gsutil sync commands.
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
if not remote_dir:
_syncers[key] = BaseSyncer(local_dir, remote_dir)
return _syncers[key]
sync_cls = _get_sync_cls(sync_function)
if sync_cls:
_syncers[key] = sync_cls(local_dir, remote_dir, sync_function)
return _syncers[key]
if remote_dir.startswith(S3_PREFIX):
if not distutils.spawn.find_executable("aws"):
raise TuneError(
"Upload uri starting with '{}' requires awscli tool"
" to be installed".format(S3_PREFIX))
_syncers[key] = CommandSyncer(local_dir, remote_dir,
"aws s3 sync {source} {target}")
elif remote_dir.startswith(GS_PREFIX):
if not distutils.spawn.find_executable("gsutil"):
raise TuneError(
"Upload uri starting with '{}' requires gsutil tool"
" to be installed".format(GS_PREFIX))
_syncers[key] = CommandSyncer(local_dir, remote_dir,
"gsutil rsync -r {source} {target}")
else:
raise TuneError("Upload uri must start with one of: {}"
"".format(ALLOWED_REMOTE_PREFIXES))
return _syncers[key]
def get_log_syncer(local_dir, remote_dir=None, sync_function=None):
"""Returns a Syncer depending on given args.
This syncer is in charge of syncing the local_dir with remote local_dir.
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If None,
returns BaseSyncer with noop.
sync_function (func | str): Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults rsync.
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
sync_cls = None
if sync_function:
sync_cls = _get_sync_cls(sync_function)
else:
sync_cls = CommandSyncer
sync_function = log_sync_template()
if not remote_dir or sync_function is None:
sync_cls = BaseSyncer
class MixedSyncer(NodeSyncMixin, sync_cls):
def __init__(self, *args, **kwargs):
sync_cls.__init__(self, *args, **kwargs)
NodeSyncMixin.__init__(self)
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function)
return _syncers[key]
+14 -12
View File
@@ -272,8 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
cluster.wait_for_nodes()
dirpath = str(tmpdir)
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=dirpath)
runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath)
kwargs = {
"stopping_criterion": {
"training_iteration": 2
@@ -295,7 +294,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
ray.shutdown()
cluster = _start_new_cluster()
runner = TrialRunner.restore(dirpath)
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
runner.step() # start
runner.step() # start2
@@ -377,18 +376,19 @@ tune.run_experiments(
# Wait until the right checkpoint is saved.
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
local_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(100):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
runner = TrialRunner(
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res and last_res.get("training_iteration"):
break
time.sleep(0.3)
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
@@ -469,18 +469,19 @@ tune.run_experiments(
# Wait until the right checkpoint is saved.
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
local_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
runner = TrialRunner(
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res and last_res.get("training_iteration") == 3:
break
time.sleep(0.2)
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
@@ -489,7 +490,8 @@ tune.run_experiments(
Experiment._register_if_needed(_Mock)
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
runner = TrialRunner(
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
trials = runner.get_trials()
assert trials[0].last_result["training_iteration"] == 3
assert trials[0].status == Trial.PENDING
@@ -114,8 +114,8 @@ class ExperimentAnalysisSuite(unittest.TestCase):
runner_data = self.ea.runner_data()
self.assertTrue(isinstance(runner_data, dict))
self.assertTrue("_metadata_checkpoint_dir" in runner_data)
self.assertEqual(runner_data["_metadata_checkpoint_dir"],
self.assertTrue("_local_checkpoint_dir" in runner_data)
self.assertEqual(runner_data["_local_checkpoint_dir"],
os.path.expanduser(self.test_path))
def testBestLogdir(self):
+178 -52
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import copy
import glob
import os
import shutil
import sys
@@ -315,21 +316,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
}
})
def testUploadDirNone(self):
def train(config, reporter):
reporter(timesteps_total=1)
[trial] = run_experiments({
"foo": {
"run": train,
"upload_dir": None,
"config": {
"a": "b"
},
}
})
self.assertFalse(trial.upload_dir)
def testLogdirStartingWithTilde(self):
local_dir = "~/ray_results/local_dir"
@@ -930,50 +916,190 @@ class RunExperimentTest(unittest.TestCase):
str(trial), "{}_{}_321".format(trial.trainable_name,
trial.trial_id))
def testSyncFunction(self):
def fail_sync_local():
[trial] = run_experiments({
"foo": {
"run": "__fake",
class TestSyncFunctionality(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
@patch("ray.tune.syncer.S3_PREFIX", "test")
def testNoUploadDir(self):
"""No Upload Dir is given."""
with self.assertRaises(AssertionError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_cloud": "echo {source} {target}"
})
@patch("ray.tune.syncer.S3_PREFIX", "test")
def testCloudProperString(self):
with self.assertRaises(ValueError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"upload_dir": "test",
"sync_function": "ls {remote_dir}"
}
})
"sync_to_cloud": "ls {target}"
})
self.assertRaises(AssertionError, fail_sync_local)
def fail_sync_remote():
[trial] = run_experiments({
"foo": {
"run": "__fake",
with self.assertRaises(ValueError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"upload_dir": "test",
"sync_function": "ls {local_dir}"
}
})
"sync_to_cloud": "ls {source}"
})
self.assertRaises(AssertionError, fail_sync_remote)
tmpdir = tempfile.mkdtemp()
logfile = os.path.join(tmpdir, "test.log")
def sync_func(local, remote):
with open(os.path.join(local, "test.log"), "w") as f:
f.write(remote)
[trial] = run_experiments({
"foo": {
"run": "__fake",
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"upload_dir": "test",
"sync_function": tune.function(sync_func)
}
})
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
"sync_to_cloud": "echo {source} {target} > " + logfile
})
with open(logfile) as f:
lines = f.read()
self.assertTrue("test" in lines)
shutil.rmtree(tmpdir)
def testClusterProperString(self):
"""Tests that invalid commands throw.."""
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "ls {target}"
})
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "ls {source}"
})
with patch("ray.tune.syncer.CommandSyncer.sync_function"
) as mock_fn, patch(
"ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"sync_to_driver": "echo {source} {target}"
})
self.assertGreater(mock_fn.call_count, 0)
def testCloudFunctions(self):
tmpdir = tempfile.mkdtemp()
tmpdir2 = tempfile.mkdtemp()
os.mkdir(os.path.join(tmpdir2, "foo"))
def sync_func(local, remote):
for filename in glob.glob(os.path.join(local, "*.json")):
shutil.copy(filename, remote)
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
local_dir=tmpdir,
stop={"training_iteration": 1},
upload_dir=tmpdir2,
sync_to_cloud=tune.function(sync_func))
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
self.assertTrue(test_file_path)
shutil.rmtree(tmpdir)
shutil.rmtree(tmpdir2)
def testClusterSyncFunction(self):
def sync_func_driver(source, target):
assert ":" in source, "Source not a remote path."
assert ":" not in target, "Target is supposed to be local."
with open(os.path.join(target, "test.log2"), "w") as f:
print("writing to", f.name)
f.write(source)
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_to_driver=tune.function(sync_func_driver))
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertFalse(os.path.exists(test_file_path))
with patch("ray.services.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_to_driver=tune.function(sync_func_driver))
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertTrue(os.path.exists(test_file_path))
os.remove(test_file_path)
def testNoSync(self):
def sync_func(source, target):
pass
with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync:
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
**{
"stop": {
"training_iteration": 1
},
"upload_dir": "test",
"sync_to_driver": tune.function(sync_func),
"sync_to_cloud": tune.function(sync_func)
})
self.assertEqual(mock_sync.call_count, 0)
class VariantGeneratorTest(unittest.TestCase):
@@ -1960,7 +2086,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir)
trials = [
Trial(
"__fake",
@@ -1999,7 +2125,7 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
self.assertEquals(trials[2].status, Trial.RUNNING)
runner2 = TrialRunner.restore(tmpdir)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
for tid in ["trial_terminate", "trial_fail"]:
original_trial = runner.get_trial(tid)
restored_trial = runner2.get_trial(tid)
@@ -2019,7 +2145,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner.add_trial(
Trial(
@@ -2051,7 +2177,7 @@ class TrialRunnerTest(unittest.TestCase):
runner.step()
runner.step()
runner2 = TrialRunner.restore(tmpdir)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
new_trials = runner2.get_trials()
self.assertEquals(len(new_trials), 3)
self.assertTrue(
@@ -2074,13 +2200,13 @@ class TrialRunnerTest(unittest.TestCase):
},
checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner.add_trial(trial)
for i in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
runner2 = TrialRunner.restore(tmpdir)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
new_trial = runner2.get_trials()[0]
self.assertTrue("callbacks" in new_trial.config)
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
@@ -2095,7 +2221,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init()
trial = Trial("__fake", checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner.add_trial(trial)
for i in range(5):
runner.step()
@@ -2103,7 +2229,7 @@ class TrialRunnerTest(unittest.TestCase):
runner.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 1)
runner2 = TrialRunner.restore(tmpdir)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
for i in range(5):
runner2.step()
self.assertEquals(count_checkpoints(tmpdir), 2)
+4 -9
View File
@@ -19,7 +19,6 @@ from six import string_types
import ray
from ray.tune import TuneError
from ray.tune.log_sync import validate_sync_function
from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not
@@ -276,10 +275,9 @@ class Trial(object):
checkpoint_score_attr="",
export_formats=None,
restore_path=None,
upload_dir=None,
trial_name_creator=None,
loggers=None,
sync_function=None,
sync_to_driver_fn=None,
max_failures=0):
"""Initialize a new trial.
@@ -308,10 +306,8 @@ class Trial(object):
resources = default_resources
self.resources = resources or Resources(cpu=1, gpu=0)
self.stopping_criterion = stopping_criterion or {}
self.upload_dir = upload_dir
self.loggers = loggers
self.sync_function = sync_function
validate_sync_function(sync_function)
self.sync_to_driver_fn = sync_to_driver_fn
self.verbose = True
self.max_failures = max_failures
@@ -352,7 +348,7 @@ class Trial(object):
self._nonjson_fields = [
"_checkpoint",
"loggers",
"sync_function",
"sync_to_driver_fn",
"results",
"best_result",
"param_config",
@@ -394,9 +390,8 @@ class Trial(object):
self.result_logger = UnifiedLogger(
self.config,
self.logdir,
upload_uri=self.upload_dir,
loggers=self.loggers,
sync_function=self.sync_function)
sync_function=self.sync_to_driver_fn)
def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements.
+96 -42
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import collections
from datetime import datetime
import json
@@ -15,6 +16,7 @@ import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.syncer import get_syncer
from ray.tune.trial import Trial, Checkpoint
from ray.tune.sample import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
@@ -97,12 +99,16 @@ class TrialRunner(object):
"""
CKPT_FILE_TMPL = "experiment_state-{}.json"
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT"]
def __init__(self,
search_alg=None,
scheduler=None,
launch_web_server=False,
metadata_checkpoint_dir=None,
local_checkpoint_dir=None,
remote_checkpoint_dir=None,
sync_to_cloud=None,
resume=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
trial_executor=None):
@@ -113,15 +119,16 @@ class TrialRunner(object):
Trial objects.
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
metadata_checkpoint_dir (str): Path where
local_checkpoint_dir (str): Path where
global checkpoints are stored and restored from.
server_port (int): Port number for launching TuneServer
remote_checkpoint_dir (str): Remote path where
global checkpoints are stored and restored from. Used
if `resume` == REMOTE.
resume (str|False): see `tune.py:run`.
sync_to_cloud (func|str): see `tune.py:run`.
server_port (int): Port number for launching TuneServer.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
reuse_actors (bool): Whether to reuse actors between different
trials when possible. This can drastically speed up experiments
that start and stop actors often (e.g., PBT in
time-multiplexing mode).
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
"""
self._search_alg = search_alg or BasicVariantGenerator()
@@ -143,12 +150,73 @@ class TrialRunner(object):
self._trials = []
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir
self._local_checkpoint_dir = local_checkpoint_dir
if self._local_checkpoint_dir and not os.path.exists(
self._local_checkpoint_dir):
os.makedirs(self._local_checkpoint_dir)
self._remote_checkpoint_dir = remote_checkpoint_dir
self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir,
sync_to_cloud)
self._resumed = False
if self._validate_resume(resume_type=resume):
try:
self.resume()
logger.info("Resuming trial.")
self._resumed = True
except Exception:
logger.exception(
"Runner restore failed. Restarting experiment.")
else:
logger.info("Starting a new experiment.")
self._start_time = time.time()
self._session_str = datetime.fromtimestamp(
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
def _validate_resume(self, resume_type):
"""Checks whether to resume experiment.
Args:
resume_type: One of True, "REMOTE", "LOCAL", "PROMPT".
"""
if not resume_type:
return False
assert resume_type in self.VALID_RESUME_TYPES, (
"resume_type {} is not one of {}".format(resume_type,
self.VALID_RESUME_TYPES))
# Not clear if we need this assertion, since we should always have a
# local checkpoint dir.
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
if resume_type in [True, "LOCAL", "PROMPT"]:
if not self.checkpoint_exists(self._local_checkpoint_dir):
raise ValueError("Called resume when no checkpoint exists "
"in local directory.")
elif resume_type == "PROMPT":
if click.confirm("Resume from local directory?"):
return True
if resume_type in ["REMOTE", "PROMPT"]:
if resume_type == "PROMPT" and not click.confirm(
"Try downloading from remote directory?"):
return False
if not self._remote_checkpoint_dir:
raise ValueError(
"Called resume from remote without remote directory.")
# Try syncing down the upload directory.
logger.info("Downloading from {}".format(
self._remote_checkpoint_dir))
self._syncer.sync_down_if_needed()
if not self.checkpoint_exists(self._local_checkpoint_dir):
raise ValueError("Called resume when no checkpoint exists "
"in remote or local directory.")
return True
@classmethod
def checkpoint_exists(cls, directory):
if not os.path.exists(directory):
@@ -157,17 +225,21 @@ class TrialRunner(object):
(fname.startswith("experiment_state") and fname.endswith(".json"))
for fname in os.listdir(directory))
def add_experiment(self, experiment):
if not self._resumed:
self._search_alg.add_configurations([experiment])
else:
logger.info("TrialRunner resumed, ignoring new add_experiment.")
def checkpoint(self):
"""Saves execution state to `self._metadata_checkpoint_dir`.
"""Saves execution state to `self._local_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated.
"""
if not self._metadata_checkpoint_dir:
if not self._local_checkpoint_dir:
return
metadata_checkpoint_dir = self._metadata_checkpoint_dir
if not os.path.exists(metadata_checkpoint_dir):
os.makedirs(metadata_checkpoint_dir)
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
@@ -177,55 +249,37 @@ class TrialRunner(object):
"timestamp": time.time()
}
}
tmp_file_name = os.path.join(metadata_checkpoint_dir,
tmp_file_name = os.path.join(self._local_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir,
os.path.join(self._local_checkpoint_dir,
TrialRunner.CKPT_FILE_TMPL.format(self._session_str)))
return metadata_checkpoint_dir
self._syncer.sync_up_if_needed()
return self._local_checkpoint_dir
@classmethod
def restore(cls,
metadata_checkpoint_dir,
search_alg=None,
scheduler=None,
trial_executor=None):
"""Restores all checkpointed trials from previous run.
def resume(self):
"""Resumes all checkpointed trials from previous run.
Requires user to manually re-register their objects. Also stops
all ongoing trials.
Args:
metadata_checkpoint_dir (str): Path to metadata checkpoints.
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
BasicVariantGenerator.
scheduler (TrialScheduler): Scheduler for executing
the experiment.
trial_executor (TrialExecutor): Manage the execution of trials.
Returns:
runner (TrialRunner): A TrialRunner to resume experiments from.
"""
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f, cls=_TuneFunctionDecoder)
logger.warning("".join([
"Attempting to resume experiment from {}. ".format(
metadata_checkpoint_dir), "This feature is experimental, "
self._local_checkpoint_dir), "This feature is experimental, "
"and may not work with all search algorithms. ",
"This will ignore any new changes to the specification."
]))
runner = TrialRunner(
search_alg, scheduler=scheduler, trial_executor=trial_executor)
runner.__setstate__(runner_state["runner_data"])
self.__setstate__(runner_state["runner_data"])
trials = []
for trial_cp in runner_state["checkpoints"]:
@@ -234,8 +288,7 @@ class TrialRunner(object):
trials += [new_trial]
for trial in sorted(
trials, key=lambda t: t.last_update_time, reverse=True):
runner.add_trial(trial)
return runner
self.add_trial(trial)
def is_finished(self):
"""Returns whether all trials have finished running."""
@@ -626,6 +679,7 @@ class TrialRunner(object):
"_search_alg",
"_scheduler_alg",
"trial_executor",
"_syncer",
]:
del state[k]
state["launch_web_server"] = bool(self._server)
+43 -70
View File
@@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import logging
import time
@@ -12,7 +11,7 @@ from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.log_sync import wait_for_log_sync
from ray.tune.syncer import wait_for_sync
from ray.tune.trial_runner import TrialRunner
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
FIFOScheduler, MedianStoppingRule)
@@ -36,36 +35,6 @@ def _make_scheduler(args):
args.scheduler, _SCHEDULERS.keys()))
def _find_checkpoint_dir(exp):
# TODO(rliaw): Make sure the checkpoint_dir is resolved earlier.
# Right now it is resolved somewhere far down the trial generation process
return exp.checkpoint_dir
def _prompt_restore(checkpoint_dir, resume):
restore = False
if TrialRunner.checkpoint_exists(checkpoint_dir):
if resume == "prompt":
msg = ("Found incomplete experiment at {}. "
"Would you like to resume it?".format(checkpoint_dir))
restore = click.confirm(msg, default=False)
if restore:
logger.info("Tip: to always resume, "
"pass resume=True to run()")
else:
logger.info("Tip: to always start a new experiment, "
"pass resume=False to run()")
elif resume:
restore = True
else:
logger.info("Tip: to resume incomplete experiments, "
"pass resume='prompt' or resume=True to run()")
else:
logger.info(
"Did not find checkpoint file in {}.".format(checkpoint_dir))
return restore
def run(run_or_experiment,
name=None,
stop=None,
@@ -76,7 +45,8 @@ def run(run_or_experiment,
upload_dir=None,
trial_name_creator=None,
loggers=None,
sync_function=None,
sync_to_cloud=None,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
export_formats=None,
@@ -93,7 +63,8 @@ def run(run_or_experiment,
trial_executor=None,
raise_on_failed_trial=True,
return_trials=True,
ray_auto_init=True):
ray_auto_init=True,
sync_function=None):
"""Executes training.
Args:
@@ -129,10 +100,15 @@ def run(run_or_experiment,
loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`.
sync_function (func|str): Function for syncing the local_dir to
upload_dir. If string, then it must be a string template for
syncer to run. If not provided, the sync command defaults
to standard S3 or gsutil sync comamnds.
sync_to_cloud (func|str): Function for syncing the local_dir to and
from upload_dir. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If not provided, the sync command defaults to standard
S3 or gsutil sync comamnds.
sync_to_driver (func|str): Function for syncing trial logdir from
remote node to local. If string, then it must be a string template
that includes `{source}` and `{target}` for the syncer to run.
If not provided, defaults to using rsync.
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
@@ -155,9 +131,12 @@ def run(run_or_experiment,
server_port (int): Port number for launching TuneServer.
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
1 = only status updates, 2 = status and trial results.
resume (bool|"prompt"): If checkpoint exists, the experiment will
resume from there. If resume is "prompt", Tune will prompt if
checkpoint detected.
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool.
LOCAL/True restores the checkpoint from the local_checkpoint_dir.
REMOTE restores the checkpoint from remote_checkpoint_dir.
PROMPT provides CLI feedback. False forces a new
experiment. If resume is set but checkpoint does not exist,
ValueError will be thrown.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
@@ -172,6 +151,8 @@ def run(run_or_experiment,
ray_auto_init (bool): Automatically starts a local Ray cluster
if using a RayTrialExecutor (which is the default) and
if Ray is not initialized. Defaults to True.
sync_function: Deprecated. See `sync_to_cloud` and
`sync_to_driver`.
Returns:
List of Trial objects.
@@ -199,53 +180,45 @@ def run(run_or_experiment,
ray_auto_init=ray_auto_init)
experiment = run_or_experiment
if not isinstance(run_or_experiment, Experiment):
run_identifier = Experiment._register_if_needed(run_or_experiment)
experiment = Experiment(
name=name,
run=run_or_experiment,
run=run_identifier,
stop=stop,
config=config,
resources_per_trial=resources_per_trial,
num_samples=num_samples,
local_dir=local_dir,
upload_dir=upload_dir,
sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator,
loggers=loggers,
sync_function=sync_function,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
export_formats=export_formats,
max_failures=max_failures,
restore=restore)
restore=restore,
sync_function=sync_function)
else:
logger.debug("Ignoring some parameters passed into tune.run.")
checkpoint_dir = _find_checkpoint_dir(experiment)
should_restore = _prompt_restore(checkpoint_dir, resume)
if sync_to_cloud:
assert experiment.remote_checkpoint_dir, (
"Need `upload_dir` if `sync_to_cloud` given.")
runner = None
if should_restore:
try:
runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler,
trial_executor)
except Exception:
logger.exception("Runner restore failed. Restarting experiment.")
else:
logger.info("Starting a new experiment.")
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiment.checkpoint_dir,
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
resume=resume,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
trial_executor=trial_executor)
if not runner:
scheduler = scheduler or FIFOScheduler()
search_alg = search_alg or BasicVariantGenerator()
search_alg.add_configurations([experiment])
runner = TrialRunner(
search_alg=search_alg,
scheduler=scheduler,
metadata_checkpoint_dir=checkpoint_dir,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
trial_executor=trial_executor)
runner.add_experiment(experiment)
if verbose:
print(runner.debug_string(max_debug=99999))
@@ -261,7 +234,7 @@ def run(run_or_experiment,
if verbose:
print(runner.debug_string(max_debug=99999))
wait_for_log_sync()
wait_for_sync()
errored_trials = []
for trial in runner.get_trials():