diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index ac2a04ed6..43aa2cc20 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -70,7 +70,8 @@ class OffPolicyEstimator: state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.data.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.data.get(SampleBatch.PREV_REWARDS)) - return convert_to_numpy(log_likelihoods) + log_likelihoods = convert_to_numpy(log_likelihoods) + return np.exp(log_likelihoods) @DeveloperAPI def process(self, batch: SampleBatchType):