mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 18:43:57 +08:00
fix bugs
This commit is contained in:
@@ -96,10 +96,11 @@ class GaussianPolicy(nn.Module):
|
||||
std = log_std.exp()
|
||||
normal = Normal(mean, std)
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
action = torch.tanh(x_t) * self.action_scale + self.action_bias
|
||||
y_t = torch.tanh(x_t)
|
||||
action = y_t * self.action_scale + self.action_bias
|
||||
log_prob = normal.log_prob(x_t)
|
||||
# Enforcing Action Bound
|
||||
log_prob -= torch.log(self.action_scale * (1 - action.pow(2)) + epsilon)
|
||||
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
|
||||
log_prob = log_prob.sum(1, keepdim=True)
|
||||
return action, log_prob, torch.tanh(mean)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user