mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 19:48:35 +08:00
[RLlib] Issue 12244: Unable to restore multi-agent PPOTFPolicy's Model (from exported). (#12786)
This commit is contained in:
@@ -233,7 +233,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||
tf.float32, [None], name="prev_reward"),
|
||||
})
|
||||
# Placeholder for (sampling steps) timestep (int).
|
||||
timestep = tf1.placeholder(tf.int64, (), name="timestep")
|
||||
timestep = tf1.placeholder_with_default(
|
||||
tf.zeros((), dtype=tf.int64), (), name="timestep")
|
||||
# Placeholder for `is_exploring` flag.
|
||||
explore = tf1.placeholder_with_default(
|
||||
True, (), name="is_exploring")
|
||||
|
||||
@@ -188,7 +188,8 @@ class TFPolicy(Policy):
|
||||
self._apply_op = None
|
||||
self._stats_fetches = {}
|
||||
self._timestep = timestep if timestep is not None else \
|
||||
tf1.placeholder(tf.int64, (), name="timestep")
|
||||
tf1.placeholder_with_default(
|
||||
tf.zeros((), dtype=tf.int64), (), name="timestep")
|
||||
|
||||
self._optimizer = None
|
||||
self._grads_and_vars = None
|
||||
|
||||
@@ -88,6 +88,7 @@ def try_import_tf(error=False):
|
||||
tf1_module = tf_module.compat.v1
|
||||
if not was_imported:
|
||||
tf1_module.disable_v2_behavior()
|
||||
tf1_module.enable_resource_variables()
|
||||
# No compat.v1 -> return tf as is.
|
||||
except AttributeError:
|
||||
tf1_module = tf_module
|
||||
|
||||
Reference in New Issue
Block a user