[tune] Sync logs from workers and improve tensorboard reporting (#1567)

This commit is contained in:
Eric Liang
2018-02-26 11:35:51 -08:00
committed by Richard Liaw
parent aefefcb0cd
commit 87e107edd8
6 changed files with 133 additions and 35 deletions
+23
View File
@@ -0,0 +1,23 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import getpass
import os
def get_ssh_user():
"""Returns ssh username for connecting to cluster workers."""
return getpass.getuser()
# TODO(ekl) this currently only works for clusters launched with
# ray create_or_update
def get_ssh_key():
"""Returns ssh key to connecting to cluster workers."""
path = os.path.expanduser("~/ray_bootstrap_key.pem")
if os.path.exists(path):
return path
return None
+71 -17
View File
@@ -4,9 +4,12 @@ from __future__ import print_function
import distutils.spawn
import os
import pipes
import subprocess
import time
import ray
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
@@ -15,20 +18,21 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
_syncers = {}
def get_syncer(local_dir, remote_dir):
if not remote_dir.startswith("s3://"):
raise TuneError("Upload uri must start with s3://")
def get_syncer(local_dir, remote_dir=None):
if remote_dir:
if not remote_dir.startswith("s3://"):
raise TuneError("Upload uri must start with s3://")
if not distutils.spawn.find_executable("aws"):
raise TuneError("Upload uri requires awscli tool to be installed")
if not distutils.spawn.find_executable("aws"):
raise TuneError("Upload uri requires awscli tool to be installed")
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
remote_dir = os.path.join(remote_dir, rel_path)
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
remote_dir = os.path.join(remote_dir, rel_path)
key = (local_dir, remote_dir)
if key not in _syncers:
_syncers[key] = _S3LogSyncer(local_dir, remote_dir)
_syncers[key] = _LogSyncer(local_dir, remote_dir)
return _syncers[key]
@@ -38,23 +42,64 @@ def wait_for_log_sync():
syncer.wait()
class _S3LogSyncer(object):
def __init__(self, local_dir, remote_dir):
class _LogSyncer(object):
"""Log syncer for tune.
This syncs files from workers to the local node, and optionally also from
the local node to a remote directory (e.g. S3)."""
def __init__(self, local_dir, remote_dir=None):
self.local_dir = local_dir
self.remote_dir = remote_dir
self.last_sync_time = 0
self.sync_process = None
print("Created S3LogSyncer for {} -> {}".format(local_dir, remote_dir))
self.local_ip = ray.services.get_node_ip_address()
self.worker_ip = None
print("Created LogSyncer for {} -> {}".format(local_dir, remote_dir))
def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip
def sync_if_needed(self):
if time.time() - self.last_sync_time > 300:
self.sync_now()
def sync_now(self, force=False):
print(
"Syncing files from {} -> {}".format(
self.local_dir, self.remote_dir))
self.last_sync_time = time.time()
if not self.worker_ip:
print(
"Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return
if self.worker_ip == self.local_ip:
worker_to_local_sync_cmd = None # don't need to rsync
else:
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
if ssh_key is None or ssh_user is None:
print(
"Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
return
if not distutils.spawn.find_executable("rsync"):
print("Error: log sync requires rsync to be installed.")
return
worker_to_local_sync_cmd = (
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
ssh_key, ssh_user, self.worker_ip,
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
if self.remote_dir:
local_to_remote_sync_cmd = (
"aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
else:
local_to_remote_sync_cmd = None
if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
@@ -63,8 +108,17 @@ class _S3LogSyncer(object):
else:
print("Warning: last sync is still in progress, skipping")
return
self.sync_process = subprocess.Popen(
["aws", "s3", "sync", self.local_dir, self.remote_dir])
if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
final_cmd = ""
if worker_to_local_sync_cmd:
final_cmd += worker_to_local_sync_cmd
if local_to_remote_sync_cmd:
if final_cmd:
final_cmd += " && "
final_cmd += local_to_remote_sync_cmd
print("Running log sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(final_cmd, shell=True)
def wait(self):
if self.sync_process:
+32 -18
View File
@@ -24,10 +24,6 @@ class Logger(object):
multiple formats (TensorBoard, rllab/viskit, plain json) at once.
"""
_attrs_to_log = [
"time_this_iter_s", "mean_loss", "mean_accuracy",
"episode_reward_mean", "episode_len_mean"]
def __init__(self, config, logdir, upload_uri=None):
self.config = config
self.logdir = logdir
@@ -47,6 +43,11 @@ class Logger(object):
pass
def flush(self):
"""Flushes all disk writes to storage."""
pass
class UnifiedLogger(Logger):
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
@@ -60,22 +61,22 @@ class UnifiedLogger(Logger):
print("TF not installed - cannot log with {}...".format(cls))
continue
self._loggers.append(cls(self.config, self.logdir, self.uri))
if self.uri:
self._log_syncer = get_syncer(self.logdir, self.uri)
else:
self._log_syncer = None
self._log_syncer = get_syncer(self.logdir, self.uri)
def on_result(self, result):
for logger in self._loggers:
logger.on_result(result)
if self._log_syncer:
self._log_syncer.sync_if_needed()
self._log_syncer.set_worker_ip(result.node_ip)
self._log_syncer.sync_if_needed()
def close(self):
for logger in self._loggers:
logger.close()
if self._log_syncer:
self._log_syncer.sync_now(force=True)
self._log_syncer.sync_now(force=True)
def flush(self):
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()
class NoopLogger(Logger):
@@ -103,17 +104,30 @@ class _JsonLogger(Logger):
self.local_out.close()
def to_tf_values(result, path):
values = []
for attr, value in result.items():
if value is not None:
if type(value) in [int, float]:
values.append(tf.Summary.Value(
tag="/".join(path + [attr]),
simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
return values
class _TFLogger(Logger):
def _init(self):
self._file_writer = tf.summary.FileWriter(self.logdir)
def on_result(self, result):
values = []
for attr in Logger._attrs_to_log:
if getattr(result, attr) is not None:
values.append(tf.Summary.Value(
tag="ray/tune/{}".format(attr),
simple_value=getattr(result, attr)))
tmp = result._asdict()
for k in [
"config", "pid", "timestamp", "time_total_s",
"timesteps_total"]:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
self._file_writer.add_summary(train_stats, result.timesteps_total)
+3
View File
@@ -85,6 +85,9 @@ TrainingResult = namedtuple("TrainingResult", [
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",
# (Auto-filled) The node ip of the machine hosting the training process.
"node_ip",
# (Auto=filled) The current hyperparameter configuration.
"config",
])
+3
View File
@@ -13,6 +13,7 @@ import tempfile
import time
import uuid
import ray
from ray.tune import TuneError
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
@@ -87,6 +88,7 @@ class Trainable(object):
self._timesteps_total = 0
self._setup()
self._initialize_ok = True
self._local_ip = ray.services.get_node_ip_address()
def train(self):
"""Runs one logical iteration of training.
@@ -136,6 +138,7 @@ class Trainable(object):
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1],
node_ip=self._local_ip,
config=self.config)
self._result_logger.on_result(result)
+1
View File
@@ -252,6 +252,7 @@ class TrialRunner(object):
try:
print("Attempting to recover trial state from last checkpoint")
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
trial.result_logger.flush() # make sure checkpoint is synced
trial.start()
self._running[trial.train_remote()] = trial
except Exception: