mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
committed by
Richard Liaw
parent
1488975d1b
commit
b94d85fb5d
@@ -11,7 +11,7 @@ import os
|
||||
import ray
|
||||
from ray.rllib.a3c.runner import RunnerThread, process_rollout
|
||||
from ray.rllib.a3c.envs import create_env
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
|
||||
|
||||
@@ -73,8 +73,9 @@ class Runner(object):
|
||||
return completed
|
||||
|
||||
def start(self):
|
||||
logdir = get_tensorflow_log_dir(self.logdir)
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
os.path.join(logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
|
||||
|
||||
@@ -20,14 +20,43 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_tensorflow_log_dir(logdir):
|
||||
if logdir.startswith("s3"):
|
||||
print("WARNING: TensorFlow logging to S3 not supported by"
|
||||
"TensorFlow, logging to /tmp/ray/ instead")
|
||||
logdir = "/tmp/ray/"
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
return logdir
|
||||
|
||||
|
||||
class RLLibEncoder(json.JSONEncoder):
|
||||
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(RLLibEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
if self.ensure_ascii:
|
||||
_encoder = json.encoder.encode_basestring_ascii
|
||||
else:
|
||||
_encoder = json.encoder.encode_basestring
|
||||
|
||||
def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
|
||||
return repr(o) if not np.isnan(o) else nan_str
|
||||
|
||||
_iterencode = json.encoder._make_iterencode(
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
if np.issubdtype(value, float):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
else:
|
||||
return float(value)
|
||||
elif np.issubdtype(value, int):
|
||||
return float(value)
|
||||
if np.issubdtype(value, int):
|
||||
return int(value)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import tensorflow as tf
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
import ray
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
|
||||
from ray.rllib.ppo.runner import Runner, RemoteRunner
|
||||
from ray.rllib.ppo.rollout import collect_samples
|
||||
from ray.rllib.ppo.utils import shuffle
|
||||
@@ -99,8 +99,9 @@ class PPOAgent(Agent):
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.start_time = time.time()
|
||||
if self.config["write_logs"]:
|
||||
logdir = get_tensorflow_log_dir(self.logdir)
|
||||
self.file_writer = tf.summary.FileWriter(
|
||||
self.logdir, self.model.sess.graph)
|
||||
logdir, self.model.sess.graph)
|
||||
else:
|
||||
self.file_writer = None
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
Reference in New Issue
Block a user