mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
3664ba4e60
More pytorch-native way would be to use `Module.register_buffer()` method. In that case, buffer won't be used in parameters(), but will be converted to CUDA and CPU with `to()` call transparently.
153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributions import Normal
|
|
|
|
LOG_SIG_MAX = 2
|
|
LOG_SIG_MIN = -20
|
|
epsilon = 1e-6
|
|
|
|
# Initialize Policy weights
|
|
def weights_init_(m):
|
|
if isinstance(m, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(m.weight, gain=1)
|
|
torch.nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
class ValueNetwork(nn.Module):
|
|
def __init__(self, num_inputs, hidden_dim):
|
|
super(ValueNetwork, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(num_inputs, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.linear3 = nn.Linear(hidden_dim, 1)
|
|
|
|
self.apply(weights_init_)
|
|
|
|
def forward(self, state):
|
|
x = F.relu(self.linear1(state))
|
|
x = F.relu(self.linear2(x))
|
|
x = self.linear3(x)
|
|
return x
|
|
|
|
|
|
class QNetwork(nn.Module):
|
|
def __init__(self, num_inputs, num_actions, hidden_dim):
|
|
super(QNetwork, self).__init__()
|
|
|
|
# Q1 architecture
|
|
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.linear3 = nn.Linear(hidden_dim, 1)
|
|
|
|
# Q2 architecture
|
|
self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
|
|
self.linear5 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.linear6 = nn.Linear(hidden_dim, 1)
|
|
|
|
self.apply(weights_init_)
|
|
|
|
def forward(self, state, action):
|
|
xu = torch.cat([state, action], 1)
|
|
|
|
x1 = F.relu(self.linear1(xu))
|
|
x1 = F.relu(self.linear2(x1))
|
|
x1 = self.linear3(x1)
|
|
|
|
x2 = F.relu(self.linear4(xu))
|
|
x2 = F.relu(self.linear5(x2))
|
|
x2 = self.linear6(x2)
|
|
|
|
return x1, x2
|
|
|
|
|
|
class GaussianPolicy(nn.Module):
|
|
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
|
|
super(GaussianPolicy, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(num_inputs, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
self.mean_linear = nn.Linear(hidden_dim, num_actions)
|
|
self.log_std_linear = nn.Linear(hidden_dim, num_actions)
|
|
|
|
self.apply(weights_init_)
|
|
|
|
# action rescaling
|
|
if action_space is None:
|
|
self.action_scale = torch.tensor(1.)
|
|
self.action_bias = torch.tensor(0.)
|
|
else:
|
|
self.action_scale = torch.FloatTensor(
|
|
(action_space.high - action_space.low) / 2.)
|
|
self.action_bias = torch.FloatTensor(
|
|
(action_space.high + action_space.low) / 2.)
|
|
|
|
def forward(self, state):
|
|
x = F.relu(self.linear1(state))
|
|
x = F.relu(self.linear2(x))
|
|
mean = self.mean_linear(x)
|
|
log_std = self.log_std_linear(x)
|
|
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
|
|
return mean, log_std
|
|
|
|
def sample(self, state):
|
|
mean, log_std = self.forward(state)
|
|
std = log_std.exp()
|
|
normal = Normal(mean, std)
|
|
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
|
y_t = torch.tanh(x_t)
|
|
action = y_t * self.action_scale + self.action_bias
|
|
log_prob = normal.log_prob(x_t)
|
|
# Enforcing Action Bound
|
|
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
|
|
log_prob = log_prob.sum(1, keepdim=True)
|
|
mean = torch.tanh(mean) * self.action_scale + self.action_bias
|
|
return action, log_prob, mean
|
|
|
|
def to(self, device):
|
|
self.action_scale = self.action_scale.to(device)
|
|
self.action_bias = self.action_bias.to(device)
|
|
return super(GaussianPolicy, self).to(device)
|
|
|
|
|
|
class DeterministicPolicy(nn.Module):
|
|
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
|
|
super(DeterministicPolicy, self).__init__()
|
|
self.linear1 = nn.Linear(num_inputs, hidden_dim)
|
|
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
self.mean = nn.Linear(hidden_dim, num_actions)
|
|
self.noise = torch.Tensor(num_actions)
|
|
|
|
self.apply(weights_init_)
|
|
|
|
# action rescaling
|
|
if action_space is None:
|
|
self.action_scale = 1.
|
|
self.action_bias = 0.
|
|
else:
|
|
self.action_scale = torch.FloatTensor(
|
|
(action_space.high - action_space.low) / 2.)
|
|
self.action_bias = torch.FloatTensor(
|
|
(action_space.high + action_space.low) / 2.)
|
|
|
|
def forward(self, state):
|
|
x = F.relu(self.linear1(state))
|
|
x = F.relu(self.linear2(x))
|
|
mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
|
|
return mean
|
|
|
|
def sample(self, state):
|
|
mean = self.forward(state)
|
|
noise = self.noise.normal_(0., std=0.1)
|
|
noise = noise.clamp(-0.25, 0.25)
|
|
action = mean + noise
|
|
return action, torch.tensor(0.), mean
|
|
|
|
def to(self, device):
|
|
self.action_scale = self.action_scale.to(device)
|
|
self.action_bias = self.action_bias.to(device)
|
|
self.noise = self.noise.to(device)
|
|
return super(DeterministicPolicy, self).to(device)
|