mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[tune] logging fixes, better warnings, better cluster support (#3906)
This commit is contained in:
@@ -12,12 +12,14 @@ def get_ssh_user():
|
||||
return getpass.getuser()
|
||||
|
||||
|
||||
# TODO(ekl) this currently only works for clusters launched with
|
||||
# ray create_or_update
|
||||
def get_ssh_key():
|
||||
"""Returns ssh key to connecting to cluster workers."""
|
||||
"""Returns ssh key to connecting to cluster workers.
|
||||
|
||||
path = os.path.expanduser("~/ray_bootstrap_key.pem")
|
||||
If the env var TUNE_CLUSTER_SSH_KEY is provided, then this key
|
||||
will be used for syncing across different nodes.
|
||||
"""
|
||||
path = os.environ.get("TUNE_CLUSTER_SSH_KEY",
|
||||
os.path.expanduser("~/ray_bootstrap_key.pem"))
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
return None
|
||||
|
||||
@@ -32,11 +32,14 @@ class StatusReporter(object):
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
||||
Pass in `done=True` when the training job is completed.
|
||||
|
||||
Args:
|
||||
kwargs: Latest training result status.
|
||||
|
||||
Example:
|
||||
>>> reporter(mean_accuracy=1, training_iteration=4)
|
||||
>>> reporter(mean_accuracy=1, training_iteration=4, done=True)
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
@@ -48,6 +51,9 @@ class StatusReporter(object):
|
||||
if self._done and not self._latest_result:
|
||||
if not self._last_result:
|
||||
raise TuneError("Trial finished without reporting result!")
|
||||
logger.warning("Trial detected as completed; re-reporting "
|
||||
"last result. To avoid this, include done=True "
|
||||
"upon the last reporter call.")
|
||||
self._last_result.update(done=True)
|
||||
return self._last_result
|
||||
with self._lock:
|
||||
|
||||
@@ -117,6 +117,9 @@ class _LogSyncer(object):
|
||||
logger.debug("Created LogSyncer for {} -> {}".format(
|
||||
local_dir, remote_dir))
|
||||
|
||||
def close(self):
|
||||
self.logfile.close()
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
"""Set the worker ip to sync logs from."""
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
@@ -86,7 +86,7 @@ class UnifiedLogger(Logger):
|
||||
upload_uri=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None):
|
||||
self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger]
|
||||
self._logger_list = [_JsonLogger, _TFLogger, _CSVLogger]
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
if custom_loggers:
|
||||
@@ -101,7 +101,7 @@ class UnifiedLogger(Logger):
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
except Exception:
|
||||
logger.exception("Could not instantiate {} - skipping.".format(
|
||||
logger.warning("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
@@ -116,6 +116,7 @@ class UnifiedLogger(Logger):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.close()
|
||||
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
@@ -216,7 +217,7 @@ class _TFLogger(Logger):
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
class _VisKitLogger(Logger):
|
||||
class _CSVLogger(Logger):
|
||||
def _init(self):
|
||||
"""CSV outputted with Headers as first set of results."""
|
||||
# Note that we assume params.json was already created by JsonLogger
|
||||
@@ -230,7 +231,9 @@ class _VisKitLogger(Logger):
|
||||
self._csv_out = csv.DictWriter(self._file, result.keys())
|
||||
if not self._continuing:
|
||||
self._csv_out.writeheader()
|
||||
self._csv_out.writerow(result.copy())
|
||||
self._csv_out.writerow(
|
||||
{k: v
|
||||
for k, v in result.items() if k in self._csv_out.fieldnames})
|
||||
|
||||
def flush(self):
|
||||
self._file.flush()
|
||||
|
||||
+10
-6
@@ -57,7 +57,7 @@ def run_experiments(experiments,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
verbose=2,
|
||||
resume=False,
|
||||
queue_trials=False,
|
||||
trial_executor=None,
|
||||
@@ -75,7 +75,8 @@ def run_experiments(experiments,
|
||||
with_server (bool): Starts a background Tune server. Needed for
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): How much output should be printed for each trial.
|
||||
verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent,
|
||||
1 = only status updates, 2 = status and trial results.
|
||||
resume (bool|"prompt"): If checkpoint exists, the experiment will
|
||||
resume from there. If resume is "prompt", Tune will prompt if
|
||||
checkpoint detected.
|
||||
@@ -158,20 +159,23 @@ def run_experiments(experiments,
|
||||
metadata_checkpoint_dir=checkpoint_dir,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
verbose=int(verbose > 1),
|
||||
queue_trials=queue_trials,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
|
||||
print(runner.debug_string())
|
||||
if verbose:
|
||||
print(runner.debug_string())
|
||||
last_debug = time.time()
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
wait_for_log_sync()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user