diff --git a/model.py b/model.py index 3acbe8b..71b701d 100644 --- a/model.py +++ b/model.py @@ -62,7 +62,7 @@ class QNetwork(nn.Module): class GaussianPolicy(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim): + def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): super(GaussianPolicy, self).__init__() self.linear1 = nn.Linear(num_inputs, hidden_dim) @@ -73,6 +73,16 @@ class GaussianPolicy(nn.Module): self.apply(weights_init_) + # action rescaling + if action_space is None: + self.action_scale = torch.tensor(1.) + self.action_bias = torch.tensor(0.) + else: + self.action_scale = torch.FloatTensor( + (action_space.high - action_space.low) / 2.) + self.action_bias = torch.FloatTensor( + (action_space.high + action_space.low) / 2.) + def forward(self, state): x = F.relu(self.linear1(state)) x = F.relu(self.linear2(x)) @@ -86,15 +96,21 @@ class GaussianPolicy(nn.Module): std = log_std.exp() normal = Normal(mean, std) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) - action = torch.tanh(x_t) + action = torch.tanh(x_t) * self.action_scale + self.action_bias log_prob = normal.log_prob(x_t) # Enforcing Action Bound - log_prob -= torch.log(1 - action.pow(2) + epsilon) + log_prob -= torch.log(self.action_scale * (1 - action.pow(2)) + epsilon) log_prob = log_prob.sum(1, keepdim=True) return action, log_prob, torch.tanh(mean) + def to(self, device): + self.action_scale = self.action_scale.to(device) + self.action_bias = self.action_bias.to(device) + return super(GaussianPolicy, self).to(device) + + class DeterministicPolicy(nn.Module): - def __init__(self, num_inputs, num_actions, hidden_dim): + def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): super(DeterministicPolicy, self).__init__() self.linear1 = nn.Linear(num_inputs, hidden_dim) self.linear2 = nn.Linear(hidden_dim, hidden_dim) @@ -104,17 +120,30 @@ class DeterministicPolicy(nn.Module): self.apply(weights_init_) + # action rescaling + if action_space is None: + self.action_scale = 1. + self.action_bias = 0. + else: + self.action_scale = torch.FloatTensor( + (action_space.high - action_space.low) / 2.) + self.action_bias = torch.FloatTensor( + (action_space.high + action_space.low) / 2.) + def forward(self, state): x = F.relu(self.linear1(state)) x = F.relu(self.linear2(x)) - mean = torch.tanh(self.mean(x)) + mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias return mean - def sample(self, state): mean = self.forward(state) noise = self.noise.normal_(0., std=0.1) noise = noise.clamp(-0.25, 0.25) action = mean + noise return action, torch.tensor(0.), mean - + + def to(self, device): + self.action_scale = self.action_scale.to(device) + self.action_bias = self.action_bias.to(device) + return super(GaussianPolicy, self).to(device) diff --git a/sac.py b/sac.py index d2e0f84..103d1a6 100644 --- a/sac.py +++ b/sac.py @@ -12,7 +12,6 @@ class SAC(object): self.gamma = args.gamma self.tau = args.tau self.alpha = args.alpha - self.action_range = [action_space.low, action_space.high] self.policy_type = args.policy self.target_update_interval = args.target_update_interval @@ -33,14 +32,13 @@ class SAC(object): self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha_optim = Adam([self.log_alpha], lr=args.lr) - - self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device) + self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) else: self.alpha = 0 self.automatic_entropy_tuning = False - self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device) + self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) def select_action(self, state, eval=False): @@ -49,12 +47,7 @@ class SAC(object): action, _, _ = self.policy.sample(state) else: _, _, action = self.policy.sample(state) - action = action.detach().cpu().numpy()[0] - return self.rescale_action(action) - - def rescale_action(self, action): - return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\ - (self.action_range[1] + self.action_range[0]) / 2.0 + return action.detach().cpu().numpy()[0] def update_parameters(self, memory, batch_size, updates): # Sample a batch from memory