diff --git a/python/ray/rllib/agents/marwil/marwil_policy.py b/python/ray/rllib/agents/marwil/marwil_policy.py index 8c8eafba3..47ff12ebd 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy.py +++ b/python/ray/rllib/agents/marwil/marwil_policy.py @@ -110,7 +110,7 @@ class MARWILPolicy(MARWILPostprocessing, TFPolicy): self.output_actions = action_dist.sample() # Training inputs - self.act_t = tf.placeholder(tf.int32, [None], name="action") + self.act_t = ModelCatalog.get_action_placeholder(action_space) self.cum_rew_t = tf.placeholder(tf.float32, [None], name="reward") # v network evaluation