diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 92734302f..432e384f2 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -233,7 +233,8 @@ class DynamicTFPolicy(TFPolicy): tf.float32, [None], name="prev_reward"), }) # Placeholder for (sampling steps) timestep (int). - timestep = tf1.placeholder(tf.int64, (), name="timestep") + timestep = tf1.placeholder_with_default( + tf.zeros((), dtype=tf.int64), (), name="timestep") # Placeholder for `is_exploring` flag. explore = tf1.placeholder_with_default( True, (), name="is_exploring") diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index abcfae503..f6e48dad2 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -188,7 +188,8 @@ class TFPolicy(Policy): self._apply_op = None self._stats_fetches = {} self._timestep = timestep if timestep is not None else \ - tf1.placeholder(tf.int64, (), name="timestep") + tf1.placeholder_with_default( + tf.zeros((), dtype=tf.int64), (), name="timestep") self._optimizer = None self._grads_and_vars = None diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 5615d0fe0..1d0a8afcd 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -88,6 +88,7 @@ def try_import_tf(error=False): tf1_module = tf_module.compat.v1 if not was_imported: tf1_module.disable_v2_behavior() + tf1_module.enable_resource_variables() # No compat.v1 -> return tf as is. except AttributeError: tf1_module = tf_module