mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 18:06:10 +08:00
Clean Up
This commit is contained in:
@@ -75,14 +75,6 @@ class SAC(object):
|
||||
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
|
||||
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf1_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf2_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
pi, log_pi, _ = self.policy.sample(state_batch)
|
||||
|
||||
qf1_pi, qf2_pi = self.critic(state_batch, pi)
|
||||
@@ -90,6 +82,14 @@ class SAC(object):
|
||||
|
||||
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf1_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
qf2_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.policy_optim.zero_grad()
|
||||
policy_loss.backward()
|
||||
self.policy_optim.step()
|
||||
|
||||
Reference in New Issue
Block a user