From b94d85fb5d58efb08319443e85d281fca9e005a1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 2 Oct 2017 17:16:52 -0700 Subject: [PATCH] [rllib] Fix logging to Athena (#1058) * Fix logging to Athena * fixes --- python/ray/rllib/a3c/a3c.py | 5 +++-- python/ray/rllib/common.py | 39 ++++++++++++++++++++++++++++++++----- python/ray/rllib/ppo/ppo.py | 5 +++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 79f2bdcef..352a87eb8 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -11,7 +11,7 @@ import os import ray from ray.rllib.a3c.runner import RunnerThread, process_rollout from ray.rllib.a3c.envs import create_env -from ray.rllib.common import Agent, TrainingResult +from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir from ray.rllib.a3c.shared_model import SharedModel from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM @@ -73,8 +73,9 @@ class Runner(object): return completed def start(self): + logdir = get_tensorflow_log_dir(self.logdir) summary_writer = tf.summary.FileWriter( - os.path.join(self.logdir, "agent_%d" % self.id)) + os.path.join(logdir, "agent_%d" % self.id)) self.summary_writer = summary_writer self.runner.start_runner(self.policy.sess, summary_writer) diff --git a/python/ray/rllib/common.py b/python/ray/rllib/common.py index 30a14b913..965c98e90 100644 --- a/python/ray/rllib/common.py +++ b/python/ray/rllib/common.py @@ -20,14 +20,43 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +def get_tensorflow_log_dir(logdir): + if logdir.startswith("s3"): + print("WARNING: TensorFlow logging to S3 not supported by" + "TensorFlow, logging to /tmp/ray/ instead") + logdir = "/tmp/ray/" + if not os.path.exists(logdir): + os.makedirs(logdir) + return logdir + + class RLLibEncoder(json.JSONEncoder): + + def __init__(self, nan_str="null", **kwargs): + super(RLLibEncoder, self).__init__(**kwargs) + self.nan_str = nan_str + + def iterencode(self, o, _one_shot=False): + if self.ensure_ascii: + _encoder = json.encoder.encode_basestring_ascii + else: + _encoder = json.encoder.encode_basestring + + def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str): + return repr(o) if not np.isnan(o) else nan_str + + _iterencode = json.encoder._make_iterencode( + None, self.default, _encoder, self.indent, floatstr, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, _one_shot) + return _iterencode(o, 0) + def default(self, value): + if np.isnan(value): + return None if np.issubdtype(value, float): - if np.isnan(value): - return None - else: - return float(value) - elif np.issubdtype(value, int): + return float(value) + if np.issubdtype(value, int): return int(value) diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 21955cec8..cb0a2edb5 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -11,7 +11,7 @@ import tensorflow as tf from tensorflow.python import debug as tf_debug import ray -from ray.rllib.common import Agent, TrainingResult +from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir from ray.rllib.ppo.runner import Runner, RemoteRunner from ray.rllib.ppo.rollout import collect_samples from ray.rllib.ppo.utils import shuffle @@ -99,8 +99,9 @@ class PPOAgent(Agent): for _ in range(self.config["num_workers"])] self.start_time = time.time() if self.config["write_logs"]: + logdir = get_tensorflow_log_dir(self.logdir) self.file_writer = tf.summary.FileWriter( - self.logdir, self.model.sess.graph) + logdir, self.model.sess.graph) else: self.file_writer = None self.saver = tf.train.Saver(max_to_keep=None)