mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 18:06:10 +08:00
Update sac.py
This commit is contained in:
@@ -16,7 +16,6 @@ class SAC(object):
|
||||
self.gamma = args.gamma
|
||||
self.tau = args.tau
|
||||
self.alpha = args.alpha
|
||||
self.reparam = args.reparam
|
||||
self.policy_type = args.policy
|
||||
self.target_update_interval = args.target_update_interval
|
||||
|
||||
@@ -46,7 +45,7 @@ class SAC(object):
|
||||
state = torch.FloatTensor(state).unsqueeze(0)
|
||||
if eval == False:
|
||||
self.policy.train()
|
||||
_, _, action, _, _ = self.policy.evaluate(state)
|
||||
action, _, _, _, _ = self.policy.evaluate(state)
|
||||
else:
|
||||
self.policy.eval()
|
||||
_, _, _, action, _ = self.policy.evaluate(state)
|
||||
@@ -73,7 +72,7 @@ class SAC(object):
|
||||
up training, especially on harder task.
|
||||
"""
|
||||
expected_q1_value, expected_q2_value = self.critic(state_batch, action_batch)
|
||||
new_action, log_prob, x_t, mean, log_std = self.policy.evaluate(state_batch, reparam=self.reparam)
|
||||
new_action, log_prob, _, mean, log_std = self.policy.evaluate(state_batch)
|
||||
|
||||
if self.policy_type == "Gaussian":
|
||||
"""
|
||||
@@ -113,21 +112,17 @@ class SAC(object):
|
||||
"""
|
||||
next_value = expected_new_q_value - (self.alpha * log_prob)
|
||||
value_loss = self.value_criterion(expected_value, next_value.detach())
|
||||
log_prob_target = expected_new_q_value - expected_value
|
||||
else:
|
||||
log_prob_target = expected_new_q_value
|
||||
pass
|
||||
|
||||
if self.reparam == True:
|
||||
"""
|
||||
Reparameterization trick is used to get a low variance estimator
|
||||
f(εt;st) = action sampled from the policy
|
||||
εt is an input noise vector, sampled from some fixed distribution
|
||||
Jπ = 𝔼st∼D,εt∼N[logπ(f(εt;st)|st)−Q(st,f(εt;st))]
|
||||
∇Jπ =∇log π + ([∇at log π(at|st) − ∇at Q(st,at)])∇f(εt;st)
|
||||
"""
|
||||
policy_loss = ((self.alpha * log_prob) - expected_new_q_value).mean()
|
||||
else:
|
||||
policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean() # likelihood ratio gradient estimator
|
||||
"""
|
||||
Reparameterization trick is used to get a low variance estimator
|
||||
f(εt;st) = action sampled from the policy
|
||||
εt is an input noise vector, sampled from some fixed distribution
|
||||
Jπ = 𝔼st∼D,εt∼N[logπ(f(εt;st)|st)−Q(st,f(εt;st))]
|
||||
∇Jπ =∇log π + ([∇at log π(at|st) − ∇at Q(st,at)])∇f(εt;st)
|
||||
"""
|
||||
policy_loss = ((self.alpha * log_prob) - expected_new_q_value).mean()
|
||||
|
||||
# Regularization Loss
|
||||
mean_loss = 0.001 * mean.pow(2).mean()
|
||||
|
||||
Reference in New Issue
Block a user