mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 17:01:47 +08:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import math
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import wrappers
|
||||
import torch
|
||||
from sac import SAC
|
||||
from plot import plot_line
|
||||
from normalized_actions import NormalizedActions
|
||||
from replay_memory import ReplayMemory
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
|
||||
parser.add_argument('--algo', default='SAC(GMM)',
|
||||
help='algorithm to use: SAC | SAC(GMM)')
|
||||
parser.add_argument('--env-name', default="HalfCheetah-v2",
|
||||
help='name of the environment to run')
|
||||
parser.add_argument('--reparam', type=bool, default=True,
|
||||
help='reparameterize the policy (default:True)')
|
||||
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
|
||||
help='discount factor for reward (default: 0.99)')
|
||||
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
|
||||
help='target smoothing coefficient(τ) (default: 0.005)')
|
||||
parser.add_argument('--k', type=int, default=4, metavar='G',
|
||||
help='No. of Mixtures (default: 4)')
|
||||
parser.add_argument('--scale_R', type=int, default=5, metavar='G',
|
||||
help='reward scaling (default: 5)')
|
||||
parser.add_argument('--seed', type=int, default=543, metavar='N',
|
||||
help='random seed (default: 543)')
|
||||
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
||||
help='batch size (default: 256)')
|
||||
parser.add_argument('--num_steps', type=int, default=1000, metavar='N',
|
||||
help='max episode length (default: 1000)')
|
||||
parser.add_argument('--num_episodes', type=int, default=1000, metavar='N',
|
||||
help='number of episodes (default: 1000)')
|
||||
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
|
||||
help='hidden size (default: 256)')
|
||||
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
|
||||
help='model updates per simulator step (default: 1)')
|
||||
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
|
||||
help='size of replay buffer (default: 10000000)')
|
||||
args = parser.parse_args()
|
||||
|
||||
env = NormalizedActions(gym.make(args.env_name))
|
||||
|
||||
env.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||
|
||||
|
||||
memory = ReplayMemory(args.replay_size)
|
||||
|
||||
|
||||
rewards = []
|
||||
total_numsteps = 0
|
||||
updates = 0
|
||||
|
||||
for i_episode in range(args.num_episodes):
|
||||
state = env.reset()
|
||||
|
||||
episode_reward = 0
|
||||
while True:
|
||||
action = agent.select_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
mask = not done
|
||||
|
||||
memory.push(state, action, reward, next_state, mask)
|
||||
if len(memory) > args.batch_size:
|
||||
for i in range(args.updates_per_step):
|
||||
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(args.batch_size)
|
||||
agent.update_parameters(state_batch, action_batch, reward_batch, next_state_batch, mask_batch, total_numsteps)
|
||||
|
||||
state = next_state
|
||||
total_numsteps += 1
|
||||
episode_reward += reward
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
rewards.append(episode_reward)
|
||||
plot_line(total_numsteps, rewards, args.algo)
|
||||
print("Episode: {}, total numsteps: {}, reward: {}, average reward: {}".format(i_episode, total_numsteps, np.round(rewards[-1],2),
|
||||
np.round(np.mean(rewards[-100:]),2)))
|
||||
|
||||
env.close()
|
||||
@@ -0,0 +1,155 @@
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Normal
|
||||
from utils import create_log_gaussian, logsumexp
|
||||
|
||||
LOG_SIG_MAX = 2
|
||||
LOG_SIG_MIN = -20
|
||||
|
||||
|
||||
class ValueNetwork(nn.Module):
|
||||
def __init__(self, state_dim, hidden_dim, init_w=1e-3):
|
||||
super(ValueNetwork, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(state_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear3 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, state):
|
||||
x = F.relu(self.linear1(state))
|
||||
x = F.relu(self.linear2(x))
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
|
||||
|
||||
class QNetwork(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3):
|
||||
super(QNetwork, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear3 = nn.Linear(hidden_size, 1)
|
||||
|
||||
self.linear3.weight.data.uniform_(-init_w, init_w)
|
||||
self.linear3.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, state, action):
|
||||
x = torch.cat([state, action], 1)
|
||||
x = F.relu(self.linear1(x))
|
||||
x = F.relu(self.linear2(x))
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
|
||||
|
||||
class GaussianPolicy(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3):
|
||||
super(GaussianPolicy, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.mean_linear = nn.Linear(hidden_size, num_actions)
|
||||
self.mean_linear.weight.data.uniform_(-init_w, init_w)
|
||||
self.mean_linear.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
self.log_std_linear = nn.Linear(hidden_size, num_actions)
|
||||
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
|
||||
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, state):
|
||||
x = F.relu(self.linear1(state))
|
||||
x = F.relu(self.linear2(x))
|
||||
|
||||
mean = self.mean_linear(x)
|
||||
log_std = self.log_std_linear(x)
|
||||
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
|
||||
|
||||
return mean, log_std
|
||||
|
||||
def evaluate(self, state, reparam=False, epsilon=1e-6):
|
||||
mean, log_std = self.forward(state)
|
||||
std = log_std.exp()
|
||||
|
||||
normal = Normal(mean, std)
|
||||
|
||||
if reparam == True:
|
||||
x_t = mean + std * torch.randn(1,6)
|
||||
else:
|
||||
x_t = normal.sample()
|
||||
|
||||
action = torch.tanh(x_t)
|
||||
|
||||
log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + epsilon)
|
||||
log_prob = log_prob.sum(-1, keepdim=True)
|
||||
|
||||
return action, log_prob, x_t, mean, log_std
|
||||
|
||||
def get_action(self, state):
|
||||
state = torch.FloatTensor(state).unsqueeze(0)
|
||||
_, _, x_t, _, _ = self.evaluate(state)
|
||||
action = torch.tanh(x_t)
|
||||
action = action.detach().cpu().numpy()
|
||||
return action[0]
|
||||
|
||||
|
||||
class GaussianMixturePolicy(nn.Module):
|
||||
def __init__(self, num_inputs, num_actions, hidden_size, k):
|
||||
super(GaussianMixturePolicy, self).__init__()
|
||||
self.actions = num_actions
|
||||
self.k = k
|
||||
self.log_std_max = LOG_SIG_MAX
|
||||
|
||||
self.linear1 = nn.Linear(num_inputs, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.out_linear = nn.Linear(hidden_size, (k * 2 * self.actions) + k)
|
||||
|
||||
def forward(self, state):
|
||||
x = F.relu(self.linear1(state))
|
||||
x = F.relu(self.linear2(x))
|
||||
|
||||
out = self.out_linear(x)
|
||||
out = out.view(-1, self.k, (2 * self.actions) + 1)
|
||||
log_w = out[:, :, 0]
|
||||
mean = out[:, :, 1:1 + self.actions]
|
||||
log_std = torch.clamp(out[:, :, 1 + self.actions:], min=LOG_SIG_MIN, max=LOG_SIG_MAX)
|
||||
|
||||
return log_w, mean, log_std
|
||||
|
||||
def evaluate(self, state, reparam=False, epsilon=1e-6):
|
||||
log_w, mean, log_std = self.forward(state)
|
||||
std = log_std.exp()
|
||||
W = F.softmax(log_w, dim=1)
|
||||
pi_picked = torch.multinomial(W, num_samples=1)
|
||||
for i, r in enumerate(pi_picked):
|
||||
means = mean[:, r, :]
|
||||
means = means[:, 0, :]
|
||||
stds = std[:, r, :]
|
||||
stds = stds[:, 0, :]
|
||||
|
||||
|
||||
# We can only reparameterize if there was one component in the GMM,
|
||||
# in which case one should use GaussianPolicy
|
||||
normal = Normal(means, stds)
|
||||
x_t = normal.sample()
|
||||
action = torch.tanh(x_t)
|
||||
|
||||
log_prob = create_log_gaussian(mean, log_std, x_t[:, None, :]) - torch.log(1 - action.pow(2) + epsilon).sum(
|
||||
dim=-1, keepdim=True)
|
||||
log_prob = logsumexp(log_prob + log_w, dim=-1, keepdim=True)
|
||||
log_prob = log_prob - logsumexp(log_w, dim=-1, keepdim=True)
|
||||
return action, log_prob, x_t, mean, log_std
|
||||
|
||||
def get_action(self, state):
|
||||
state = torch.FloatTensor(state).unsqueeze(0)
|
||||
_, _, x_t, _, _ = self.evaluate(state)
|
||||
action = torch.tanh(x_t)
|
||||
|
||||
action = action.detach().cpu().numpy()
|
||||
return action[0]
|
||||
@@ -0,0 +1,16 @@
|
||||
import gym
|
||||
|
||||
|
||||
class NormalizedActions(gym.ActionWrapper):
|
||||
|
||||
def _action(self, action):
|
||||
action = (action + 1) / 2 # [-1, 1] => [0, 1]
|
||||
action *= (self.action_space.high - self.action_space.low)
|
||||
action += self.action_space.low
|
||||
return action
|
||||
|
||||
def _reverse_action(self, action):
|
||||
action -= self.action_space.low
|
||||
action /= (self.action_space.high - self.action_space.low)
|
||||
action = action * 2 - 1
|
||||
return actions
|
||||
@@ -0,0 +1,44 @@
|
||||
import plotly
|
||||
from plotly.graph_objs import Scatter, Line
|
||||
import torch
|
||||
|
||||
|
||||
steps = []
|
||||
def plot_line(xs, ys_population, algo):
|
||||
steps.append(xs)
|
||||
if algo == "SAC":
|
||||
colour = 'rgb(0, 172, 237)'
|
||||
elif algo == "SAC(GMM)":
|
||||
colour = 'rgb(0, 172, 12)'
|
||||
else:
|
||||
colour = 'rgb(172, 12, 0)'
|
||||
|
||||
ys = torch.Tensor(ys_population)
|
||||
|
||||
ys = ys.squeeze()
|
||||
|
||||
trace = Scatter(x=steps, y=ys.numpy(), line=Line(color=colour), name='Reward')
|
||||
|
||||
if algo == "SAC(GMM)":
|
||||
plotly.offline.plot({
|
||||
'data': [trace],
|
||||
'layout': dict(title='SAC(GMM)',
|
||||
xaxis={'title': 'Steps'},
|
||||
yaxis={'title': 'Reward'})
|
||||
}, filename='SAC(GMM).html', auto_open=False)
|
||||
elif algo == "SAC":
|
||||
plotly.offline.plot({
|
||||
'data': [trace],
|
||||
'layout': dict(title='SAC',
|
||||
xaxis={'title': 'Steps'},
|
||||
yaxis={'title': 'Reward'})
|
||||
}, filename='SAC.html', auto_open=False)
|
||||
else:
|
||||
plotly.offline.plot({
|
||||
'data': [trace],
|
||||
'layout': dict(title=algo,
|
||||
xaxis={'title': 'Steps'},
|
||||
yaxis={'title': 'Reward'})
|
||||
}, filename='{}.html'.format(algo), auto_open=False)
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class ReplayMemory:
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.position = 0
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(None)
|
||||
self.buffer[self.position] = (state, action, reward, next_state, done)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size):
|
||||
batch = random.sample(self.buffer, batch_size)
|
||||
state, action, reward, next_state, done = map(np.stack, zip(*batch))
|
||||
return state, action, reward, next_state, done
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
@@ -0,0 +1,128 @@
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.distributions import MultivariateNormal
|
||||
from utils import soft_update, hard_update
|
||||
from model import GaussianMixturePolicy, GaussianPolicy, QNetwork, ValueNetwork
|
||||
|
||||
|
||||
class SAC(object):
|
||||
def __init__(self, num_inputs, action_space, args):
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.action_space = action_space.shape[0]
|
||||
self.gamma = args.gamma
|
||||
self.tau = args.tau
|
||||
self.k = args.k
|
||||
self.scale_R = args.scale_R
|
||||
self.algo = args.algo
|
||||
self.reparam = args.reparam
|
||||
|
||||
if args.algo == "SAC":
|
||||
self.policy = GaussianPolicy(self.num_inputs, self.action_space, args.hidden_size)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=3e-4)
|
||||
else:
|
||||
self.policy = GaussianMixturePolicy(self.num_inputs, self.action_space, args.hidden_size, self.k)
|
||||
self.policy_optim = Adam(self.policy.parameters(), lr=3e-4)
|
||||
|
||||
self.critic = QNetwork(self.num_inputs, self.action_space, args.hidden_size)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=3e-4)
|
||||
|
||||
self.value = ValueNetwork(self.num_inputs, args.hidden_size)
|
||||
self.value_target = ValueNetwork(self.num_inputs, args.hidden_size)
|
||||
self.value_optim = Adam(self.value.parameters(), lr=3e-4)
|
||||
|
||||
self.value_criterion = nn.MSELoss()
|
||||
self.soft_q_criterion = nn.MSELoss()
|
||||
self.action_prior = "uniform"
|
||||
# Make sure target is with the same weight
|
||||
hard_update(self.value_target, self.value)
|
||||
|
||||
def select_action(self, state):
|
||||
action = self.policy.get_action(state)
|
||||
return action
|
||||
|
||||
|
||||
def update_parameters(self, state_batch, action_batch, reward_batch, next_state_batch, mask_batch, step):
|
||||
state_batch = torch.FloatTensor(state_batch)
|
||||
next_state_batch = torch.FloatTensor(next_state_batch)
|
||||
action_batch = torch.FloatTensor(action_batch)
|
||||
reward_batch = torch.FloatTensor(reward_batch)
|
||||
mask_batch = torch.FloatTensor(np.float32(mask_batch))
|
||||
|
||||
expected_q_value = self.critic(state_batch, action_batch)
|
||||
expected_value = self.value(state_batch)
|
||||
|
||||
new_action, log_prob, x_t, mean, log_std = self.policy.evaluate(state_batch, reparam=self.reparam)
|
||||
if self.action_prior == "normal":
|
||||
act = new_action
|
||||
act = act.size()
|
||||
policy_prior = MultivariateNormal(torch.zeros(act[-1]), torch.eye(act[-1]))
|
||||
policy_prior_log_probs = policy_prior.log_prob(new_action)
|
||||
policy_prior_log_probs = policy_prior_log_probs.unsqueeze(1)
|
||||
else:
|
||||
policy_prior_log_probs = 0.0
|
||||
|
||||
target_value = self.value_target(next_state_batch)
|
||||
reward_batch = reward_batch.unsqueeze(1)
|
||||
mask_batch = mask_batch.unsqueeze(1)
|
||||
next_q_value = self.scale_R * reward_batch + mask_batch * self.gamma * target_value
|
||||
q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach())
|
||||
|
||||
expected_new_q_value = self.critic(state_batch, new_action)
|
||||
next_value = expected_new_q_value - log_prob + policy_prior_log_probs
|
||||
value_loss = self.value_criterion(expected_value, next_value.detach())
|
||||
|
||||
log_prob_target = expected_new_q_value - expected_value
|
||||
if self.reparam == True and self.algo == "SAC":
|
||||
policy_loss = (log_prob - expected_new_q_value).mean()
|
||||
else:
|
||||
policy_loss = (log_prob * (log_prob - policy_prior_log_probs - log_prob_target).detach()).mean()
|
||||
|
||||
mean_loss = 0.001 * mean.pow(2).mean()
|
||||
std_loss = 0.001 * log_std.pow(2).mean()
|
||||
x_t_loss = 0.0 * x_t.pow(2).sum(1).mean()
|
||||
|
||||
policy_loss += mean_loss + std_loss + x_t_loss
|
||||
|
||||
self.critic_optim.zero_grad()
|
||||
q_value_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
self.value_optim.zero_grad()
|
||||
value_loss.backward()
|
||||
self.value_optim.step()
|
||||
|
||||
self.policy_optim.zero_grad()
|
||||
policy_loss.backward()
|
||||
self.policy_optim.step()
|
||||
|
||||
soft_update(self.value_target, self.value, self.tau)
|
||||
|
||||
|
||||
def save_model(self, env_name, suffix="", actor_path=None, critic_path=None, value_path=None):
|
||||
if not os.path.exists('models/'):
|
||||
os.makedirs('models/')
|
||||
|
||||
if actor_path is None:
|
||||
actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
|
||||
if critic_path is None:
|
||||
critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
|
||||
if value_path is None:
|
||||
value_path = "models/sac_value_{}_{}".format(env_name, suffix)
|
||||
print('Saving models to {}, {} and {}'.format(actor_path, critic_path, value_path))
|
||||
torch.save(self.value.state_dict(), value_path)
|
||||
torch.save(self.policy.state_dict(), actor_path)
|
||||
torch.save(self.critic.state_dict(), critic_path)
|
||||
|
||||
def load_model(self, actor_path, critic_path, value_path):
|
||||
print('Loading models from {}, {} and {}'.format(actor_path, critic_path, value_path))
|
||||
if actor_path is not None:
|
||||
self.policy.load_state_dict(torch.load(actor_path))
|
||||
if critic_path is not None:
|
||||
self.critic.load_state_dict(torch.load(critic_path))
|
||||
if value_path is not None:
|
||||
self.value.load_state_dict(torch.load(value_path))
|
||||
@@ -0,0 +1,28 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
def create_log_gaussian(mean, log_std, t):
|
||||
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
|
||||
l = mean.shape
|
||||
log_z = log_std
|
||||
z = l[-1] * math.log(2 * math.pi)
|
||||
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
|
||||
return log_p
|
||||
|
||||
def logsumexp(inputs, dim=None, keepdim=False):
|
||||
if dim is None:
|
||||
inputs = inputs.view(-1)
|
||||
dim = 0
|
||||
s, _ = torch.max(inputs, dim=dim, keepdim=True)
|
||||
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
|
||||
if not keepdim:
|
||||
outputs = outputs.squeeze(dim)
|
||||
return outputs
|
||||
|
||||
def soft_update(target, source, tau):
|
||||
for target_param, param in zip(target.parameters(), source.parameters()):
|
||||
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
|
||||
|
||||
def hard_update(target, source):
|
||||
for target_param, param in zip(target.parameters(), source.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
Reference in New Issue
Block a user