diff --git a/sac.py b/sac.py index aef23f0..1aa523f 100644 --- a/sac.py +++ b/sac.py @@ -73,9 +73,6 @@ class SAC(object): qf_loss.backward() self.critic_optim.step() - for c_param in self.critic.parameters(): - c_param.requires_grad = False - pi, log_pi, _ = self.policy.sample(state_batch) qf1_pi, qf2_pi = self.critic(state_batch, pi) @@ -87,9 +84,6 @@ class SAC(object): policy_loss.backward() self.policy_optim.step() - for c_param in self.critic.parameters(): - c_param.requires_grad = True - if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()