diff --git a/sac.py b/sac.py index 103d1a6..f494344 100644 --- a/sac.py +++ b/sac.py @@ -27,7 +27,7 @@ class SAC(object): if self.policy_type == "Gaussian": # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper - if self.automatic_entropy_tuning == True: + if self.automatic_entropy_tuning is True: self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha_optim = Adam([self.log_alpha], lr=args.lr) @@ -43,7 +43,7 @@ class SAC(object): def select_action(self, state, eval=False): state = torch.FloatTensor(state).to(self.device).unsqueeze(0) - if eval == False: + if eval is False: action, _, _ = self.policy.sample(state) else: _, _, action = self.policy.sample(state)