diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 40a2c7986..758cfc948 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -393,7 +393,7 @@ def build_eager_tf_policy(name, if action_sampler_fn: dist_inputs = None state_out = [] - actions, logp = self.action_sampler_fn( + actions, logp = action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS],