mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 01:43:50 +08:00
[tune] Change the log syncing behavior (#4450)
* Change the log syncing behavior * fix up abstractions for syncer * Finished checkpoint syncing * Code * Set of changes to get things running * Fixes for log syncing * Fix parts * Lint and other fixes * fix some test * Remove extra parsing functionality * some test fixes * Fix up cloud syncing * Another thing to do * Fix up tests and local sync Changes LogSync into a mixin, and adds tests for different functionalities. * Fix up tests, start on local migration * fix distributed migrations * comments * formatting * Better checkpoint directory handling * fix tests * fix tests * fix click * comments * formatting comments * formatting and comments * sync function deprecations * syncfunction * Add documentation for Syncing and Uploading * nit * BaseSyncer as base for Mixin in edge case * more docs * clean up assertions * validate * nit * Update test_cluster.py * betterdoc * Update tune-usage.rst * cleanup * nit
This commit is contained in:
committed by
Richard Liaw
parent
71d4637b75
commit
9e0192bc0b
@@ -10,6 +10,7 @@ import yaml
|
||||
import ray
|
||||
from ray.tests.cluster_utils import Cluster
|
||||
from ray.tune.config_parser import make_parser
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import resources_to_json
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
@@ -71,6 +72,17 @@ def create_parser(parser_creator=None):
|
||||
default="default",
|
||||
type=str,
|
||||
help="Name of the subdirectory under `local_dir` to put results in.")
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
default=DEFAULT_RESULTS_DIR,
|
||||
type=str,
|
||||
help="Local dir to save training results to. Defaults to '{}'.".format(
|
||||
DEFAULT_RESULTS_DIR))
|
||||
parser.add_argument(
|
||||
"--upload-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
|
||||
@@ -10,7 +10,6 @@ import os
|
||||
from six import string_types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Trial, json_to_resources
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
|
||||
@@ -65,17 +64,6 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of times to repeat each trial.")
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
default=DEFAULT_RESULTS_DIR,
|
||||
type=str,
|
||||
help="Local dir to save training results to. Defaults to '{}'.".format(
|
||||
DEFAULT_RESULTS_DIR))
|
||||
parser.add_argument(
|
||||
"--upload-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq",
|
||||
default=0,
|
||||
@@ -183,7 +171,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
trainable_name=spec["run"],
|
||||
# json.load leads to str -> unicode in py2.7
|
||||
config=spec.get("config", {}),
|
||||
local_dir=os.path.join(args.local_dir, output_path),
|
||||
local_dir=os.path.join(spec["local_dir"], output_path),
|
||||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
@@ -193,10 +181,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
export_formats=spec.get("export_formats", []),
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
loggers=spec.get("loggers"),
|
||||
# str(None) doesn't create None
|
||||
sync_function=spec.get("sync_function"),
|
||||
sync_to_driver_fn=spec.get("sync_to_driver"),
|
||||
max_failures=args.max_failures,
|
||||
**trial_kwargs)
|
||||
|
||||
@@ -11,9 +11,8 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Trainable, run, Experiment
|
||||
from ray.tune import Trainable, run
|
||||
|
||||
|
||||
class TestLogger(tune.logger.Logger):
|
||||
@@ -60,11 +59,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
exp = Experiment(
|
||||
|
||||
trials = run(
|
||||
MyTrainableClass,
|
||||
name="hyperband_test",
|
||||
run=MyTrainableClass,
|
||||
num_samples=1,
|
||||
num_samples=5,
|
||||
trial_name_creator=tune.function(trial_str_creator),
|
||||
loggers=[TestLogger],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
@@ -73,5 +72,3 @@ if __name__ == "__main__":
|
||||
lambda spec: 10 + int(90 * random.random())),
|
||||
"height": tune.sample_from(lambda spec: int(100 * random.random()))
|
||||
})
|
||||
|
||||
trials = run(exp)
|
||||
|
||||
@@ -52,7 +52,6 @@ class Experiment(object):
|
||||
>>> },
|
||||
>>> num_samples=10,
|
||||
>>> local_dir="~/ray_results",
|
||||
>>> upload_dir="s3://your_bucket/path",
|
||||
>>> checkpoint_freq=10,
|
||||
>>> max_failures=2)
|
||||
"""
|
||||
@@ -68,7 +67,7 @@ class Experiment(object):
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
sync_function=None,
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
keep_checkpoints_num=None,
|
||||
@@ -78,18 +77,16 @@ class Experiment(object):
|
||||
restore=None,
|
||||
repeat=None,
|
||||
trial_resources=None,
|
||||
custom_loggers=None):
|
||||
if sync_function:
|
||||
assert upload_dir, "Need `upload_dir` if sync_function given."
|
||||
|
||||
custom_loggers=None,
|
||||
sync_function=None):
|
||||
if repeat:
|
||||
_raise_deprecation_note("repeat", "num_samples", soft=False)
|
||||
if trial_resources:
|
||||
_raise_deprecation_note(
|
||||
"trial_resources", "resources_per_trial", soft=False)
|
||||
if custom_loggers:
|
||||
_raise_deprecation_note("custom_loggers", "loggers", soft=False)
|
||||
|
||||
if sync_function:
|
||||
_raise_deprecation_note(
|
||||
"sync_function", "sync_to_driver", soft=False)
|
||||
run_identifier = Experiment._register_if_needed(run)
|
||||
spec = {
|
||||
"run": run_identifier,
|
||||
@@ -98,10 +95,10 @@ class Experiment(object):
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
|
||||
"upload_dir": upload_dir or "", # argparse converts None to "null"
|
||||
"upload_dir": upload_dir,
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"loggers": loggers,
|
||||
"sync_function": sync_function,
|
||||
"sync_to_driver": sync_to_driver,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"keep_checkpoints_num": keep_checkpoints_num,
|
||||
@@ -182,7 +179,13 @@ class Experiment(object):
|
||||
|
||||
@property
|
||||
def checkpoint_dir(self):
|
||||
return os.path.join(self.spec["local_dir"], self.name)
|
||||
if self.local_dir:
|
||||
return os.path.join(self.local_dir, self.name)
|
||||
|
||||
@property
|
||||
def remote_checkpoint_dir(self):
|
||||
if self.spec["upload_dir"]:
|
||||
return os.path.join(self.spec["upload_dir"], self.name)
|
||||
|
||||
|
||||
def convert_to_experiment_list(experiments):
|
||||
|
||||
+41
-208
@@ -4,11 +4,6 @@ from __future__ import print_function
|
||||
|
||||
import distutils.spawn
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import types
|
||||
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
@@ -17,231 +12,69 @@ except ImportError: # py2
|
||||
|
||||
import ray
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import function as tune_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_log_sync_warned = False
|
||||
|
||||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
|
||||
S3_PREFIX = "s3://"
|
||||
GCS_PREFIX = "gs://"
|
||||
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX)
|
||||
def log_sync_template():
|
||||
"""Syncs the local_dir between driver and worker if possible.
|
||||
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
rsync to be installed.
|
||||
|
||||
def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
if remote_dir:
|
||||
if not sync_function and not any(
|
||||
remote_dir.startswith(prefix)
|
||||
for prefix in ALLOWED_REMOTE_PREFIXES):
|
||||
raise TuneError("Upload uri must start with one of: {}"
|
||||
"".format(ALLOWED_REMOTE_PREFIXES))
|
||||
|
||||
if (remote_dir.startswith(S3_PREFIX)
|
||||
and not distutils.spawn.find_executable("aws")):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX))
|
||||
elif (remote_dir.startswith(GCS_PREFIX)
|
||||
and not distutils.spawn.find_executable("gsutil")):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GCS_PREFIX))
|
||||
|
||||
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
|
||||
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
|
||||
remote_dir = os.path.join(remote_dir, rel_path)
|
||||
|
||||
key = (local_dir, remote_dir)
|
||||
if key not in _syncers:
|
||||
_syncers[key] = _LogSyncer(local_dir, remote_dir, sync_function)
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
def wait_for_log_sync():
|
||||
for syncer in _syncers.values():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
def validate_sync_function(sync_function):
|
||||
if sync_function is None:
|
||||
return
|
||||
elif isinstance(sync_function, str):
|
||||
assert "{remote_dir}" in sync_function, (
|
||||
"Sync template missing '{remote_dir}'.")
|
||||
assert "{local_dir}" in sync_function, (
|
||||
"Sync template missing '{local_dir}'.")
|
||||
elif not (isinstance(sync_function, types.FunctionType)
|
||||
or isinstance(sync_function, tune_function)):
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
|
||||
|
||||
class _LogSyncer(object):
|
||||
"""Log syncer for tune.
|
||||
|
||||
This syncs files from workers to the local node, and optionally also from
|
||||
the local node to a remote directory (e.g. S3).
|
||||
|
||||
Arguments:
|
||||
logdir (str): Directory to sync from.
|
||||
upload_uri (str): Directory to sync to.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
upload_dir. If string, then it must be a string template
|
||||
for syncer to run and needs to include replacement fields
|
||||
'{local_dir}' and '{remote_dir}'.
|
||||
"""
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
return
|
||||
global _log_sync_warned
|
||||
ssh_key = get_ssh_key()
|
||||
if ssh_key is None:
|
||||
if not _log_sync_warned:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray up`.")
|
||||
_log_sync_warned = True
|
||||
return
|
||||
|
||||
def __init__(self, local_dir, remote_dir=None, sync_function=None):
|
||||
self.local_dir = local_dir
|
||||
self.remote_dir = remote_dir
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync", dir=self.local_dir, suffix=".log", delete=False)
|
||||
return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" {{source}} {{target}}"""
|
||||
).format(ssh_key=quote(ssh_key))
|
||||
|
||||
# Resolve sync_function into template or function
|
||||
self.sync_func = None
|
||||
self.sync_cmd_tmpl = None
|
||||
if isinstance(sync_function, types.FunctionType) or isinstance(
|
||||
sync_function, tune_function):
|
||||
self.sync_func = sync_function
|
||||
elif isinstance(sync_function, str):
|
||||
self.sync_cmd_tmpl = sync_function
|
||||
self.last_sync_time = 0
|
||||
self.sync_process = None
|
||||
|
||||
class NodeSyncMixin():
|
||||
"""Mixin for syncing files to/from a remote dir to a local dir."""
|
||||
|
||||
def __init__(self):
|
||||
assert hasattr(self, "_remote_dir"), "Mixin not mixed with Syncer."
|
||||
self.local_ip = ray.services.get_node_ip_address()
|
||||
self.worker_ip = None
|
||||
logger.debug("Created LogSyncer for {} -> {}".format(
|
||||
local_dir, remote_dir))
|
||||
|
||||
def close(self):
|
||||
self.logfile.close()
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
"""Set the worker ip to sync logs from."""
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
def sync_if_needed(self):
|
||||
if time.time() - self.last_sync_time > 300:
|
||||
self.sync_now()
|
||||
|
||||
def sync_to_worker_if_possible(self):
|
||||
"""Syncs the local logdir on driver to worker if possible.
|
||||
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
rsync to be installed.
|
||||
"""
|
||||
def _check_valid_worker_ip(self):
|
||||
if not self.worker_ip:
|
||||
logger.info("Worker ip unknown, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
return False
|
||||
if self.worker_ip == self.local_ip:
|
||||
return
|
||||
ssh_key = get_ssh_key()
|
||||
logger.debug(
|
||||
"Worker ip is local ip, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if ssh_key is None or ssh_user is None:
|
||||
if not self._check_valid_worker_ip():
|
||||
return
|
||||
if ssh_user is None:
|
||||
if not _log_sync_warned:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray up`.")
|
||||
_log_sync_warned = True
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
return
|
||||
source = "{}/".format(self.local_dir)
|
||||
target = "{}@{}:{}/".format(ssh_user, self.worker_ip, self.local_dir)
|
||||
final_cmd = (("""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" {} {}""").format(
|
||||
quote(ssh_key), quote(source), quote(target)))
|
||||
logger.info("Syncing results to %s", str(self.worker_ip))
|
||||
sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
sync_process.wait()
|
||||
|
||||
def sync_now(self, force=False):
|
||||
self.last_sync_time = time.time()
|
||||
if not self.worker_ip:
|
||||
logger.debug("Worker ip unknown, skipping log sync for {}".format(
|
||||
self.local_dir))
|
||||
return
|
||||
|
||||
if self.worker_ip == self.local_ip:
|
||||
worker_to_local_sync_cmd = None # don't need to rsync
|
||||
else:
|
||||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if ssh_key is None or ssh_user is None:
|
||||
if not _log_sync_warned:
|
||||
logger.error("Log sync requires cluster to be setup with "
|
||||
"`ray up`.")
|
||||
_log_sync_warned = True
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
return
|
||||
source = "{}@{}:{}/".format(ssh_user, self.worker_ip,
|
||||
self.local_dir)
|
||||
target = "{}/".format(self.local_dir)
|
||||
worker_to_local_sync_cmd = ((
|
||||
"""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" {} {}""").format(
|
||||
quote(ssh_key), quote(source), quote(target)))
|
||||
|
||||
if self.remote_dir:
|
||||
if self.sync_func:
|
||||
local_to_remote_sync_cmd = None
|
||||
try:
|
||||
self.sync_func(self.local_dir, self.remote_dir)
|
||||
except Exception:
|
||||
logger.exception("Sync function failed.")
|
||||
else:
|
||||
local_to_remote_sync_cmd = self.get_remote_sync_cmd()
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
if force:
|
||||
self.sync_process.kill()
|
||||
else:
|
||||
logger.warning("Last sync is still in progress, skipping.")
|
||||
return
|
||||
|
||||
if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
|
||||
final_cmd = ""
|
||||
if worker_to_local_sync_cmd:
|
||||
final_cmd += worker_to_local_sync_cmd
|
||||
if local_to_remote_sync_cmd:
|
||||
if final_cmd:
|
||||
final_cmd += " && "
|
||||
final_cmd += local_to_remote_sync_cmd
|
||||
logger.debug("Running log sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
self.sync_process.wait()
|
||||
|
||||
def get_remote_sync_cmd(self):
|
||||
if self.sync_cmd_tmpl:
|
||||
local_to_remote_sync_cmd = (self.sync_cmd_tmpl.format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(S3_PREFIX):
|
||||
local_to_remote_sync_cmd = (
|
||||
"aws s3 sync {local_dir} {remote_dir}".format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(GCS_PREFIX):
|
||||
local_to_remote_sync_cmd = (
|
||||
"gsutil rsync -r {local_dir} {remote_dir}".format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
else:
|
||||
logger.warning("Remote sync unsupported, skipping.")
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
return local_to_remote_sync_cmd
|
||||
return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir)
|
||||
|
||||
+18
-22
@@ -13,7 +13,7 @@ import numbers
|
||||
import numpy as np
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.log_sync import get_syncer
|
||||
from ray.tune.syncer import get_log_syncer
|
||||
from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
|
||||
TIMESTEPS_TOTAL
|
||||
|
||||
@@ -33,13 +33,11 @@ class Logger(object):
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
"""
|
||||
|
||||
def __init__(self, config, logdir, upload_uri=None):
|
||||
def __init__(self, config, logdir):
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.uri = upload_uri
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
@@ -196,24 +194,16 @@ DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
|
||||
This class also periodically syncs output to the given upload uri.
|
||||
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
|
||||
and JSON loggers.
|
||||
sync_function (func|str): Optional function for syncer to run.
|
||||
See ray/python/ray/tune/log_sync.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
logdir,
|
||||
upload_uri=None,
|
||||
loggers=None,
|
||||
sync_function=None):
|
||||
def __init__(self, config, logdir, loggers=None, sync_function=None):
|
||||
if loggers is None:
|
||||
self._logger_cls_list = DEFAULT_LOGGERS
|
||||
else:
|
||||
@@ -221,24 +211,26 @@ class UnifiedLogger(Logger):
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
|
||||
Logger.__init__(self, config, logdir, upload_uri)
|
||||
super(UnifiedLogger, self).__init__(config, logdir)
|
||||
|
||||
def _init(self):
|
||||
self._loggers = []
|
||||
for cls in self._logger_cls_list:
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
self._loggers.append(cls(self.config, self.logdir))
|
||||
except Exception:
|
||||
logger.warning("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
self._log_syncer = get_log_syncer(
|
||||
self.logdir,
|
||||
remote_dir=self.logdir,
|
||||
sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for _logger in self._loggers:
|
||||
_logger.on_result(result)
|
||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
self._log_syncer.sync_if_needed()
|
||||
self._log_syncer.sync_down_if_needed()
|
||||
|
||||
def update_config(self, config):
|
||||
for _logger in self._loggers:
|
||||
@@ -247,13 +239,12 @@ class UnifiedLogger(Logger):
|
||||
def close(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_now(force=False)
|
||||
self._log_syncer.close()
|
||||
self._log_syncer.sync_down()
|
||||
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
self._log_syncer.sync_now(force=False)
|
||||
self._log_syncer.sync_down()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
@@ -262,8 +253,13 @@ class UnifiedLogger(Logger):
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
logger.info("Syncing (blocking) results to {}".format(worker_ip))
|
||||
self._log_syncer.reset()
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
self._log_syncer.sync_to_worker_if_possible()
|
||||
self._log_syncer.sync_up()
|
||||
# TODO: change this because this is blocking. But failures
|
||||
# are rare, so maybe this is OK?
|
||||
self._log_syncer.wait()
|
||||
|
||||
|
||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import types
|
||||
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
except ImportError: # py2
|
||||
from pipes import quote
|
||||
|
||||
from ray.tune.sample import function as tune_function
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.log_sync import log_sync_template, NodeSyncMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
S3_PREFIX = "s3://"
|
||||
GS_PREFIX = "gs://"
|
||||
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GS_PREFIX)
|
||||
SYNC_PERIOD = 300
|
||||
|
||||
_syncers = {}
|
||||
|
||||
|
||||
def validate_sync_string(sync_string):
|
||||
if "{source}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{source}'.")
|
||||
if "{target}" not in sync_string:
|
||||
raise ValueError("Sync template missing '{target}'.")
|
||||
|
||||
|
||||
def wait_for_sync():
|
||||
for syncer in _syncers.values():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
class BaseSyncer(object):
|
||||
def __init__(self, local_dir, remote_dir, sync_function=None):
|
||||
"""Syncs between two directories with the sync_function.
|
||||
|
||||
Arguments:
|
||||
local_dir (str): Directory to sync. Uniquely identifies the syncer.
|
||||
remote_dir (str): Remote directory to sync with.
|
||||
sync_function (func): Function for syncing the local_dir to
|
||||
remote_dir. Defaults to a Noop.
|
||||
"""
|
||||
self._local_dir = (os.path.join(local_dir, "")
|
||||
if local_dir else local_dir)
|
||||
self._remote_dir = remote_dir
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
self._sync_function = sync_function or (lambda source, target: None)
|
||||
|
||||
def sync_function(self, source, target):
|
||||
"""Executes sync between source and target.
|
||||
|
||||
Can be overwritten by subclasses for custom sync procedures.
|
||||
|
||||
Args:
|
||||
source: Path to source file(s).
|
||||
target: Path to target file(s).
|
||||
"""
|
||||
if self._sync_function:
|
||||
return self._sync_function(source, target)
|
||||
|
||||
def sync(self, source, target):
|
||||
if not (source and target):
|
||||
logger.debug(
|
||||
"Source or target is empty, skipping log sync for {}".format(
|
||||
self._local_dir))
|
||||
return
|
||||
|
||||
try:
|
||||
self.sync_function(source, target)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Sync function failed.")
|
||||
|
||||
def sync_up_if_needed(self):
|
||||
if time.time() - self.last_sync_up_time > SYNC_PERIOD:
|
||||
self.sync_up()
|
||||
|
||||
def sync_down_if_needed(self):
|
||||
if time.time() - self.last_sync_down_time > SYNC_PERIOD:
|
||||
self.sync_down()
|
||||
|
||||
def sync_down(self, *args, **kwargs):
|
||||
self.sync(self._remote_path, self._local_dir, *args, **kwargs)
|
||||
self.last_sync_down_time = time.time()
|
||||
|
||||
def sync_up(self, *args, **kwargs):
|
||||
self.sync(self._local_dir, self._remote_path, *args, **kwargs)
|
||||
self.last_sync_up_time = time.time()
|
||||
|
||||
def reset(self):
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
"""Protected method for accessing remote_dir.
|
||||
|
||||
Can be overridden in subclass for custom path.
|
||||
"""
|
||||
return self._remote_dir
|
||||
|
||||
|
||||
class CommandSyncer(BaseSyncer):
|
||||
def __init__(self, local_dir, remote_dir, sync_template):
|
||||
"""Syncs between two directories with the given command.
|
||||
|
||||
Arguments:
|
||||
local_dir (str): Directory to sync.
|
||||
remote_dir (str): Remote directory to sync with.
|
||||
sync_template (str): A string template
|
||||
for syncer to run and needs to include replacement fields
|
||||
'{source}' and '{target}'. Returned when using
|
||||
`CommandSyncer.sync_template`, which can be overridden
|
||||
by subclass.
|
||||
"""
|
||||
super(CommandSyncer, self).__init__(local_dir, remote_dir)
|
||||
if not isinstance(sync_template, str):
|
||||
raise ValueError("{} is not a string.".format(sync_template))
|
||||
validate_sync_string(sync_template)
|
||||
self._sync_template = sync_template
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync",
|
||||
dir=self._local_dir,
|
||||
suffix=".log",
|
||||
delete=False)
|
||||
|
||||
self.sync_process = None
|
||||
|
||||
def sync_function(self, source, target):
|
||||
self.last_sync_time = time.time()
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
logger.warning("Last sync is still in progress, skipping.")
|
||||
return
|
||||
final_cmd = self._sync_template.format(
|
||||
source=quote(source), target=quote(target))
|
||||
logger.debug("Running sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stdout=self.logfile)
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
if self.sync_process:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self.sync_process = None
|
||||
super(CommandSyncer, self).reset()
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
self.sync_process.wait()
|
||||
|
||||
|
||||
def _get_sync_cls(sync_function):
|
||||
if not sync_function:
|
||||
return
|
||||
if isinstance(sync_function, types.FunctionType) or isinstance(
|
||||
sync_function, tune_function):
|
||||
return BaseSyncer
|
||||
elif isinstance(sync_function, str):
|
||||
return CommandSyncer
|
||||
else:
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
"""Returns a Syncer depending on given args.
|
||||
|
||||
This syncer is in charge of syncing the local_dir with upload_dir.
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If None,
|
||||
returns BaseSyncer with a noop.
|
||||
sync_function (func | str): Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults
|
||||
to standard S3 or gsutil sync commands.
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
|
||||
if not remote_dir:
|
||||
_syncers[key] = BaseSyncer(local_dir, remote_dir)
|
||||
return _syncers[key]
|
||||
|
||||
sync_cls = _get_sync_cls(sync_function)
|
||||
|
||||
if sync_cls:
|
||||
_syncers[key] = sync_cls(local_dir, remote_dir, sync_function)
|
||||
return _syncers[key]
|
||||
|
||||
if remote_dir.startswith(S3_PREFIX):
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX))
|
||||
_syncers[key] = CommandSyncer(local_dir, remote_dir,
|
||||
"aws s3 sync {source} {target}")
|
||||
elif remote_dir.startswith(GS_PREFIX):
|
||||
if not distutils.spawn.find_executable("gsutil"):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GS_PREFIX))
|
||||
_syncers[key] = CommandSyncer(local_dir, remote_dir,
|
||||
"gsutil rsync -r {source} {target}")
|
||||
else:
|
||||
raise TuneError("Upload uri must start with one of: {}"
|
||||
"".format(ALLOWED_REMOTE_PREFIXES))
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
def get_log_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
"""Returns a Syncer depending on given args.
|
||||
|
||||
This syncer is in charge of syncing the local_dir with remote local_dir.
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If None,
|
||||
returns BaseSyncer with noop.
|
||||
sync_function (func | str): Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults rsync.
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
|
||||
sync_cls = None
|
||||
if sync_function:
|
||||
sync_cls = _get_sync_cls(sync_function)
|
||||
else:
|
||||
sync_cls = CommandSyncer
|
||||
sync_function = log_sync_template()
|
||||
|
||||
if not remote_dir or sync_function is None:
|
||||
sync_cls = BaseSyncer
|
||||
|
||||
class MixedSyncer(NodeSyncMixin, sync_cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
sync_cls.__init__(self, *args, **kwargs)
|
||||
NodeSyncMixin.__init__(self)
|
||||
|
||||
_syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function)
|
||||
return _syncers[key]
|
||||
@@ -272,8 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=dirpath)
|
||||
runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
@@ -295,7 +294,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
ray.shutdown()
|
||||
|
||||
cluster = _start_new_cluster()
|
||||
runner = TrialRunner.restore(dirpath)
|
||||
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
|
||||
runner.step() # start
|
||||
runner.step() # start2
|
||||
|
||||
@@ -377,18 +376,19 @@ tune.run_experiments(
|
||||
# Wait until the right checkpoint is saved.
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
local_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(100):
|
||||
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
runner = TrialRunner(
|
||||
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res and last_res.get("training_iteration"):
|
||||
break
|
||||
time.sleep(0.3)
|
||||
|
||||
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
@@ -469,18 +469,19 @@ tune.run_experiments(
|
||||
# Wait until the right checkpoint is saved.
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
local_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(50):
|
||||
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
runner = TrialRunner(
|
||||
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res and last_res.get("training_iteration") == 3:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
@@ -489,7 +490,8 @@ tune.run_experiments(
|
||||
Experiment._register_if_needed(_Mock)
|
||||
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
runner = TrialRunner(
|
||||
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
assert trials[0].last_result["training_iteration"] == 3
|
||||
assert trials[0].status == Trial.PENDING
|
||||
|
||||
@@ -114,8 +114,8 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
runner_data = self.ea.runner_data()
|
||||
|
||||
self.assertTrue(isinstance(runner_data, dict))
|
||||
self.assertTrue("_metadata_checkpoint_dir" in runner_data)
|
||||
self.assertEqual(runner_data["_metadata_checkpoint_dir"],
|
||||
self.assertTrue("_local_checkpoint_dir" in runner_data)
|
||||
self.assertEqual(runner_data["_local_checkpoint_dir"],
|
||||
os.path.expanduser(self.test_path))
|
||||
|
||||
def testBestLogdir(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
@@ -315,21 +316,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
}
|
||||
})
|
||||
|
||||
def testUploadDirNone(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"upload_dir": None,
|
||||
"config": {
|
||||
"a": "b"
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertFalse(trial.upload_dir)
|
||||
|
||||
def testLogdirStartingWithTilde(self):
|
||||
local_dir = "~/ray_results/local_dir"
|
||||
|
||||
@@ -930,50 +916,190 @@ class RunExperimentTest(unittest.TestCase):
|
||||
str(trial), "{}_{}_321".format(trial.trainable_name,
|
||||
trial.trial_id))
|
||||
|
||||
def testSyncFunction(self):
|
||||
def fail_sync_local():
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
|
||||
class TestSyncFunctionality(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
@patch("ray.tune.syncer.S3_PREFIX", "test")
|
||||
def testNoUploadDir(self):
|
||||
"""No Upload Dir is given."""
|
||||
with self.assertRaises(AssertionError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_cloud": "echo {source} {target}"
|
||||
})
|
||||
|
||||
@patch("ray.tune.syncer.S3_PREFIX", "test")
|
||||
def testCloudProperString(self):
|
||||
with self.assertRaises(ValueError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": "ls {remote_dir}"
|
||||
}
|
||||
})
|
||||
"sync_to_cloud": "ls {target}"
|
||||
})
|
||||
|
||||
self.assertRaises(AssertionError, fail_sync_local)
|
||||
|
||||
def fail_sync_remote():
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
with self.assertRaises(ValueError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": "ls {local_dir}"
|
||||
}
|
||||
})
|
||||
"sync_to_cloud": "ls {source}"
|
||||
})
|
||||
|
||||
self.assertRaises(AssertionError, fail_sync_remote)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
logfile = os.path.join(tmpdir, "test.log")
|
||||
|
||||
def sync_func(local, remote):
|
||||
with open(os.path.join(local, "test.log"), "w") as f:
|
||||
f.write(remote)
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": tune.function(sync_func)
|
||||
}
|
||||
})
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
"sync_to_cloud": "echo {source} {target} > " + logfile
|
||||
})
|
||||
with open(logfile) as f:
|
||||
lines = f.read()
|
||||
self.assertTrue("test" in lines)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterProperString(self):
|
||||
"""Tests that invalid commands throw.."""
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "ls {target}"
|
||||
})
|
||||
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "ls {source}"
|
||||
})
|
||||
|
||||
with patch("ray.tune.syncer.CommandSyncer.sync_function"
|
||||
) as mock_fn, patch(
|
||||
"ray.services.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"sync_to_driver": "echo {source} {target}"
|
||||
})
|
||||
self.assertGreater(mock_fn.call_count, 0)
|
||||
|
||||
def testCloudFunctions(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
tmpdir2 = tempfile.mkdtemp()
|
||||
os.mkdir(os.path.join(tmpdir2, "foo"))
|
||||
|
||||
def sync_func(local, remote):
|
||||
for filename in glob.glob(os.path.join(local, "*.json")):
|
||||
shutil.copy(filename, remote)
|
||||
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
local_dir=tmpdir,
|
||||
stop={"training_iteration": 1},
|
||||
upload_dir=tmpdir2,
|
||||
sync_to_cloud=tune.function(sync_func))
|
||||
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
|
||||
self.assertTrue(test_file_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
shutil.rmtree(tmpdir2)
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
def sync_func_driver(source, target):
|
||||
assert ":" in source, "Source not a remote path."
|
||||
assert ":" not in target, "Target is supposed to be local."
|
||||
with open(os.path.join(target, "test.log2"), "w") as f:
|
||||
print("writing to", f.name)
|
||||
f.write(source)
|
||||
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_to_driver=tune.function(sync_func_driver))
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertFalse(os.path.exists(test_file_path))
|
||||
|
||||
with patch("ray.services.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_to_driver=tune.function(sync_func_driver))
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertTrue(os.path.exists(test_file_path))
|
||||
os.remove(test_file_path)
|
||||
|
||||
def testNoSync(self):
|
||||
def sync_func(source, target):
|
||||
pass
|
||||
|
||||
with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync:
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_to_driver": tune.function(sync_func),
|
||||
"sync_to_cloud": tune.function(sync_func)
|
||||
})
|
||||
self.assertEqual(mock_sync.call_count, 0)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
@@ -1960,7 +2086,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
trials = [
|
||||
Trial(
|
||||
"__fake",
|
||||
@@ -1999,7 +2125,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEquals(trials[2].status, Trial.RUNNING)
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
for tid in ["trial_terminate", "trial_fail"]:
|
||||
original_trial = runner.get_trial(tid)
|
||||
restored_trial = runner2.get_trial(tid)
|
||||
@@ -2019,7 +2145,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
@@ -2051,7 +2177,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
runner.step()
|
||||
runner.step()
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
new_trials = runner2.get_trials()
|
||||
self.assertEquals(len(new_trials), 3)
|
||||
self.assertTrue(
|
||||
@@ -2074,13 +2200,13 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
},
|
||||
checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
# force checkpoint
|
||||
runner.checkpoint()
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
new_trial = runner2.get_trials()[0]
|
||||
self.assertTrue("callbacks" in new_trial.config)
|
||||
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
|
||||
@@ -2095,7 +2221,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init()
|
||||
trial = Trial("__fake", checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(metadata_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
@@ -2103,7 +2229,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
runner.checkpoint()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 1)
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
for i in range(5):
|
||||
runner2.step()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
|
||||
@@ -19,7 +19,6 @@ from six import string_types
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.log_sync import validate_sync_function
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
@@ -276,10 +275,9 @@ class Trial(object):
|
||||
checkpoint_score_attr="",
|
||||
export_formats=None,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
sync_function=None,
|
||||
sync_to_driver_fn=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
@@ -308,10 +306,8 @@ class Trial(object):
|
||||
resources = default_resources
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.upload_dir = upload_dir
|
||||
self.loggers = loggers
|
||||
self.sync_function = sync_function
|
||||
validate_sync_function(sync_function)
|
||||
self.sync_to_driver_fn = sync_to_driver_fn
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
@@ -352,7 +348,7 @@ class Trial(object):
|
||||
self._nonjson_fields = [
|
||||
"_checkpoint",
|
||||
"loggers",
|
||||
"sync_function",
|
||||
"sync_to_driver_fn",
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
@@ -394,9 +390,8 @@ class Trial(object):
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config,
|
||||
self.logdir,
|
||||
upload_uri=self.upload_dir,
|
||||
loggers=self.loggers,
|
||||
sync_function=self.sync_function)
|
||||
sync_function=self.sync_to_driver_fn)
|
||||
|
||||
def update_resources(self, cpu, gpu, **kwargs):
|
||||
"""EXPERIMENTAL: Updates the resource requirements.
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import json
|
||||
@@ -15,6 +16,7 @@ import ray.cloudpickle as cloudpickle
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.syncer import get_syncer
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.sample import function
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
@@ -97,12 +99,16 @@ class TrialRunner(object):
|
||||
"""
|
||||
|
||||
CKPT_FILE_TMPL = "experiment_state-{}.json"
|
||||
VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT"]
|
||||
|
||||
def __init__(self,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
metadata_checkpoint_dir=None,
|
||||
local_checkpoint_dir=None,
|
||||
remote_checkpoint_dir=None,
|
||||
sync_to_cloud=None,
|
||||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
trial_executor=None):
|
||||
@@ -113,15 +119,16 @@ class TrialRunner(object):
|
||||
Trial objects.
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
metadata_checkpoint_dir (str): Path where
|
||||
local_checkpoint_dir (str): Path where
|
||||
global checkpoints are stored and restored from.
|
||||
server_port (int): Port number for launching TuneServer
|
||||
remote_checkpoint_dir (str): Remote path where
|
||||
global checkpoints are stored and restored from. Used
|
||||
if `resume` == REMOTE.
|
||||
resume (str|False): see `tune.py:run`.
|
||||
sync_to_cloud (func|str): see `tune.py:run`.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
reuse_actors (bool): Whether to reuse actors between different
|
||||
trials when possible. This can drastically speed up experiments
|
||||
that start and stop actors often (e.g., PBT in
|
||||
time-multiplexing mode).
|
||||
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
|
||||
"""
|
||||
self._search_alg = search_alg or BasicVariantGenerator()
|
||||
@@ -143,12 +150,73 @@ class TrialRunner(object):
|
||||
|
||||
self._trials = []
|
||||
self._stop_queue = []
|
||||
self._metadata_checkpoint_dir = metadata_checkpoint_dir
|
||||
self._local_checkpoint_dir = local_checkpoint_dir
|
||||
|
||||
if self._local_checkpoint_dir and not os.path.exists(
|
||||
self._local_checkpoint_dir):
|
||||
os.makedirs(self._local_checkpoint_dir)
|
||||
|
||||
self._remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir,
|
||||
sync_to_cloud)
|
||||
|
||||
self._resumed = False
|
||||
|
||||
if self._validate_resume(resume_type=resume):
|
||||
try:
|
||||
self.resume()
|
||||
logger.info("Resuming trial.")
|
||||
self._resumed = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Runner restore failed. Restarting experiment.")
|
||||
else:
|
||||
logger.info("Starting a new experiment.")
|
||||
|
||||
self._start_time = time.time()
|
||||
self._session_str = datetime.fromtimestamp(
|
||||
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
def _validate_resume(self, resume_type):
|
||||
"""Checks whether to resume experiment.
|
||||
|
||||
Args:
|
||||
resume_type: One of True, "REMOTE", "LOCAL", "PROMPT".
|
||||
"""
|
||||
if not resume_type:
|
||||
return False
|
||||
assert resume_type in self.VALID_RESUME_TYPES, (
|
||||
"resume_type {} is not one of {}".format(resume_type,
|
||||
self.VALID_RESUME_TYPES))
|
||||
# Not clear if we need this assertion, since we should always have a
|
||||
# local checkpoint dir.
|
||||
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
|
||||
if resume_type in [True, "LOCAL", "PROMPT"]:
|
||||
if not self.checkpoint_exists(self._local_checkpoint_dir):
|
||||
raise ValueError("Called resume when no checkpoint exists "
|
||||
"in local directory.")
|
||||
elif resume_type == "PROMPT":
|
||||
if click.confirm("Resume from local directory?"):
|
||||
return True
|
||||
|
||||
if resume_type in ["REMOTE", "PROMPT"]:
|
||||
if resume_type == "PROMPT" and not click.confirm(
|
||||
"Try downloading from remote directory?"):
|
||||
return False
|
||||
if not self._remote_checkpoint_dir:
|
||||
raise ValueError(
|
||||
"Called resume from remote without remote directory.")
|
||||
|
||||
# Try syncing down the upload directory.
|
||||
logger.info("Downloading from {}".format(
|
||||
self._remote_checkpoint_dir))
|
||||
self._syncer.sync_down_if_needed()
|
||||
|
||||
if not self.checkpoint_exists(self._local_checkpoint_dir):
|
||||
raise ValueError("Called resume when no checkpoint exists "
|
||||
"in remote or local directory.")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def checkpoint_exists(cls, directory):
|
||||
if not os.path.exists(directory):
|
||||
@@ -157,17 +225,21 @@ class TrialRunner(object):
|
||||
(fname.startswith("experiment_state") and fname.endswith(".json"))
|
||||
for fname in os.listdir(directory))
|
||||
|
||||
def add_experiment(self, experiment):
|
||||
if not self._resumed:
|
||||
self._search_alg.add_configurations([experiment])
|
||||
else:
|
||||
logger.info("TrialRunner resumed, ignoring new add_experiment.")
|
||||
|
||||
def checkpoint(self):
|
||||
"""Saves execution state to `self._metadata_checkpoint_dir`.
|
||||
"""Saves execution state to `self._local_checkpoint_dir`.
|
||||
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated.
|
||||
"""
|
||||
if not self._metadata_checkpoint_dir:
|
||||
if not self._local_checkpoint_dir:
|
||||
return
|
||||
metadata_checkpoint_dir = self._metadata_checkpoint_dir
|
||||
if not os.path.exists(metadata_checkpoint_dir):
|
||||
os.makedirs(metadata_checkpoint_dir)
|
||||
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
@@ -177,55 +249,37 @@ class TrialRunner(object):
|
||||
"timestamp": time.time()
|
||||
}
|
||||
}
|
||||
tmp_file_name = os.path.join(metadata_checkpoint_dir,
|
||||
tmp_file_name = os.path.join(self._local_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
|
||||
|
||||
os.rename(
|
||||
tmp_file_name,
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
os.path.join(self._local_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_TMPL.format(self._session_str)))
|
||||
return metadata_checkpoint_dir
|
||||
self._syncer.sync_up_if_needed()
|
||||
return self._local_checkpoint_dir
|
||||
|
||||
@classmethod
|
||||
def restore(cls,
|
||||
metadata_checkpoint_dir,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
trial_executor=None):
|
||||
"""Restores all checkpointed trials from previous run.
|
||||
def resume(self):
|
||||
"""Resumes all checkpointed trials from previous run.
|
||||
|
||||
Requires user to manually re-register their objects. Also stops
|
||||
all ongoing trials.
|
||||
|
||||
Args:
|
||||
metadata_checkpoint_dir (str): Path to metadata checkpoints.
|
||||
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
|
||||
BasicVariantGenerator.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment.
|
||||
trial_executor (TrialExecutor): Manage the execution of trials.
|
||||
|
||||
Returns:
|
||||
runner (TrialRunner): A TrialRunner to resume experiments from.
|
||||
"""
|
||||
|
||||
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
|
||||
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
|
||||
logger.warning("".join([
|
||||
"Attempting to resume experiment from {}. ".format(
|
||||
metadata_checkpoint_dir), "This feature is experimental, "
|
||||
self._local_checkpoint_dir), "This feature is experimental, "
|
||||
"and may not work with all search algorithms. ",
|
||||
"This will ignore any new changes to the specification."
|
||||
]))
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg, scheduler=scheduler, trial_executor=trial_executor)
|
||||
|
||||
runner.__setstate__(runner_state["runner_data"])
|
||||
self.__setstate__(runner_state["runner_data"])
|
||||
|
||||
trials = []
|
||||
for trial_cp in runner_state["checkpoints"]:
|
||||
@@ -234,8 +288,7 @@ class TrialRunner(object):
|
||||
trials += [new_trial]
|
||||
for trial in sorted(
|
||||
trials, key=lambda t: t.last_update_time, reverse=True):
|
||||
runner.add_trial(trial)
|
||||
return runner
|
||||
self.add_trial(trial)
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
@@ -626,6 +679,7 @@ class TrialRunner(object):
|
||||
"_search_alg",
|
||||
"_scheduler_alg",
|
||||
"trial_executor",
|
||||
"_syncer",
|
||||
]:
|
||||
del state[k]
|
||||
state["launch_web_server"] = bool(self._server)
|
||||
|
||||
+43
-70
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -12,7 +11,7 @@ from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
from ray.tune.syncer import wait_for_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
FIFOScheduler, MedianStoppingRule)
|
||||
@@ -36,36 +35,6 @@ def _make_scheduler(args):
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def _find_checkpoint_dir(exp):
|
||||
# TODO(rliaw): Make sure the checkpoint_dir is resolved earlier.
|
||||
# Right now it is resolved somewhere far down the trial generation process
|
||||
return exp.checkpoint_dir
|
||||
|
||||
|
||||
def _prompt_restore(checkpoint_dir, resume):
|
||||
restore = False
|
||||
if TrialRunner.checkpoint_exists(checkpoint_dir):
|
||||
if resume == "prompt":
|
||||
msg = ("Found incomplete experiment at {}. "
|
||||
"Would you like to resume it?".format(checkpoint_dir))
|
||||
restore = click.confirm(msg, default=False)
|
||||
if restore:
|
||||
logger.info("Tip: to always resume, "
|
||||
"pass resume=True to run()")
|
||||
else:
|
||||
logger.info("Tip: to always start a new experiment, "
|
||||
"pass resume=False to run()")
|
||||
elif resume:
|
||||
restore = True
|
||||
else:
|
||||
logger.info("Tip: to resume incomplete experiments, "
|
||||
"pass resume='prompt' or resume=True to run()")
|
||||
else:
|
||||
logger.info(
|
||||
"Did not find checkpoint file in {}.".format(checkpoint_dir))
|
||||
return restore
|
||||
|
||||
|
||||
def run(run_or_experiment,
|
||||
name=None,
|
||||
stop=None,
|
||||
@@ -76,7 +45,8 @@ def run(run_or_experiment,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
sync_function=None,
|
||||
sync_to_cloud=None,
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
export_formats=None,
|
||||
@@ -93,7 +63,8 @@ def run(run_or_experiment,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
return_trials=True,
|
||||
ray_auto_init=True):
|
||||
ray_auto_init=True,
|
||||
sync_function=None):
|
||||
"""Executes training.
|
||||
|
||||
Args:
|
||||
@@ -129,10 +100,15 @@ def run(run_or_experiment,
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
upload_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, the sync command defaults
|
||||
to standard S3 or gsutil sync comamnds.
|
||||
sync_to_cloud (func|str): Function for syncing the local_dir to and
|
||||
from upload_dir. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If not provided, the sync command defaults to standard
|
||||
S3 or gsutil sync comamnds.
|
||||
sync_to_driver (func|str): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
If not provided, defaults to using rsync.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
@@ -155,9 +131,12 @@ def run(run_or_experiment,
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
|
||||
1 = only status updates, 2 = status and trial results.
|
||||
resume (bool|"prompt"): If checkpoint exists, the experiment will
|
||||
resume from there. If resume is "prompt", Tune will prompt if
|
||||
checkpoint detected.
|
||||
resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool.
|
||||
LOCAL/True restores the checkpoint from the local_checkpoint_dir.
|
||||
REMOTE restores the checkpoint from remote_checkpoint_dir.
|
||||
PROMPT provides CLI feedback. False forces a new
|
||||
experiment. If resume is set but checkpoint does not exist,
|
||||
ValueError will be thrown.
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
@@ -172,6 +151,8 @@ def run(run_or_experiment,
|
||||
ray_auto_init (bool): Automatically starts a local Ray cluster
|
||||
if using a RayTrialExecutor (which is the default) and
|
||||
if Ray is not initialized. Defaults to True.
|
||||
sync_function: Deprecated. See `sync_to_cloud` and
|
||||
`sync_to_driver`.
|
||||
|
||||
Returns:
|
||||
List of Trial objects.
|
||||
@@ -199,53 +180,45 @@ def run(run_or_experiment,
|
||||
ray_auto_init=ray_auto_init)
|
||||
experiment = run_or_experiment
|
||||
if not isinstance(run_or_experiment, Experiment):
|
||||
run_identifier = Experiment._register_if_needed(run_or_experiment)
|
||||
experiment = Experiment(
|
||||
name=name,
|
||||
run=run_or_experiment,
|
||||
run=run_identifier,
|
||||
stop=stop,
|
||||
config=config,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
sync_to_driver=sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
loggers=loggers,
|
||||
sync_function=sync_function,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
export_formats=export_formats,
|
||||
max_failures=max_failures,
|
||||
restore=restore)
|
||||
restore=restore,
|
||||
sync_function=sync_function)
|
||||
else:
|
||||
logger.debug("Ignoring some parameters passed into tune.run.")
|
||||
|
||||
checkpoint_dir = _find_checkpoint_dir(experiment)
|
||||
should_restore = _prompt_restore(checkpoint_dir, resume)
|
||||
if sync_to_cloud:
|
||||
assert experiment.remote_checkpoint_dir, (
|
||||
"Need `upload_dir` if `sync_to_cloud` given.")
|
||||
|
||||
runner = None
|
||||
if should_restore:
|
||||
try:
|
||||
runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler,
|
||||
trial_executor)
|
||||
except Exception:
|
||||
logger.exception("Runner restore failed. Restarting experiment.")
|
||||
else:
|
||||
logger.info("Starting a new experiment.")
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg or BasicVariantGenerator(),
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
local_checkpoint_dir=experiment.checkpoint_dir,
|
||||
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
resume=resume,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
trial_executor=trial_executor)
|
||||
|
||||
if not runner:
|
||||
scheduler = scheduler or FIFOScheduler()
|
||||
search_alg = search_alg or BasicVariantGenerator()
|
||||
|
||||
search_alg.add_configurations([experiment])
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
metadata_checkpoint_dir=checkpoint_dir,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
trial_executor=trial_executor)
|
||||
runner.add_experiment(experiment)
|
||||
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
@@ -261,7 +234,7 @@ def run(run_or_experiment,
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
wait_for_log_sync()
|
||||
wait_for_sync()
|
||||
|
||||
errored_trials = []
|
||||
for trial in runner.get_trials():
|
||||
|
||||
Reference in New Issue
Block a user