diff --git a/main.py b/main.py new file mode 100644 index 0000000..f1c3d2b --- /dev/null +++ b/main.py @@ -0,0 +1,85 @@ +import argparse +import math +import gym +import numpy as np +from gym import wrappers +import torch +from sac import SAC +from plot import plot_line +from normalized_actions import NormalizedActions +from replay_memory import ReplayMemory + +parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') +parser.add_argument('--algo', default='SAC(GMM)', + help='algorithm to use: SAC | SAC(GMM)') +parser.add_argument('--env-name', default="HalfCheetah-v2", + help='name of the environment to run') +parser.add_argument('--reparam', type=bool, default=True, + help='reparameterize the policy (default:True)') +parser.add_argument('--gamma', type=float, default=0.99, metavar='G', + help='discount factor for reward (default: 0.99)') +parser.add_argument('--tau', type=float, default=0.005, metavar='G', + help='target smoothing coefficient(τ) (default: 0.005)') +parser.add_argument('--k', type=int, default=4, metavar='G', + help='No. of Mixtures (default: 4)') +parser.add_argument('--scale_R', type=int, default=5, metavar='G', + help='reward scaling (default: 5)') +parser.add_argument('--seed', type=int, default=543, metavar='N', + help='random seed (default: 543)') +parser.add_argument('--batch_size', type=int, default=256, metavar='N', + help='batch size (default: 256)') +parser.add_argument('--num_steps', type=int, default=1000, metavar='N', + help='max episode length (default: 1000)') +parser.add_argument('--num_episodes', type=int, default=1000, metavar='N', + help='number of episodes (default: 1000)') +parser.add_argument('--hidden_size', type=int, default=256, metavar='N', + help='hidden size (default: 256)') +parser.add_argument('--updates_per_step', type=int, default=1, metavar='N', + help='model updates per simulator step (default: 1)') +parser.add_argument('--replay_size', type=int, default=1000000, metavar='N', + help='size of replay buffer (default: 10000000)') +args = parser.parse_args() + +env = NormalizedActions(gym.make(args.env_name)) + +env.seed(args.seed) +torch.manual_seed(args.seed) +np.random.seed(args.seed) +agent = SAC(env.observation_space.shape[0], env.action_space, args) + + +memory = ReplayMemory(args.replay_size) + + +rewards = [] +total_numsteps = 0 +updates = 0 + +for i_episode in range(args.num_episodes): + state = env.reset() + + episode_reward = 0 + while True: + action = agent.select_action(state) + next_state, reward, done, _ = env.step(action) + mask = not done + + memory.push(state, action, reward, next_state, mask) + if len(memory) > args.batch_size: + for i in range(args.updates_per_step): + state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(args.batch_size) + agent.update_parameters(state_batch, action_batch, reward_batch, next_state_batch, mask_batch, total_numsteps) + + state = next_state + total_numsteps += 1 + episode_reward += reward + + if done: + break + + rewards.append(episode_reward) + plot_line(total_numsteps, rewards, args.algo) + print("Episode: {}, total numsteps: {}, reward: {}, average reward: {}".format(i_episode, total_numsteps, np.round(rewards[-1],2), + np.round(np.mean(rewards[-100:]),2))) + +env.close() diff --git a/model.py b/model.py new file mode 100644 index 0000000..93a6a22 --- /dev/null +++ b/model.py @@ -0,0 +1,155 @@ +import sys +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal +from utils import create_log_gaussian, logsumexp + +LOG_SIG_MAX = 2 +LOG_SIG_MIN = -20 + + +class ValueNetwork(nn.Module): + def __init__(self, state_dim, hidden_dim, init_w=1e-3): + super(ValueNetwork, self).__init__() + + self.linear1 = nn.Linear(state_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.linear3 = nn.Linear(hidden_dim, 1) + + self.linear3.weight.data.uniform_(-init_w, init_w) + self.linear3.bias.data.uniform_(-init_w, init_w) + + def forward(self, state): + x = F.relu(self.linear1(state)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + return x + + +class QNetwork(nn.Module): + def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3): + super(QNetwork, self).__init__() + + self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + self.linear3 = nn.Linear(hidden_size, 1) + + self.linear3.weight.data.uniform_(-init_w, init_w) + self.linear3.bias.data.uniform_(-init_w, init_w) + + def forward(self, state, action): + x = torch.cat([state, action], 1) + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) + x = self.linear3(x) + return x + + +class GaussianPolicy(nn.Module): + def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3): + super(GaussianPolicy, self).__init__() + + self.linear1 = nn.Linear(num_inputs, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + + self.mean_linear = nn.Linear(hidden_size, num_actions) + self.mean_linear.weight.data.uniform_(-init_w, init_w) + self.mean_linear.bias.data.uniform_(-init_w, init_w) + + self.log_std_linear = nn.Linear(hidden_size, num_actions) + self.log_std_linear.weight.data.uniform_(-init_w, init_w) + self.log_std_linear.bias.data.uniform_(-init_w, init_w) + + def forward(self, state): + x = F.relu(self.linear1(state)) + x = F.relu(self.linear2(x)) + + mean = self.mean_linear(x) + log_std = self.log_std_linear(x) + log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) + + return mean, log_std + + def evaluate(self, state, reparam=False, epsilon=1e-6): + mean, log_std = self.forward(state) + std = log_std.exp() + + normal = Normal(mean, std) + + if reparam == True: + x_t = mean + std * torch.randn(1,6) + else: + x_t = normal.sample() + + action = torch.tanh(x_t) + + log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + epsilon) + log_prob = log_prob.sum(-1, keepdim=True) + + return action, log_prob, x_t, mean, log_std + + def get_action(self, state): + state = torch.FloatTensor(state).unsqueeze(0) + _, _, x_t, _, _ = self.evaluate(state) + action = torch.tanh(x_t) + action = action.detach().cpu().numpy() + return action[0] + + +class GaussianMixturePolicy(nn.Module): + def __init__(self, num_inputs, num_actions, hidden_size, k): + super(GaussianMixturePolicy, self).__init__() + self.actions = num_actions + self.k = k + self.log_std_max = LOG_SIG_MAX + + self.linear1 = nn.Linear(num_inputs, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + + self.out_linear = nn.Linear(hidden_size, (k * 2 * self.actions) + k) + + def forward(self, state): + x = F.relu(self.linear1(state)) + x = F.relu(self.linear2(x)) + + out = self.out_linear(x) + out = out.view(-1, self.k, (2 * self.actions) + 1) + log_w = out[:, :, 0] + mean = out[:, :, 1:1 + self.actions] + log_std = torch.clamp(out[:, :, 1 + self.actions:], min=LOG_SIG_MIN, max=LOG_SIG_MAX) + + return log_w, mean, log_std + + def evaluate(self, state, reparam=False, epsilon=1e-6): + log_w, mean, log_std = self.forward(state) + std = log_std.exp() + W = F.softmax(log_w, dim=1) + pi_picked = torch.multinomial(W, num_samples=1) + for i, r in enumerate(pi_picked): + means = mean[:, r, :] + means = means[:, 0, :] + stds = std[:, r, :] + stds = stds[:, 0, :] + + + # We can only reparameterize if there was one component in the GMM, + # in which case one should use GaussianPolicy + normal = Normal(means, stds) + x_t = normal.sample() + action = torch.tanh(x_t) + + log_prob = create_log_gaussian(mean, log_std, x_t[:, None, :]) - torch.log(1 - action.pow(2) + epsilon).sum( + dim=-1, keepdim=True) + log_prob = logsumexp(log_prob + log_w, dim=-1, keepdim=True) + log_prob = log_prob - logsumexp(log_w, dim=-1, keepdim=True) + return action, log_prob, x_t, mean, log_std + + def get_action(self, state): + state = torch.FloatTensor(state).unsqueeze(0) + _, _, x_t, _, _ = self.evaluate(state) + action = torch.tanh(x_t) + + action = action.detach().cpu().numpy() + return action[0] \ No newline at end of file diff --git a/normalized_actions.py b/normalized_actions.py new file mode 100644 index 0000000..1456abe --- /dev/null +++ b/normalized_actions.py @@ -0,0 +1,16 @@ +import gym + + +class NormalizedActions(gym.ActionWrapper): + + def _action(self, action): + action = (action + 1) / 2 # [-1, 1] => [0, 1] + action *= (self.action_space.high - self.action_space.low) + action += self.action_space.low + return action + + def _reverse_action(self, action): + action -= self.action_space.low + action /= (self.action_space.high - self.action_space.low) + action = action * 2 - 1 + return actions diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..81edc94 --- /dev/null +++ b/plot.py @@ -0,0 +1,44 @@ +import plotly +from plotly.graph_objs import Scatter, Line +import torch + + +steps = [] +def plot_line(xs, ys_population, algo): + steps.append(xs) + if algo == "SAC": + colour = 'rgb(0, 172, 237)' + elif algo == "SAC(GMM)": + colour = 'rgb(0, 172, 12)' + else: + colour = 'rgb(172, 12, 0)' + + ys = torch.Tensor(ys_population) + + ys = ys.squeeze() + + trace = Scatter(x=steps, y=ys.numpy(), line=Line(color=colour), name='Reward') + + if algo == "SAC(GMM)": + plotly.offline.plot({ + 'data': [trace], + 'layout': dict(title='SAC(GMM)', + xaxis={'title': 'Steps'}, + yaxis={'title': 'Reward'}) + }, filename='SAC(GMM).html', auto_open=False) + elif algo == "SAC": + plotly.offline.plot({ + 'data': [trace], + 'layout': dict(title='SAC', + xaxis={'title': 'Steps'}, + yaxis={'title': 'Reward'}) + }, filename='SAC.html', auto_open=False) + else: + plotly.offline.plot({ + 'data': [trace], + 'layout': dict(title=algo, + xaxis={'title': 'Steps'}, + yaxis={'title': 'Reward'}) + }, filename='{}.html'.format(algo), auto_open=False) + + diff --git a/replay_memory.py b/replay_memory.py new file mode 100644 index 0000000..2b85164 --- /dev/null +++ b/replay_memory.py @@ -0,0 +1,24 @@ +import random +import numpy as np +from collections import namedtuple + + +class ReplayMemory: + def __init__(self, capacity): + self.capacity = capacity + self.buffer = [] + self.position = 0 + + def push(self, state, action, reward, next_state, done): + if len(self.buffer) < self.capacity: + self.buffer.append(None) + self.buffer[self.position] = (state, action, reward, next_state, done) + self.position = (self.position + 1) % self.capacity + + def sample(self, batch_size): + batch = random.sample(self.buffer, batch_size) + state, action, reward, next_state, done = map(np.stack, zip(*batch)) + return state, action, reward, next_state, done + + def __len__(self): + return len(self.buffer) \ No newline at end of file diff --git a/sac.py b/sac.py new file mode 100644 index 0000000..36aa331 --- /dev/null +++ b/sac.py @@ -0,0 +1,128 @@ +import sys +import os +import numpy as np +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.distributions import MultivariateNormal +from utils import soft_update, hard_update +from model import GaussianMixturePolicy, GaussianPolicy, QNetwork, ValueNetwork + + +class SAC(object): + def __init__(self, num_inputs, action_space, args): + + self.num_inputs = num_inputs + self.action_space = action_space.shape[0] + self.gamma = args.gamma + self.tau = args.tau + self.k = args.k + self.scale_R = args.scale_R + self.algo = args.algo + self.reparam = args.reparam + + if args.algo == "SAC": + self.policy = GaussianPolicy(self.num_inputs, self.action_space, args.hidden_size) + self.policy_optim = Adam(self.policy.parameters(), lr=3e-4) + else: + self.policy = GaussianMixturePolicy(self.num_inputs, self.action_space, args.hidden_size, self.k) + self.policy_optim = Adam(self.policy.parameters(), lr=3e-4) + + self.critic = QNetwork(self.num_inputs, self.action_space, args.hidden_size) + self.critic_optim = Adam(self.critic.parameters(), lr=3e-4) + + self.value = ValueNetwork(self.num_inputs, args.hidden_size) + self.value_target = ValueNetwork(self.num_inputs, args.hidden_size) + self.value_optim = Adam(self.value.parameters(), lr=3e-4) + + self.value_criterion = nn.MSELoss() + self.soft_q_criterion = nn.MSELoss() + self.action_prior = "uniform" + # Make sure target is with the same weight + hard_update(self.value_target, self.value) + + def select_action(self, state): + action = self.policy.get_action(state) + return action + + + def update_parameters(self, state_batch, action_batch, reward_batch, next_state_batch, mask_batch, step): + state_batch = torch.FloatTensor(state_batch) + next_state_batch = torch.FloatTensor(next_state_batch) + action_batch = torch.FloatTensor(action_batch) + reward_batch = torch.FloatTensor(reward_batch) + mask_batch = torch.FloatTensor(np.float32(mask_batch)) + + expected_q_value = self.critic(state_batch, action_batch) + expected_value = self.value(state_batch) + + new_action, log_prob, x_t, mean, log_std = self.policy.evaluate(state_batch, reparam=self.reparam) + if self.action_prior == "normal": + act = new_action + act = act.size() + policy_prior = MultivariateNormal(torch.zeros(act[-1]), torch.eye(act[-1])) + policy_prior_log_probs = policy_prior.log_prob(new_action) + policy_prior_log_probs = policy_prior_log_probs.unsqueeze(1) + else: + policy_prior_log_probs = 0.0 + + target_value = self.value_target(next_state_batch) + reward_batch = reward_batch.unsqueeze(1) + mask_batch = mask_batch.unsqueeze(1) + next_q_value = self.scale_R * reward_batch + mask_batch * self.gamma * target_value + q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach()) + + expected_new_q_value = self.critic(state_batch, new_action) + next_value = expected_new_q_value - log_prob + policy_prior_log_probs + value_loss = self.value_criterion(expected_value, next_value.detach()) + + log_prob_target = expected_new_q_value - expected_value + if self.reparam == True and self.algo == "SAC": + policy_loss = (log_prob - expected_new_q_value).mean() + else: + policy_loss = (log_prob * (log_prob - policy_prior_log_probs - log_prob_target).detach()).mean() + + mean_loss = 0.001 * mean.pow(2).mean() + std_loss = 0.001 * log_std.pow(2).mean() + x_t_loss = 0.0 * x_t.pow(2).sum(1).mean() + + policy_loss += mean_loss + std_loss + x_t_loss + + self.critic_optim.zero_grad() + q_value_loss.backward() + self.critic_optim.step() + + self.value_optim.zero_grad() + value_loss.backward() + self.value_optim.step() + + self.policy_optim.zero_grad() + policy_loss.backward() + self.policy_optim.step() + + soft_update(self.value_target, self.value, self.tau) + + + def save_model(self, env_name, suffix="", actor_path=None, critic_path=None, value_path=None): + if not os.path.exists('models/'): + os.makedirs('models/') + + if actor_path is None: + actor_path = "models/sac_actor_{}_{}".format(env_name, suffix) + if critic_path is None: + critic_path = "models/sac_critic_{}_{}".format(env_name, suffix) + if value_path is None: + value_path = "models/sac_value_{}_{}".format(env_name, suffix) + print('Saving models to {}, {} and {}'.format(actor_path, critic_path, value_path)) + torch.save(self.value.state_dict(), value_path) + torch.save(self.policy.state_dict(), actor_path) + torch.save(self.critic.state_dict(), critic_path) + + def load_model(self, actor_path, critic_path, value_path): + print('Loading models from {}, {} and {}'.format(actor_path, critic_path, value_path)) + if actor_path is not None: + self.policy.load_state_dict(torch.load(actor_path)) + if critic_path is not None: + self.critic.load_state_dict(torch.load(critic_path)) + if value_path is not None: + self.value.load_state_dict(torch.load(value_path)) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..038ceb4 --- /dev/null +++ b/utils.py @@ -0,0 +1,28 @@ +import math +import torch + +def create_log_gaussian(mean, log_std, t): + quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2)) + l = mean.shape + log_z = log_std + z = l[-1] * math.log(2 * math.pi) + log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z + return log_p + +def logsumexp(inputs, dim=None, keepdim=False): + if dim is None: + inputs = inputs.view(-1) + dim = 0 + s, _ = torch.max(inputs, dim=dim, keepdim=True) + outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() + if not keepdim: + outputs = outputs.squeeze(dim) + return outputs + +def soft_update(target, source, tau): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) + +def hard_update(target, source): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_(param.data)