[tune] logging fixes, better warnings, better cluster support (#3906)

This commit is contained in:
Richard Liaw
2019-02-02 19:14:03 -08:00
committed by GitHub
parent 002531b199
commit eab6dd72b5
5 changed files with 32 additions and 14 deletions
+6 -4
View File
@@ -12,12 +12,14 @@ def get_ssh_user():
return getpass.getuser()
# TODO(ekl) this currently only works for clusters launched with
# ray create_or_update
def get_ssh_key():
"""Returns ssh key to connecting to cluster workers."""
"""Returns ssh key to connecting to cluster workers.
path = os.path.expanduser("~/ray_bootstrap_key.pem")
If the env var TUNE_CLUSTER_SSH_KEY is provided, then this key
will be used for syncing across different nodes.
"""
path = os.environ.get("TUNE_CLUSTER_SSH_KEY",
os.path.expanduser("~/ray_bootstrap_key.pem"))
if os.path.exists(path):
return path
return None
+6
View File
@@ -32,11 +32,14 @@ class StatusReporter(object):
def __call__(self, **kwargs):
"""Report updated training status.
Pass in `done=True` when the training job is completed.
Args:
kwargs: Latest training result status.
Example:
>>> reporter(mean_accuracy=1, training_iteration=4)
>>> reporter(mean_accuracy=1, training_iteration=4, done=True)
"""
with self._lock:
@@ -48,6 +51,9 @@ class StatusReporter(object):
if self._done and not self._latest_result:
if not self._last_result:
raise TuneError("Trial finished without reporting result!")
logger.warning("Trial detected as completed; re-reporting "
"last result. To avoid this, include done=True "
"upon the last reporter call.")
self._last_result.update(done=True)
return self._last_result
with self._lock:
+3
View File
@@ -117,6 +117,9 @@ class _LogSyncer(object):
logger.debug("Created LogSyncer for {} -> {}".format(
local_dir, remote_dir))
def close(self):
self.logfile.close()
def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip
+7 -4
View File
@@ -86,7 +86,7 @@ class UnifiedLogger(Logger):
upload_uri=None,
custom_loggers=None,
sync_function=None):
self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger]
self._logger_list = [_JsonLogger, _TFLogger, _CSVLogger]
self._sync_function = sync_function
self._log_syncer = None
if custom_loggers:
@@ -101,7 +101,7 @@ class UnifiedLogger(Logger):
try:
self._loggers.append(cls(self.config, self.logdir, self.uri))
except Exception:
logger.exception("Could not instantiate {} - skipping.".format(
logger.warning("Could not instantiate {} - skipping.".format(
str(cls)))
self._log_syncer = get_syncer(
self.logdir, self.uri, sync_function=self._sync_function)
@@ -116,6 +116,7 @@ class UnifiedLogger(Logger):
for _logger in self._loggers:
_logger.close()
self._log_syncer.sync_now(force=True)
self._log_syncer.close()
def flush(self):
for _logger in self._loggers:
@@ -216,7 +217,7 @@ class _TFLogger(Logger):
self._file_writer.close()
class _VisKitLogger(Logger):
class _CSVLogger(Logger):
def _init(self):
"""CSV outputted with Headers as first set of results."""
# Note that we assume params.json was already created by JsonLogger
@@ -230,7 +231,9 @@ class _VisKitLogger(Logger):
self._csv_out = csv.DictWriter(self._file, result.keys())
if not self._continuing:
self._csv_out.writeheader()
self._csv_out.writerow(result.copy())
self._csv_out.writerow(
{k: v
for k, v in result.items() if k in self._csv_out.fieldnames})
def flush(self):
self._file.flush()
+10 -6
View File
@@ -57,7 +57,7 @@ def run_experiments(experiments,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
verbose=2,
resume=False,
queue_trials=False,
trial_executor=None,
@@ -75,7 +75,8 @@ def run_experiments(experiments,
with_server (bool): Starts a background Tune server. Needed for
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (bool): How much output should be printed for each trial.
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
1 = only status updates, 2 = status and trial results.
resume (bool|"prompt"): If checkpoint exists, the experiment will
resume from there. If resume is "prompt", Tune will prompt if
checkpoint detected.
@@ -158,20 +159,23 @@ def run_experiments(experiments,
metadata_checkpoint_dir=checkpoint_dir,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose,
verbose=int(verbose > 1),
queue_trials=queue_trials,
trial_executor=trial_executor)
print(runner.debug_string(max_debug=99999))
if verbose:
print(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
print(runner.debug_string())
if verbose:
print(runner.debug_string())
last_debug = time.time()
print(runner.debug_string(max_debug=99999))
if verbose:
print(runner.debug_string(max_debug=99999))
wait_for_log_sync()