This commit is contained in:
Toshiki Watanabe
2019-07-23 11:59:59 +09:00
parent d3a6ffda45
commit d4cce3869e
+3 -2
View File
@@ -96,10 +96,11 @@ class GaussianPolicy(nn.Module):
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
action = torch.tanh(x_t) * self.action_scale + self.action_bias
y_t = torch.tanh(x_t)
action = y_t * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(self.action_scale * (1 - action.pow(2)) + epsilon)
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob, torch.tanh(mean)