Add files via upload

This commit is contained in:
Pranjal Tandon
2018-08-31 17:25:08 +05:30
committed by GitHub
parent 55d30d5448
commit ba7609856d
7 changed files with 480 additions and 0 deletions
+85
View File
@@ -0,0 +1,85 @@
import argparse
import math
import gym
import numpy as np
from gym import wrappers
import torch
from sac import SAC
from plot import plot_line
from normalized_actions import NormalizedActions
from replay_memory import ReplayMemory
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--algo', default='SAC(GMM)',
help='algorithm to use: SAC | SAC(GMM)')
parser.add_argument('--env-name', default="HalfCheetah-v2",
help='name of the environment to run')
parser.add_argument('--reparam', type=bool, default=True,
help='reparameterize the policy (default:True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--k', type=int, default=4, metavar='G',
help='No. of Mixtures (default: 4)')
parser.add_argument('--scale_R', type=int, default=5, metavar='G',
help='reward scaling (default: 5)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000, metavar='N',
help='max episode length (default: 1000)')
parser.add_argument('--num_episodes', type=int, default=1000, metavar='N',
help='number of episodes (default: 1000)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
help='model updates per simulator step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
help='size of replay buffer (default: 10000000)')
args = parser.parse_args()
env = NormalizedActions(gym.make(args.env_name))
env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
agent = SAC(env.observation_space.shape[0], env.action_space, args)
memory = ReplayMemory(args.replay_size)
rewards = []
total_numsteps = 0
updates = 0
for i_episode in range(args.num_episodes):
state = env.reset()
episode_reward = 0
while True:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
mask = not done
memory.push(state, action, reward, next_state, mask)
if len(memory) > args.batch_size:
for i in range(args.updates_per_step):
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(args.batch_size)
agent.update_parameters(state_batch, action_batch, reward_batch, next_state_batch, mask_batch, total_numsteps)
state = next_state
total_numsteps += 1
episode_reward += reward
if done:
break
rewards.append(episode_reward)
plot_line(total_numsteps, rewards, args.algo)
print("Episode: {}, total numsteps: {}, reward: {}, average reward: {}".format(i_episode, total_numsteps, np.round(rewards[-1],2),
np.round(np.mean(rewards[-100:]),2)))
env.close()
+155
View File
@@ -0,0 +1,155 @@
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from utils import create_log_gaussian, logsumexp
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim, init_w=1e-3):
super(ValueNetwork, self).__init__()
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class QNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3):
super(QNetwork, self).__init__()
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.linear3 = nn.Linear(hidden_size, 1)
self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class GaussianPolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3):
super(GaussianPolicy, self).__init__()
self.linear1 = nn.Linear(num_inputs, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.mean_linear = nn.Linear(hidden_size, num_actions)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)
self.log_std_linear = nn.Linear(hidden_size, num_actions)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
return mean, log_std
def evaluate(self, state, reparam=False, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
if reparam == True:
x_t = mean + std * torch.randn(1,6)
else:
x_t = normal.sample()
action = torch.tanh(x_t)
log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(-1, keepdim=True)
return action, log_prob, x_t, mean, log_std
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
_, _, x_t, _, _ = self.evaluate(state)
action = torch.tanh(x_t)
action = action.detach().cpu().numpy()
return action[0]
class GaussianMixturePolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_size, k):
super(GaussianMixturePolicy, self).__init__()
self.actions = num_actions
self.k = k
self.log_std_max = LOG_SIG_MAX
self.linear1 = nn.Linear(num_inputs, hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.out_linear = nn.Linear(hidden_size, (k * 2 * self.actions) + k)
def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
out = self.out_linear(x)
out = out.view(-1, self.k, (2 * self.actions) + 1)
log_w = out[:, :, 0]
mean = out[:, :, 1:1 + self.actions]
log_std = torch.clamp(out[:, :, 1 + self.actions:], min=LOG_SIG_MIN, max=LOG_SIG_MAX)
return log_w, mean, log_std
def evaluate(self, state, reparam=False, epsilon=1e-6):
log_w, mean, log_std = self.forward(state)
std = log_std.exp()
W = F.softmax(log_w, dim=1)
pi_picked = torch.multinomial(W, num_samples=1)
for i, r in enumerate(pi_picked):
means = mean[:, r, :]
means = means[:, 0, :]
stds = std[:, r, :]
stds = stds[:, 0, :]
# We can only reparameterize if there was one component in the GMM,
# in which case one should use GaussianPolicy
normal = Normal(means, stds)
x_t = normal.sample()
action = torch.tanh(x_t)
log_prob = create_log_gaussian(mean, log_std, x_t[:, None, :]) - torch.log(1 - action.pow(2) + epsilon).sum(
dim=-1, keepdim=True)
log_prob = logsumexp(log_prob + log_w, dim=-1, keepdim=True)
log_prob = log_prob - logsumexp(log_w, dim=-1, keepdim=True)
return action, log_prob, x_t, mean, log_std
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
_, _, x_t, _, _ = self.evaluate(state)
action = torch.tanh(x_t)
action = action.detach().cpu().numpy()
return action[0]
+16
View File
@@ -0,0 +1,16 @@
import gym
class NormalizedActions(gym.ActionWrapper):
def _action(self, action):
action = (action + 1) / 2 # [-1, 1] => [0, 1]
action *= (self.action_space.high - self.action_space.low)
action += self.action_space.low
return action
def _reverse_action(self, action):
action -= self.action_space.low
action /= (self.action_space.high - self.action_space.low)
action = action * 2 - 1
return actions
+44
View File
@@ -0,0 +1,44 @@
import plotly
from plotly.graph_objs import Scatter, Line
import torch
steps = []
def plot_line(xs, ys_population, algo):
steps.append(xs)
if algo == "SAC":
colour = 'rgb(0, 172, 237)'
elif algo == "SAC(GMM)":
colour = 'rgb(0, 172, 12)'
else:
colour = 'rgb(172, 12, 0)'
ys = torch.Tensor(ys_population)
ys = ys.squeeze()
trace = Scatter(x=steps, y=ys.numpy(), line=Line(color=colour), name='Reward')
if algo == "SAC(GMM)":
plotly.offline.plot({
'data': [trace],
'layout': dict(title='SAC(GMM)',
xaxis={'title': 'Steps'},
yaxis={'title': 'Reward'})
}, filename='SAC(GMM).html', auto_open=False)
elif algo == "SAC":
plotly.offline.plot({
'data': [trace],
'layout': dict(title='SAC',
xaxis={'title': 'Steps'},
yaxis={'title': 'Reward'})
}, filename='SAC.html', auto_open=False)
else:
plotly.offline.plot({
'data': [trace],
'layout': dict(title=algo,
xaxis={'title': 'Steps'},
yaxis={'title': 'Reward'})
}, filename='{}.html'.format(algo), auto_open=False)
+24
View File
@@ -0,0 +1,24 @@
import random
import numpy as np
from collections import namedtuple
class ReplayMemory:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
+128
View File
@@ -0,0 +1,128 @@
import sys
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import MultivariateNormal
from utils import soft_update, hard_update
from model import GaussianMixturePolicy, GaussianPolicy, QNetwork, ValueNetwork
class SAC(object):
def __init__(self, num_inputs, action_space, args):
self.num_inputs = num_inputs
self.action_space = action_space.shape[0]
self.gamma = args.gamma
self.tau = args.tau
self.k = args.k
self.scale_R = args.scale_R
self.algo = args.algo
self.reparam = args.reparam
if args.algo == "SAC":
self.policy = GaussianPolicy(self.num_inputs, self.action_space, args.hidden_size)
self.policy_optim = Adam(self.policy.parameters(), lr=3e-4)
else:
self.policy = GaussianMixturePolicy(self.num_inputs, self.action_space, args.hidden_size, self.k)
self.policy_optim = Adam(self.policy.parameters(), lr=3e-4)
self.critic = QNetwork(self.num_inputs, self.action_space, args.hidden_size)
self.critic_optim = Adam(self.critic.parameters(), lr=3e-4)
self.value = ValueNetwork(self.num_inputs, args.hidden_size)
self.value_target = ValueNetwork(self.num_inputs, args.hidden_size)
self.value_optim = Adam(self.value.parameters(), lr=3e-4)
self.value_criterion = nn.MSELoss()
self.soft_q_criterion = nn.MSELoss()
self.action_prior = "uniform"
# Make sure target is with the same weight
hard_update(self.value_target, self.value)
def select_action(self, state):
action = self.policy.get_action(state)
return action
def update_parameters(self, state_batch, action_batch, reward_batch, next_state_batch, mask_batch, step):
state_batch = torch.FloatTensor(state_batch)
next_state_batch = torch.FloatTensor(next_state_batch)
action_batch = torch.FloatTensor(action_batch)
reward_batch = torch.FloatTensor(reward_batch)
mask_batch = torch.FloatTensor(np.float32(mask_batch))
expected_q_value = self.critic(state_batch, action_batch)
expected_value = self.value(state_batch)
new_action, log_prob, x_t, mean, log_std = self.policy.evaluate(state_batch, reparam=self.reparam)
if self.action_prior == "normal":
act = new_action
act = act.size()
policy_prior = MultivariateNormal(torch.zeros(act[-1]), torch.eye(act[-1]))
policy_prior_log_probs = policy_prior.log_prob(new_action)
policy_prior_log_probs = policy_prior_log_probs.unsqueeze(1)
else:
policy_prior_log_probs = 0.0
target_value = self.value_target(next_state_batch)
reward_batch = reward_batch.unsqueeze(1)
mask_batch = mask_batch.unsqueeze(1)
next_q_value = self.scale_R * reward_batch + mask_batch * self.gamma * target_value
q_value_loss = self.soft_q_criterion(expected_q_value, next_q_value.detach())
expected_new_q_value = self.critic(state_batch, new_action)
next_value = expected_new_q_value - log_prob + policy_prior_log_probs
value_loss = self.value_criterion(expected_value, next_value.detach())
log_prob_target = expected_new_q_value - expected_value
if self.reparam == True and self.algo == "SAC":
policy_loss = (log_prob - expected_new_q_value).mean()
else:
policy_loss = (log_prob * (log_prob - policy_prior_log_probs - log_prob_target).detach()).mean()
mean_loss = 0.001 * mean.pow(2).mean()
std_loss = 0.001 * log_std.pow(2).mean()
x_t_loss = 0.0 * x_t.pow(2).sum(1).mean()
policy_loss += mean_loss + std_loss + x_t_loss
self.critic_optim.zero_grad()
q_value_loss.backward()
self.critic_optim.step()
self.value_optim.zero_grad()
value_loss.backward()
self.value_optim.step()
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
soft_update(self.value_target, self.value, self.tau)
def save_model(self, env_name, suffix="", actor_path=None, critic_path=None, value_path=None):
if not os.path.exists('models/'):
os.makedirs('models/')
if actor_path is None:
actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
if critic_path is None:
critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
if value_path is None:
value_path = "models/sac_value_{}_{}".format(env_name, suffix)
print('Saving models to {}, {} and {}'.format(actor_path, critic_path, value_path))
torch.save(self.value.state_dict(), value_path)
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
def load_model(self, actor_path, critic_path, value_path):
print('Loading models from {}, {} and {}'.format(actor_path, critic_path, value_path))
if actor_path is not None:
self.policy.load_state_dict(torch.load(actor_path))
if critic_path is not None:
self.critic.load_state_dict(torch.load(critic_path))
if value_path is not None:
self.value.load_state_dict(torch.load(value_path))
+28
View File
@@ -0,0 +1,28 @@
import math
import torch
def create_log_gaussian(mean, log_std, t):
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
l = mean.shape
log_z = log_std
z = l[-1] * math.log(2 * math.pi)
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
return log_p
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)