This commit is contained in:
pranz24
2019-04-07 12:39:47 +05:30
parent 86412e14e1
commit 82d54f0f6a
+8 -8
View File
@@ -75,14 +75,6 @@ class SAC(object):
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
self.critic_optim.zero_grad()
qf1_loss.backward()
self.critic_optim.step()
self.critic_optim.zero_grad()
qf2_loss.backward()
self.critic_optim.step()
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
@@ -90,6 +82,14 @@ class SAC(object):
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼stD,εtN[α * logπ(f(εt;st)|st) Q(st,f(εt;st))]
self.critic_optim.zero_grad()
qf1_loss.backward()
self.critic_optim.step()
self.critic_optim.zero_grad()
qf2_loss.backward()
self.critic_optim.step()
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()