[rllib] support running older version tensorflow(version < 1.5.0) (#3571)

This commit is contained in:
adoda
2018-12-20 12:27:24 +08:00
committed by Richard Liaw
parent a5309bec7c
commit cf0c4745f4
3 changed files with 11 additions and 3 deletions
@@ -168,7 +168,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(rewards)
mask = tf.ones_like(rewards, dtype=tf.bool)
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
self.loss = VTraceLoss(
@@ -206,7 +206,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(adv_ph)
mask = tf.ones_like(adv_ph, dtype=tf.bool)
self.loss_obj = PPOLoss(
action_space,
+9 -1
View File
@@ -8,6 +8,7 @@ import logging
import numpy as np
import os
import yaml
import distutils.version
import ray.cloudpickle as cloudpickle
from ray.tune.log_sync import get_syncer
@@ -18,8 +19,11 @@ logger = logging.getLogger(__name__)
try:
import tensorflow as tf
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
except ImportError:
tf = None
use_tf150_api = True
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
@@ -155,7 +159,11 @@ def to_tf_values(result, path):
values = []
for attr, value in result.items():
if value is not None:
if type(value) in [int, float, np.float32, np.float64, np.int32]:
if use_tf150_api:
type_list = [int, float, np.float32, np.float64, np.int32]
else:
type_list = [int, float]
if type(value) in type_list:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))