mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
fix bugs of action rescaling
This commit is contained in:
@@ -62,7 +62,7 @@ class QNetwork(nn.Module):
|
||||
|
||||
|
||||
class GaussianPolicy(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_dim):
|
||||
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
|
||||
super(GaussianPolicy, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs, hidden_dim)
|
||||
@@ -73,6 +73,16 @@ class GaussianPolicy(nn.Module):
|
||||
|
||||
self.apply(weights_init_)
|
||||
|
||||
# action rescaling
|
||||
if action_space is None:
|
||||
self.action_scale = torch.tensor(1.)
|
||||
self.action_bias = torch.tensor(0.)
|
||||
else:
|
||||
self.action_scale = torch.FloatTensor(
|
||||
(action_space.high - action_space.low) / 2.)
|
||||
self.action_bias = torch.FloatTensor(
|
||||
(action_space.high + action_space.low) / 2.)
|
||||
|
||||
def forward(self, state):
|
||||
x = F.relu(self.linear1(state))
|
||||
x = F.relu(self.linear2(x))
|
||||
@@ -86,15 +96,21 @@ class GaussianPolicy(nn.Module):
|
||||
std = log_std.exp()
|
||||
normal = Normal(mean, std)
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
action = torch.tanh(x_t)
|
||||
action = torch.tanh(x_t) * self.action_scale + self.action_bias
|
||||
log_prob = normal.log_prob(x_t)
|
||||
# Enforcing Action Bound
|
||||
log_prob -= torch.log(1 - action.pow(2) + epsilon)
|
||||
log_prob -= torch.log(self.action_scale * (1 - action.pow(2)) + epsilon)
|
||||
log_prob = log_prob.sum(1, keepdim=True)
|
||||
return action, log_prob, torch.tanh(mean)
|
||||
|
||||
def to(self, device):
|
||||
self.action_scale = self.action_scale.to(device)
|
||||
self.action_bias = self.action_bias.to(device)
|
||||
return super(GaussianPolicy, self).to(device)
|
||||
|
||||
|
||||
class DeterministicPolicy(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_dim):
|
||||
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
|
||||
super(DeterministicPolicy, self).__init__()
|
||||
self.linear1 = nn.Linear(num_inputs, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
@@ -104,17 +120,30 @@ class DeterministicPolicy(nn.Module):
|
||||
|
||||
self.apply(weights_init_)
|
||||
|
||||
# action rescaling
|
||||
if action_space is None:
|
||||
self.action_scale = 1.
|
||||
self.action_bias = 0.
|
||||
else:
|
||||
self.action_scale = torch.FloatTensor(
|
||||
(action_space.high - action_space.low) / 2.)
|
||||
self.action_bias = torch.FloatTensor(
|
||||
(action_space.high + action_space.low) / 2.)
|
||||
|
||||
def forward(self, state):
|
||||
x = F.relu(self.linear1(state))
|
||||
x = F.relu(self.linear2(x))
|
||||
mean = torch.tanh(self.mean(x))
|
||||
mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
|
||||
return mean
|
||||
|
||||
|
||||
def sample(self, state):
|
||||
mean = self.forward(state)
|
||||
noise = self.noise.normal_(0., std=0.1)
|
||||
noise = noise.clamp(-0.25, 0.25)
|
||||
action = mean + noise
|
||||
return action, torch.tensor(0.), mean
|
||||
|
||||
|
||||
def to(self, device):
|
||||
self.action_scale = self.action_scale.to(device)
|
||||
self.action_bias = self.action_bias.to(device)
|
||||
return super(GaussianPolicy, self).to(device)
|
||||
|
||||
@@ -12,7 +12,6 @@ class SAC(object):
|
||||
self.gamma = args.gamma
|
||||
self.tau = args.tau
|
||||
self.alpha = args.alpha
|
||||
self.action_range = [action_space.low, action_space.high]
|
||||
|
||||
self.policy_type = args.policy
|
||||
self.target_update_interval = args.target_update_interval
|
||||
@@ -33,14 +32,13 @@ class SAC(object):
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
|
||||
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
|
||||
|
||||
|
||||
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
|
||||
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
|
||||
|
||||
else:
|
||||
self.alpha = 0
|
||||
self.automatic_entropy_tuning = False
|
||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
|
||||
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
|
||||
|
||||
def select_action(self, state, eval=False):
|
||||
@@ -49,12 +47,7 @@ class SAC(object):
|
||||
action, _, _ = self.policy.sample(state)
|
||||
else:
|
||||
_, _, action = self.policy.sample(state)
|
||||
action = action.detach().cpu().numpy()[0]
|
||||
return self.rescale_action(action)
|
||||
|
||||
def rescale_action(self, action):
|
||||
return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
|
||||
(self.action_range[1] + self.action_range[0]) / 2.0
|
||||
return action.detach().cpu().numpy()[0]
|
||||
|
||||
def update_parameters(self, memory, batch_size, updates):
|
||||
# Sample a batch from memory
|
||||
|
||||
Reference in New Issue
Block a user