diff --git a/model.py b/model.py index d6cf038..cc548e0 100644 --- a/model.py +++ b/model.py @@ -102,7 +102,8 @@ class GaussianPolicy(nn.Module): # Enforcing Action Bound log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) log_prob = log_prob.sum(1, keepdim=True) - return action, log_prob, torch.tanh(mean) + mean = torch.tanh(mean) * self.action_scale + self.action_bias + return action, log_prob, mean def to(self, device): self.action_scale = self.action_scale.to(device)