mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[rllib] support running older version tensorflow(version < 1.5.0) (#3571)
This commit is contained in:
@@ -168,7 +168,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
mask = tf.ones_like(rewards, dtype=tf.bool)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
self.loss = VTraceLoss(
|
||||
|
||||
@@ -206,7 +206,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
else:
|
||||
mask = tf.ones_like(adv_ph)
|
||||
mask = tf.ones_like(adv_ph, dtype=tf.bool)
|
||||
|
||||
self.loss_obj = PPOLoss(
|
||||
action_space,
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import yaml
|
||||
import distutils.version
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.log_sync import get_syncer
|
||||
@@ -18,8 +19,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
except ImportError:
|
||||
tf = None
|
||||
use_tf150_api = True
|
||||
logger.warning("Couldn't import TensorFlow - "
|
||||
"disabling TensorBoard logging.")
|
||||
|
||||
@@ -155,7 +159,11 @@ def to_tf_values(result, path):
|
||||
values = []
|
||||
for attr, value in result.items():
|
||||
if value is not None:
|
||||
if type(value) in [int, float, np.float32, np.float64, np.int32]:
|
||||
if use_tf150_api:
|
||||
type_list = [int, float, np.float32, np.float64, np.int32]
|
||||
else:
|
||||
type_list = [int, float]
|
||||
if type(value) in type_list:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
|
||||
Reference in New Issue
Block a user