diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 7f177e404..23f88e51f 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -107,10 +107,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): tf.get_variable_scope().name) # Setup the policy loss - if isinstance(action_space, gym.spaces.Box): - ac_size = action_space.shape[0] - actions = tf.placeholder(tf.float32, [None, ac_size], name="ac") - elif isinstance(action_space, gym.spaces.Discrete): + if isinstance(action_space, gym.spaces.Discrete): ac_size = action_space.n actions = tf.placeholder(tf.int64, [None], name="ac") else: diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 60ca9de8c..cded0c165 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -94,6 +94,7 @@ class ModelSupportedSpaces(unittest.TestCase): def testAll(self): ray.init() stats = {} + check_support("IMPALA", {"gpu": False}, stats) check_support("DDPG", {"timesteps_per_iteration": 1}, stats) check_support("DQN", {"timesteps_per_iteration": 1}, stats) check_support("A3C", {