Update sac.py

This commit is contained in:
Pranjal Tandon
2018-12-06 04:17:24 +05:30
committed by GitHub
parent 6440e47a0b
commit bcd0d02b01
+11 -16
View File
@@ -16,7 +16,6 @@ class SAC(object):
self.gamma = args.gamma
self.tau = args.tau
self.alpha = args.alpha
self.reparam = args.reparam
self.policy_type = args.policy
self.target_update_interval = args.target_update_interval
@@ -46,7 +45,7 @@ class SAC(object):
state = torch.FloatTensor(state).unsqueeze(0)
if eval == False:
self.policy.train()
_, _, action, _, _ = self.policy.evaluate(state)
action, _, _, _, _ = self.policy.evaluate(state)
else:
self.policy.eval()
_, _, _, action, _ = self.policy.evaluate(state)
@@ -73,7 +72,7 @@ class SAC(object):
up training, especially on harder task.
"""
expected_q1_value, expected_q2_value = self.critic(state_batch, action_batch)
new_action, log_prob, x_t, mean, log_std = self.policy.evaluate(state_batch, reparam=self.reparam)
new_action, log_prob, _, mean, log_std = self.policy.evaluate(state_batch)
if self.policy_type == "Gaussian":
"""
@@ -113,21 +112,17 @@ class SAC(object):
"""
next_value = expected_new_q_value - (self.alpha * log_prob)
value_loss = self.value_criterion(expected_value, next_value.detach())
log_prob_target = expected_new_q_value - expected_value
else:
log_prob_target = expected_new_q_value
pass
if self.reparam == True:
"""
Reparameterization trick is used to get a low variance estimator
f(εt;st) = action sampled from the policy
εt is an input noise vector, sampled from some fixed distribution
Jπ = 𝔼stD,εtN[logπ(f(εt;st)|st)Q(st,f(εt;st))]
∇Jπ =∇log π + ([∇at log π(at|st) ∇at Q(st,at)])∇f(εt;st)
"""
policy_loss = ((self.alpha * log_prob) - expected_new_q_value).mean()
else:
policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean() # likelihood ratio gradient estimator
"""
Reparameterization trick is used to get a low variance estimator
f(εt;st) = action sampled from the policy
εt is an input noise vector, sampled from some fixed distribution
Jπ = 𝔼stD,εtN[logπ(f(εt;st)|st)Q(st,f(εt;st))]
∇Jπ =∇log π + ([∇at log π(at|st) ∇at Q(st,at)])∇f(εt;st)
"""
policy_loss = ((self.alpha * log_prob) - expected_new_q_value).mean()
# Regularization Loss
mean_loss = 0.001 * mean.pow(2).mean()