fix bugs of action rescaling

This commit is contained in:
Toshiki Watanabe
2019-07-23 11:30:36 +09:00
parent 3f64157068
commit ab2c461af0
2 changed files with 39 additions and 17 deletions
+36 -7
View File
@@ -62,7 +62,7 @@ class QNetwork(nn.Module):
class GaussianPolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_dim):
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
super(GaussianPolicy, self).__init__()
self.linear1 = nn.Linear(num_inputs, hidden_dim)
@@ -73,6 +73,16 @@ class GaussianPolicy(nn.Module):
self.apply(weights_init_)
# action rescaling
if action_space is None:
self.action_scale = torch.tensor(1.)
self.action_bias = torch.tensor(0.)
else:
self.action_scale = torch.FloatTensor(
(action_space.high - action_space.low) / 2.)
self.action_bias = torch.FloatTensor(
(action_space.high + action_space.low) / 2.)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
@@ -86,15 +96,21 @@ class GaussianPolicy(nn.Module):
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
action = torch.tanh(x_t)
action = torch.tanh(x_t) * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(1 - action.pow(2) + epsilon)
log_prob -= torch.log(self.action_scale * (1 - action.pow(2)) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob, torch.tanh(mean)
def to(self, device):
self.action_scale = self.action_scale.to(device)
self.action_bias = self.action_bias.to(device)
return super(GaussianPolicy, self).to(device)
class DeterministicPolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_dim):
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
super(DeterministicPolicy, self).__init__()
self.linear1 = nn.Linear(num_inputs, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
@@ -104,17 +120,30 @@ class DeterministicPolicy(nn.Module):
self.apply(weights_init_)
# action rescaling
if action_space is None:
self.action_scale = 1.
self.action_bias = 0.
else:
self.action_scale = torch.FloatTensor(
(action_space.high - action_space.low) / 2.)
self.action_bias = torch.FloatTensor(
(action_space.high + action_space.low) / 2.)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = torch.tanh(self.mean(x))
mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
return mean
def sample(self, state):
mean = self.forward(state)
noise = self.noise.normal_(0., std=0.1)
noise = noise.clamp(-0.25, 0.25)
action = mean + noise
return action, torch.tensor(0.), mean
def to(self, device):
self.action_scale = self.action_scale.to(device)
self.action_bias = self.action_bias.to(device)
return super(GaussianPolicy, self).to(device)
+3 -10
View File
@@ -12,7 +12,6 @@ class SAC(object):
self.gamma = args.gamma
self.tau = args.tau
self.alpha = args.alpha
self.action_range = [action_space.low, action_space.high]
self.policy_type = args.policy
self.target_update_interval = args.target_update_interval
@@ -33,14 +32,13 @@ class SAC(object):
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
else:
self.alpha = 0
self.automatic_entropy_tuning = False
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
def select_action(self, state, eval=False):
@@ -49,12 +47,7 @@ class SAC(object):
action, _, _ = self.policy.sample(state)
else:
_, _, action = self.policy.sample(state)
action = action.detach().cpu().numpy()[0]
return self.rescale_action(action)
def rescale_action(self, action):
return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
(self.action_range[1] + self.action_range[0]) / 2.0
return action.detach().cpu().numpy()[0]
def update_parameters(self, memory, batch_size, updates):
# Sample a batch from memory