From cf0c4745f4c2953a774bbf2a944c3cdc76bdae24 Mon Sep 17 00:00:00 2001 From: adoda Date: Thu, 20 Dec 2018 12:27:24 +0800 Subject: [PATCH] [rllib] support running older version tensorflow(version < 1.5.0) (#3571) --- python/ray/rllib/agents/impala/vtrace_policy_graph.py | 2 +- python/ray/rllib/agents/ppo/ppo_policy_graph.py | 2 +- python/ray/tune/logger.py | 10 +++++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 5eed0a6e7..12c0c30fb 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -168,7 +168,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) mask = tf.reshape(mask, [-1]) else: - mask = tf.ones_like(rewards) + mask = tf.ones_like(rewards, dtype=tf.bool) # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 6948d810a..80ec01ea5 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -206,7 +206,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) mask = tf.reshape(mask, [-1]) else: - mask = tf.ones_like(adv_ph) + mask = tf.ones_like(adv_ph, dtype=tf.bool) self.loss_obj = PPOLoss( action_space, diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 471b5b3fd..75c888c2a 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -8,6 +8,7 @@ import logging import numpy as np import os import yaml +import distutils.version import ray.cloudpickle as cloudpickle from ray.tune.log_sync import get_syncer @@ -18,8 +19,11 @@ logger = logging.getLogger(__name__) try: import tensorflow as tf + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) except ImportError: tf = None + use_tf150_api = True logger.warning("Couldn't import TensorFlow - " "disabling TensorBoard logging.") @@ -155,7 +159,11 @@ def to_tf_values(result, path): values = [] for attr, value in result.items(): if value is not None: - if type(value) in [int, float, np.float32, np.float64, np.int32]: + if use_tf150_api: + type_list = [int, float, np.float32, np.float64, np.int32] + else: + type_list = [int, float] + if type(value) in type_list: values.append( tf.Summary.Value( tag="/".join(path + [attr]), simple_value=value))