mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:49:16 +08:00
[tune] Sync logs from workers and improve tensorboard reporting (#1567)
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import getpass
|
||||
import os
|
||||
|
||||
|
||||
def get_ssh_user():
|
||||
"""Returns ssh username for connecting to cluster workers."""
|
||||
|
||||
return getpass.getuser()
|
||||
|
||||
|
||||
# TODO(ekl) this currently only works for clusters launched with
|
||||
# ray create_or_update
|
||||
def get_ssh_key():
|
||||
"""Returns ssh key to connecting to cluster workers."""
|
||||
|
||||
path = os.path.expanduser("~/ray_bootstrap_key.pem")
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
return None
|
||||
+71
-17
@@ -4,9 +4,12 @@ from __future__ import print_function
|
||||
|
||||
import distutils.spawn
|
||||
import os
|
||||
import pipes
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
@@ -15,20 +18,21 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
_syncers = {}
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir):
|
||||
if not remote_dir.startswith("s3://"):
|
||||
raise TuneError("Upload uri must start with s3://")
|
||||
def get_syncer(local_dir, remote_dir=None):
|
||||
if remote_dir:
|
||||
if not remote_dir.startswith("s3://"):
|
||||
raise TuneError("Upload uri must start with s3://")
|
||||
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise TuneError("Upload uri requires awscli tool to be installed")
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise TuneError("Upload uri requires awscli tool to be installed")
|
||||
|
||||
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
|
||||
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
|
||||
remote_dir = os.path.join(remote_dir, rel_path)
|
||||
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
|
||||
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
|
||||
remote_dir = os.path.join(remote_dir, rel_path)
|
||||
|
||||
key = (local_dir, remote_dir)
|
||||
if key not in _syncers:
|
||||
_syncers[key] = _S3LogSyncer(local_dir, remote_dir)
|
||||
_syncers[key] = _LogSyncer(local_dir, remote_dir)
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
@@ -38,23 +42,64 @@ def wait_for_log_sync():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
class _S3LogSyncer(object):
|
||||
def __init__(self, local_dir, remote_dir):
|
||||
class _LogSyncer(object):
|
||||
"""Log syncer for tune.
|
||||
|
||||
This syncs files from workers to the local node, and optionally also from
|
||||
the local node to a remote directory (e.g. S3)."""
|
||||
|
||||
def __init__(self, local_dir, remote_dir=None):
|
||||
self.local_dir = local_dir
|
||||
self.remote_dir = remote_dir
|
||||
self.last_sync_time = 0
|
||||
self.sync_process = None
|
||||
print("Created S3LogSyncer for {} -> {}".format(local_dir, remote_dir))
|
||||
self.local_ip = ray.services.get_node_ip_address()
|
||||
self.worker_ip = None
|
||||
print("Created LogSyncer for {} -> {}".format(local_dir, remote_dir))
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
"""Set the worker ip to sync logs from."""
|
||||
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
def sync_if_needed(self):
|
||||
if time.time() - self.last_sync_time > 300:
|
||||
self.sync_now()
|
||||
|
||||
def sync_now(self, force=False):
|
||||
print(
|
||||
"Syncing files from {} -> {}".format(
|
||||
self.local_dir, self.remote_dir))
|
||||
self.last_sync_time = time.time()
|
||||
if not self.worker_ip:
|
||||
print(
|
||||
"Worker ip unknown, skipping log sync for {}".format(
|
||||
self.local_dir))
|
||||
return
|
||||
|
||||
if self.worker_ip == self.local_ip:
|
||||
worker_to_local_sync_cmd = None # don't need to rsync
|
||||
else:
|
||||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
if ssh_key is None or ssh_user is None:
|
||||
print(
|
||||
"Error: log sync requires cluster to be setup with "
|
||||
"`ray create_or_update`.")
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
print("Error: log sync requires rsync to be installed.")
|
||||
return
|
||||
worker_to_local_sync_cmd = (
|
||||
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
|
||||
ssh_key, ssh_user, self.worker_ip,
|
||||
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
|
||||
|
||||
if self.remote_dir:
|
||||
local_to_remote_sync_cmd = (
|
||||
"aws s3 sync '{}' '{}'".format(
|
||||
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
if self.sync_process:
|
||||
self.sync_process.poll()
|
||||
if self.sync_process.returncode is None:
|
||||
@@ -63,8 +108,17 @@ class _S3LogSyncer(object):
|
||||
else:
|
||||
print("Warning: last sync is still in progress, skipping")
|
||||
return
|
||||
self.sync_process = subprocess.Popen(
|
||||
["aws", "s3", "sync", self.local_dir, self.remote_dir])
|
||||
|
||||
if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
|
||||
final_cmd = ""
|
||||
if worker_to_local_sync_cmd:
|
||||
final_cmd += worker_to_local_sync_cmd
|
||||
if local_to_remote_sync_cmd:
|
||||
if final_cmd:
|
||||
final_cmd += " && "
|
||||
final_cmd += local_to_remote_sync_cmd
|
||||
print("Running log sync: {}".format(final_cmd))
|
||||
self.sync_process = subprocess.Popen(final_cmd, shell=True)
|
||||
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
|
||||
+32
-18
@@ -24,10 +24,6 @@ class Logger(object):
|
||||
multiple formats (TensorBoard, rllab/viskit, plain json) at once.
|
||||
"""
|
||||
|
||||
_attrs_to_log = [
|
||||
"time_this_iter_s", "mean_loss", "mean_accuracy",
|
||||
"episode_reward_mean", "episode_len_mean"]
|
||||
|
||||
def __init__(self, config, logdir, upload_uri=None):
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
@@ -47,6 +43,11 @@ class Logger(object):
|
||||
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
"""Flushes all disk writes to storage."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
@@ -60,22 +61,22 @@ class UnifiedLogger(Logger):
|
||||
print("TF not installed - cannot log with {}...".format(cls))
|
||||
continue
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
if self.uri:
|
||||
self._log_syncer = get_syncer(self.logdir, self.uri)
|
||||
else:
|
||||
self._log_syncer = None
|
||||
self._log_syncer = get_syncer(self.logdir, self.uri)
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
logger.on_result(result)
|
||||
if self._log_syncer:
|
||||
self._log_syncer.sync_if_needed()
|
||||
self._log_syncer.set_worker_ip(result.node_ip)
|
||||
self._log_syncer.sync_if_needed()
|
||||
|
||||
def close(self):
|
||||
for logger in self._loggers:
|
||||
logger.close()
|
||||
if self._log_syncer:
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.sync_now(force=True)
|
||||
|
||||
def flush(self):
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.wait()
|
||||
|
||||
|
||||
class NoopLogger(Logger):
|
||||
@@ -103,17 +104,30 @@ class _JsonLogger(Logger):
|
||||
self.local_out.close()
|
||||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
values = []
|
||||
for attr, value in result.items():
|
||||
if value is not None:
|
||||
if type(value) in [int, float]:
|
||||
values.append(tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
simple_value=value))
|
||||
elif type(value) is dict:
|
||||
values.extend(to_tf_values(value, path + [attr]))
|
||||
return values
|
||||
|
||||
|
||||
class _TFLogger(Logger):
|
||||
def _init(self):
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
values = []
|
||||
for attr in Logger._attrs_to_log:
|
||||
if getattr(result, attr) is not None:
|
||||
values.append(tf.Summary.Value(
|
||||
tag="ray/tune/{}".format(attr),
|
||||
simple_value=getattr(result, attr)))
|
||||
tmp = result._asdict()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", "time_total_s",
|
||||
"timesteps_total"]:
|
||||
del tmp[k] # not useful to tf log these
|
||||
values = to_tf_values(tmp, ["ray", "tune"])
|
||||
train_stats = tf.Summary(value=values)
|
||||
self._file_writer.add_summary(train_stats, result.timesteps_total)
|
||||
|
||||
|
||||
@@ -85,6 +85,9 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
"hostname",
|
||||
|
||||
# (Auto-filled) The node ip of the machine hosting the training process.
|
||||
"node_ip",
|
||||
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
|
||||
@@ -13,6 +13,7 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
@@ -87,6 +88,7 @@ class Trainable(object):
|
||||
self._timesteps_total = 0
|
||||
self._setup()
|
||||
self._initialize_ok = True
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
@@ -136,6 +138,7 @@ class Trainable(object):
|
||||
neg_mean_loss=neg_loss,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1],
|
||||
node_ip=self._local_ip,
|
||||
config=self.config)
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
@@ -252,6 +252,7 @@ class TrialRunner(object):
|
||||
try:
|
||||
print("Attempting to recover trial state from last checkpoint")
|
||||
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
|
||||
trial.result_logger.flush() # make sure checkpoint is synced
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user