Update sac.py

This commit is contained in:
Pranjal Tandon
2020-02-03 14:00:34 +05:30
committed by GitHub
parent 5189f44caa
commit 589b56b264
+4 -4
View File
@@ -17,7 +17,7 @@ class SAC(object):
self.target_update_interval = args.target_update_interval
self.automatic_entropy_tuning = args.automatic_entropy_tuning
self.device = torch.device("cuda" if args.cuda else "cpu")
self.device = torch.device("cuda" if args.cuda else "cpu")
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
@@ -41,9 +41,9 @@ class SAC(object):
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
def select_action(self, state, eval=False):
def select_action(self, state, evaluate=False):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
if eval is False:
if evaluate is False:
action, _, _ = self.policy.sample(state)
else:
_, _, action = self.policy.sample(state)
@@ -119,7 +119,7 @@ class SAC(object):
print('Saving models to {} and {}'.format(actor_path, critic_path))
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
# Load model parameters
def load_model(self, actor_path, critic_path):
print('Loading models from {} and {}'.format(actor_path, critic_path))