[rllib] Fix logging to Athena (#1058)

* Fix logging to Athena

* fixes
This commit is contained in:
Philipp Moritz
2017-10-02 17:16:52 -07:00
committed by Richard Liaw
parent 1488975d1b
commit b94d85fb5d
3 changed files with 40 additions and 9 deletions
+3 -2
View File
@@ -11,7 +11,7 @@ import os
import ray
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_env
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
from ray.rllib.a3c.shared_model import SharedModel
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
@@ -73,8 +73,9 @@ class Runner(object):
return completed
def start(self):
logdir = get_tensorflow_log_dir(self.logdir)
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
os.path.join(logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)
+34 -5
View File
@@ -20,14 +20,43 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def get_tensorflow_log_dir(logdir):
if logdir.startswith("s3"):
print("WARNING: TensorFlow logging to S3 not supported by"
"TensorFlow, logging to /tmp/ray/ instead")
logdir = "/tmp/ray/"
if not os.path.exists(logdir):
os.makedirs(logdir)
return logdir
class RLLibEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(RLLibEncoder, self).__init__(**kwargs)
self.nan_str = nan_str
def iterencode(self, o, _one_shot=False):
if self.ensure_ascii:
_encoder = json.encoder.encode_basestring_ascii
else:
_encoder = json.encoder.encode_basestring
def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
return repr(o) if not np.isnan(o) else nan_str
_iterencode = json.encoder._make_iterencode(
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
return _iterencode(o, 0)
def default(self, value):
if np.isnan(value):
return None
if np.issubdtype(value, float):
if np.isnan(value):
return None
else:
return float(value)
elif np.issubdtype(value, int):
return float(value)
if np.issubdtype(value, int):
return int(value)
+3 -2
View File
@@ -11,7 +11,7 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug
import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
from ray.rllib.ppo.runner import Runner, RemoteRunner
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.ppo.utils import shuffle
@@ -99,8 +99,9 @@ class PPOAgent(Agent):
for _ in range(self.config["num_workers"])]
self.start_time = time.time()
if self.config["write_logs"]:
logdir = get_tensorflow_log_dir(self.logdir)
self.file_writer = tf.summary.FileWriter(
self.logdir, self.model.sess.graph)
logdir, self.model.sess.graph)
else:
self.file_writer = None
self.saver = tf.train.Saver(max_to_keep=None)