diff --git a/python/ray/rllib/policy_gradient/policy_gradient.py b/python/ray/rllib/policy_gradient/policy_gradient.py index 0adb4f224..db6c2127e 100644 --- a/python/ray/rllib/policy_gradient/policy_gradient.py +++ b/python/ray/rllib/policy_gradient/policy_gradient.py @@ -266,4 +266,5 @@ class PolicyGradient(Algorithm): for (a, o) in zip(self.agents, extra_data[4])]) def compute_action(self, observation): + observation = self.model.observation_filter(observation) return self.model.common_policy.compute([observation])[0][0]