mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
Update sac.py
This commit is contained in:
@@ -17,7 +17,7 @@ class SAC(object):
|
||||
self.target_update_interval = args.target_update_interval
|
||||
self.automatic_entropy_tuning = args.automatic_entropy_tuning
|
||||
|
||||
self.device = torch.device("cuda" if args.cuda else "cpu")
|
||||
self.device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
|
||||
@@ -41,9 +41,9 @@ class SAC(object):
|
||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
|
||||
|
||||
def select_action(self, state, eval=False):
|
||||
def select_action(self, state, evaluate=False):
|
||||
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
|
||||
if eval is False:
|
||||
if evaluate is False:
|
||||
action, _, _ = self.policy.sample(state)
|
||||
else:
|
||||
_, _, action = self.policy.sample(state)
|
||||
@@ -119,7 +119,7 @@ class SAC(object):
|
||||
print('Saving models to {} and {}'.format(actor_path, critic_path))
|
||||
torch.save(self.policy.state_dict(), actor_path)
|
||||
torch.save(self.critic.state_dict(), critic_path)
|
||||
|
||||
|
||||
# Load model parameters
|
||||
def load_model(self, actor_path, critic_path):
|
||||
print('Loading models from {} and {}'.format(actor_path, critic_path))
|
||||
|
||||
Reference in New Issue
Block a user