diff --git a/README.md b/README.md index 5e2c36a..157b968 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # pytorch-a2c-ppo-acktr +## Update 09/27/2017: now supports both Atari and MuJoCo/Roboschool! + This is a PyTorch implementation of * Advantage Actor Critic (A2C), a synchronous deterministic version of [A3C](https://arxiv.org/pdf/1602.01783v1.pdf) * Proximal Policy Optimization [PPO](https://arxiv.org/pdf/1707.06347.pdf) @@ -13,47 +15,85 @@ This implementation is inspired by the OpenAI baselines for [A2C](https://github Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. Also see a todo list below. +Also I'm searching for volunteers to run all experiments on Atari and MuJoCo (with multiple random seeds). + +## Disclaimer + +It's extremely difficult to reproduce results for Reinforcement Learning methods. See ["Deep Reinforcement Learning that Matters"](https://arxiv.org/abs/1709.06560) for more information. I tried to reproduce OpenAI results as closely as possible. However, majors differences in performance can be caused even by minor differences in TensorFlow and PyTorch libraries. + ### TODO -* Add MuJoCo and continuous actions +* Improve this README file. Rearrange images. * Improve performance of KFAC, see kfac.py for more information * Run evaluation for all games and algorithms ## Usage -### A2C +### Atari +#### A2C ``` python main.py --env-name "PongNoFrameskip-v4" ``` -### PPO +#### PPO ``` python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --num-processes 8 --num-steps 256 --vis-interval 1 --log-interval 1 ``` -### ACKTR +#### ACKTR ``` python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 --num-steps 20 ``` +### MuJoCo +#### A2C + +``` +python main.py --env-name "Reacher-v1" --num-stack 1 --num-frames 1000000 +``` + +#### PPO + +``` +python main.py --env-name "Reacher-v1" --algo ppo --use-gae --vis-interval 1 --log-interval 1 --num-stack 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --ppo-epoch 10 --batch-size 64 --gamma 0.99 --tau 0.95 --num-frames 1000000 +``` + +#### ACKTR + +ACKTR requires some modifications to be made specifically for MuJoCo. But at the moment, I want to keep this code as unified as possible. Thus, I'm going for better ways to integrate it into the codebase. + ## Results ### A2C -![BreakoutNoFrameskip-v4](imgs/breakout.png) +![BreakoutNoFrameskip-v4](imgs/a2c_breakout.png) -![SeaquestNoFrameskip-v4](imgs/seaquest.png) +![SeaquestNoFrameskip-v4](imgs/a2c_seaquest.png) -![QbertNoFrameskip-v4](imgs/qbert.png) +![QbertNoFrameskip-v4](imgs/a2c_qbert.png) -![beamriderNoFrameskip-v4](imgs/beamrider.png) +![beamriderNoFrameskip-v4](imgs/a2c_beamrider.png) ### PPO -Coming soon. + +![BreakoutNoFrameskip-v4](imgs/ppo_halfcheetah.png) + +![SeaquestNoFrameskip-v4](imgs/ppo_hopper.png) + +![QbertNoFrameskip-v4](imgs/ppo_reacher.png) + +![beamriderNoFrameskip-v4](imgs/ppo_walker.png) + ### ACKTR -Coming soon. +![BreakoutNoFrameskip-v4](imgs/acktr_breakout.png) + +![SeaquestNoFrameskip-v4](imgs/acktr_seaquest.png) + +![QbertNoFrameskip-v4](imgs/acktr_qbert.png) + +![beamriderNoFrameskip-v4](imgs/acktr_beamrider.png) diff --git a/envs.py b/envs.py index 1d221a7..626bb15 100755 --- a/envs.py +++ b/envs.py @@ -4,7 +4,7 @@ import gym from gym.spaces.box import Box from baselines import bench -from baselines.common.atari_wrappers import * +from baselines.common.atari_wrappers import wrap_deepmind def make_env(env_id, seed, rank, log_dir): @@ -14,8 +14,10 @@ def make_env(env_id, seed, rank, log_dir): env = bench.Monitor(env, os.path.join(log_dir, "{}.monitor.json".format(rank))) - env = wrap_deepmind(env) - env = WrapPyTorch(env) + # Ugly hack to detect atari. + if env.action_space.__class__.__name__ == 'Discrete': + env = wrap_deepmind(env) + env = WrapPyTorch(env) return env return _thunk diff --git a/imgs/beamrider.png b/imgs/a2c_beamrider.png similarity index 100% rename from imgs/beamrider.png rename to imgs/a2c_beamrider.png diff --git a/imgs/breakout.png b/imgs/a2c_breakout.png similarity index 100% rename from imgs/breakout.png rename to imgs/a2c_breakout.png diff --git a/imgs/qbert.png b/imgs/a2c_qbert.png similarity index 100% rename from imgs/qbert.png rename to imgs/a2c_qbert.png diff --git a/imgs/seaquest.png b/imgs/a2c_seaquest.png similarity index 100% rename from imgs/seaquest.png rename to imgs/a2c_seaquest.png diff --git a/imgs/acktr_beamrider.png b/imgs/acktr_beamrider.png new file mode 100644 index 0000000..c95f3e4 Binary files /dev/null and b/imgs/acktr_beamrider.png differ diff --git a/imgs/acktr_breakout.png b/imgs/acktr_breakout.png new file mode 100644 index 0000000..d6264df Binary files /dev/null and b/imgs/acktr_breakout.png differ diff --git a/imgs/acktr_qbert.png b/imgs/acktr_qbert.png new file mode 100644 index 0000000..ef2472c Binary files /dev/null and b/imgs/acktr_qbert.png differ diff --git a/imgs/acktr_seaquest.png b/imgs/acktr_seaquest.png new file mode 100644 index 0000000..a009c88 Binary files /dev/null and b/imgs/acktr_seaquest.png differ diff --git a/imgs/ppo_halfcheetah.png b/imgs/ppo_halfcheetah.png new file mode 100644 index 0000000..0719534 Binary files /dev/null and b/imgs/ppo_halfcheetah.png differ diff --git a/imgs/ppo_hopper.png b/imgs/ppo_hopper.png new file mode 100644 index 0000000..d9075d1 Binary files /dev/null and b/imgs/ppo_hopper.png differ diff --git a/imgs/ppo_reacher.png b/imgs/ppo_reacher.png new file mode 100644 index 0000000..5927b33 Binary files /dev/null and b/imgs/ppo_reacher.png differ diff --git a/imgs/ppo_walker.png b/imgs/ppo_walker.png new file mode 100644 index 0000000..5ab33b1 Binary files /dev/null and b/imgs/ppo_walker.png differ diff --git a/kfac.py b/kfac.py index f15e6a6..19f875a 100644 --- a/kfac.py +++ b/kfac.py @@ -11,8 +11,6 @@ import torch.optim as optim def _extract_patches(x, kernel_size, stride, padding): - #result = P.im2col(Variable(x), kernel_size, stride, padding).data - #return result.view(result.size(0), -1, result.size(-2), result.size(-1)) if padding[0] + padding[1] > 0: x = F.pad(x, (padding[1], padding[1], padding[0], padding[0])).data # Actually check dims @@ -164,7 +162,6 @@ class KFACOptimizer(optim.Optimizer): raise NotImplementedError( 'Layer {} is not supported'.format(classname)) - #@profile def step(self): # Add weight decay if self.weight_decay > 0: @@ -187,10 +184,12 @@ class KFACOptimizer(optim.Optimizer): self.m_aa[m].cpu().double(), eigenvectors=True) self.d_g[m], self.Q_g[m] = torch.symeig( self.m_gg[m].cpu().double(), eigenvectors=True) - self.d_a[m], self.Q_a[m] = self.d_a[ - m].float().cuda(), self.Q_a[m].float().cuda() - self.d_g[m], self.Q_g[m] = self.d_g[ - m].float().cuda(), self.Q_g[m].float().cuda() + self.d_a[m], self.Q_a[m] = self.d_a[m].float(), self.Q_a[m].float() + self.d_g[m], self.Q_g[m] = self.d_g[m].float(), self.Q_g[m].float() + if self.m_aa[m].is_cuda: + self.d_a[m], self.Q_a[m] = self.d_a[m].cuda(), self.Q_a[m].cuda() + self.d_g[m], self.Q_g[m] = self.d_g[m].cuda(), self.Q_g[m].cuda() + self.d_a[m].mul_((self.d_a[m] > 1e-6).float()) self.d_g[m].mul_((self.d_g[m] > 1e-6).float()) diff --git a/main.py b/main.py index 1fbacfc..f09877c 100755 --- a/main.py +++ b/main.py @@ -15,9 +15,9 @@ from arguments import get_args from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from envs import make_env from kfac import KFACOptimizer -from model import ActorCritic +from model import CNNPolicy, MLPPolicy from storage import RolloutStorage -from vizualize_atari import visdom_plot +from visualize import visdom_plot args = get_args() @@ -59,7 +59,12 @@ def main(): obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:]) - actor_critic = ActorCritic(obs_shape[0], envs.action_space) + if envs.action_space.__class__.__name__ == 'Discrete': + actor_critic = CNNPolicy(obs_shape[0], envs.action_space) + action_shape = 1 + else: + actor_critic = MLPPolicy(obs_shape[0], envs.action_space) + action_shape = envs.action_space.shape[0] if args.cuda: actor_critic.cuda() @@ -71,13 +76,15 @@ def main(): elif args.algo == 'acktr': optimizer = KFACOptimizer(actor_critic) - rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space.n) - + rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space) current_state = torch.zeros(args.num_processes, *obs_shape) + def update_current_state(state): - state = torch.from_numpy(np.stack(state)).float() - current_state[:, :-1] = current_state[:, 1:] - current_state[:, -1] = state + shape_dim0 = envs.observation_space.shape[0] + state = torch.from_numpy(state).float() + if args.num_stack > 1: + current_state[:, :-shape_dim0] = current_state[:, shape_dim0:] + current_state[:, -shape_dim0:] = state state = envs.reset() update_current_state(state) @@ -103,7 +110,6 @@ def main(): # Obser reward and next state state, reward, done, info = envs.step(cpu_actions) - reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() episode_rewards += reward @@ -115,17 +121,24 @@ def main(): if args.cuda: masks = masks.cuda() - current_state *= masks.unsqueeze(2).unsqueeze(2) + + if current_state.dim() == 4: + current_state *= masks.unsqueeze(2).unsqueeze(2) + else: + current_state *= masks update_current_state(state) rollouts.insert(step, current_state, action.data, value.data, reward, masks) next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data + if hasattr(actor_critic, 'obs_filter'): + actor_critic.obs_filter.update(rollouts.states[:-1].view(-1, *obs_shape)) + rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau) if args.algo in ['a2c', 'acktr']: - values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(rollouts.states[:-1].view(-1, *obs_shape)), Variable(rollouts.actions.view(-1, 1))) + values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(rollouts.states[:-1].view(-1, *obs_shape)), Variable(rollouts.actions.view(-1, action_shape))) values = values.view(args.num_steps, args.num_processes, 1) action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1) @@ -164,6 +177,8 @@ def main(): advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) old_model.load_state_dict(actor_critic.state_dict()) + if hasattr(actor_critic, 'obs_filter'): + old_model.obs_filter = actor_critic.obs_filter for _ in range(args.ppo_epoch): sampler = BatchSampler(SubsetRandomSampler(range(args.num_processes * args.num_steps)), args.batch_size * args.num_processes, drop_last=False) @@ -171,8 +186,8 @@ def main(): indices = torch.LongTensor(indices) if args.cuda: indices = indices.cuda() - states_batch = rollouts.states[:-1].view(-1, *rollouts.states.size()[-3:])[indices] - actions_batch = rollouts.actions.view(-1, 1)[indices] + states_batch = rollouts.states[:-1].view(-1, *obs_shape)[indices] + actions_batch = rollouts.actions.view(-1, action_shape)[indices] return_batch = rollouts.returns[:-1].view(-1, 1)[indices] # Reshape to do in a single forward pass for all steps @@ -183,7 +198,7 @@ def main(): ratio = torch.exp(action_log_probs - Variable(old_action_log_probs.data)) adv_targ = Variable(advantages.view(-1, 1)[indices]) surr1 = ratio * adv_targ - surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ + surr2 = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP) value_loss = (Variable(return_batch) - values).pow(2).mean() diff --git a/model.py b/model.py index 515c8af..918f066 100755 --- a/model.py +++ b/model.py @@ -3,6 +3,8 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Variable +from running_stat import ObsNorm def weights_init(m): @@ -28,9 +30,9 @@ class AddBias(nn.Module): return x + bias -class ActorCritic(torch.nn.Module): +class CNNPolicy(torch.nn.Module): def __init__(self, num_inputs, action_space): - super(ActorCritic, self).__init__() + super(CNNPolicy, self).__init__() self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4, bias=False) self.ab1 = AddBias(32) self.conv2 = nn.Conv2d(32, 64, 4, stride=2, bias=False) @@ -41,19 +43,20 @@ class ActorCritic(torch.nn.Module): self.linear1 = nn.Linear(32 * 7 * 7, 512, bias=False) self.ab_fc1 = AddBias(512) - num_outputs = action_space.n self.critic_linear = nn.Linear(512, 1, bias=False) self.ab_fc2 = AddBias(1) + num_outputs = action_space.n self.actor_linear = nn.Linear(512, num_outputs, bias=False) self.ab_fc3 = AddBias(num_outputs) self.apply(weights_init) - self.conv1.weight.data.mul_(math.sqrt(2)) # Multiplier for relu - self.conv2.weight.data.mul_(math.sqrt(2)) # Multiplier for relu - self.conv3.weight.data.mul_(math.sqrt(2)) # Multiplier for relu - self.linear1.weight.data.mul_(math.sqrt(2)) # Multiplier for relu + relu_gain = nn.init.calculate_gain('relu') + self.conv1.weight.data.mul_(relu_gain) + self.conv2.weight.data.mul_(relu_gain) + self.conv3.weight.data.mul_(relu_gain) + self.linear1.weight.data.mul_(relu_gain) self.train() @@ -97,3 +100,112 @@ class ActorCritic(torch.nn.Module): dist_entropy = -(log_probs * probs).sum(-1).mean() return values, action_log_probs, dist_entropy + + +def weights_init_mlp(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + m.weight.data.normal_(0, 1) + m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) + if m.bias is not None: + m.bias.data.fill_(0) + +class MLPPolicy(torch.nn.Module): + def __init__(self, num_inputs, action_space): + super(MLPPolicy, self).__init__() + + self.obs_filter = ObsNorm((1, num_inputs), clip=5) + self.action_space = action_space + + self.a_fc1 = nn.Linear(num_inputs, 64, bias=False) + self.a_ab1 = AddBias(64) + self.a_fc2 = nn.Linear(64, 64, bias=False) + self.a_ab2 = AddBias(64) + self.a_fc_mean = nn.Linear(64, action_space.shape[0], bias=False) + self.a_ab_mean = AddBias(action_space.shape[0]) + self.a_ab_logstd = AddBias(action_space.shape[0]) + + self.v_fc1 = nn.Linear(num_inputs, 64, bias=False) + self.v_ab1 = AddBias(64) + self.v_fc2 = nn.Linear(64, 64, bias=False) + self.v_ab2 = AddBias(64) + self.v_fc3 = nn.Linear(64, 1, bias=False) + self.v_ab3 = AddBias(1) + + self.apply(weights_init_mlp) + + tanh_gain = nn.init.calculate_gain('tanh') + #self.a_fc1.weight.data.mul_(tanh_gain) + #self.a_fc2.weight.data.mul_(tanh_gain) + self.a_fc_mean.weight.data.mul_(0.01) + #self.v_fc1.weight.data.mul_(tanh_gain) + #self.v_fc2.weight.data.mul_(tanh_gain) + + self.train() + + def cuda(self, **args): + super(MLPPolicy, self).cuda(**args) + self.obs_filter.cuda() + + def forward(self, inputs): + inputs.data = self.obs_filter(inputs.data) + + x = self.v_fc1(inputs) + x = self.v_ab1(x) + x = F.tanh(x) + + x = self.v_fc2(x) + x = self.v_ab2(x) + x = F.tanh(x) + + x = self.v_fc3(x) + x = self.v_ab3(x) + value = x + + x = self.a_fc1(inputs) + x = self.a_ab1(x) + x = F.tanh(x) + + x = self.a_fc2(x) + x = self.a_ab2(x) + x = F.tanh(x) + + x = self.a_fc_mean(x) + x = self.a_ab_mean(x) + action_mean = x + + # An ugly hack for my KFAC implementation. + zeros = Variable(torch.zeros(x.size()), volatile=x.volatile) + if x.is_cuda: + zeros = zeros.cuda() + + x = self.a_ab_logstd(zeros) + action_logstd = x + + return value, action_mean, action_logstd + + def act(self, inputs): + value, action_mean, action_logstd = self(inputs) + + action_std = action_logstd.exp() + + noise = Variable(torch.randn(action_std.size())) + if action_std.is_cuda: + noise = noise.cuda() + + action = action_mean + action_std * noise + return value, action + + def evaluate_actions(self, inputs, actions): + assert inputs.dim() == 2, "Expect to have inputs in num_processes * num_steps x ... format" + + value, action_mean, action_logstd = self(inputs) + + action_std = action_logstd.exp() + + action_log_probs = -0.5 * ((actions - action_mean) / action_std).pow(2) - 0.5 * math.log(2 * math.pi) - action_logstd + action_log_probs = action_log_probs.sum(1, keepdim=True) + dist_entropy = 0.5 + math.log(2 * math.pi) + action_log_probs + dist_entropy = dist_entropy.sum(-1).mean() + + return value, action_log_probs, dist_entropy diff --git a/running_stat.py b/running_stat.py new file mode 100644 index 0000000..41fe711 --- /dev/null +++ b/running_stat.py @@ -0,0 +1,44 @@ +import random + +import torch + +class ObsNorm(object): + def __init__(self, shape, demean=True, destd=True, clip=10.0): + self.demean = demean + self.destd = destd + self.clip = clip + + self.count = torch.zeros(1).double() + 1e-2 + self.sum = torch.zeros(shape).double() + self.sum_sqr = torch.zeros(shape).double() + 1e-2 + + self.mean = torch.zeros(shape) + self.std = torch.ones(shape) + + def cuda(self): + self.count = self.count.cuda() + self.sum = self.sum.cuda() + self.sum_sqr = self.sum_sqr.cuda() + + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def update(self, x): + self.count += x.size(0) + self.sum += x.sum(0, keepdim=True).double() + self.sum_sqr += x.pow(2).sum(0, keepdim=True).double() + + self.mean = self.sum / self.count + self.std = (self.sum_sqr / self.count - self.mean.pow(2)).clamp(1e-2, 1e9).sqrt() + + self.mean = self.mean.float() + self.std = self.std.float() + + def __call__(self, x): + if self.demean: + x = x - self.mean + if self.destd: + x = x / self.std + if self.clip: + x = x.clamp(-self.clip, self.clip) + return x diff --git a/storage.py b/storage.py index e449fd8..325b58e 100644 --- a/storage.py +++ b/storage.py @@ -2,13 +2,19 @@ import torch class RolloutStorage(object): - def __init__(self, num_steps, num_processes, obs_shape, action_shape): + def __init__(self, num_steps, num_processes, obs_shape, action_space): self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape) self.rewards = torch.zeros(num_steps, num_processes, 1) self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) self.returns = torch.zeros(num_steps + 1, num_processes, 1) - self.actions = torch.LongTensor(num_steps, num_processes, 1) - self.masks = torch.zeros(num_steps, num_processes, 1) + if action_space.__class__.__name__ == 'Discrete': + action_shape = 1 + else: + action_shape = action_space.shape[0] + self.actions = torch.zeros(num_steps, num_processes, action_shape) + if action_space.__class__.__name__ == 'Discrete': + self.actions = self.actions.long() + self.masks = torch.ones(num_steps + 1, num_processes, 1) def cuda(self): self.states = self.states.cuda() @@ -30,8 +36,7 @@ class RolloutStorage(object): self.value_preds[-1] = next_value gae = 0 for step in reversed(range(self.rewards.size(0))): - delta = self.rewards[step] + gamma * self.value_preds[step + - 1] * self.masks[step] - self.value_preds[step] + delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step] - self.value_preds[step] gae = delta + gamma * tau * self.masks[step] * gae self.returns[step] = gae + self.value_preds[step] else: diff --git a/vizualize_atari.py b/visualize.py similarity index 88% rename from vizualize_atari.py rename to visualize.py index c67c3b8..ef80268 100644 --- a/vizualize_atari.py +++ b/visualize.py @@ -106,12 +106,20 @@ def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1): fig = plt.figure() plt.plot(tx, ty, label="{}".format(name)) - plt.xticks([4*1e6, 4*2e6, 4*4e6, 4*6e6, 4*8e6, 4*10e6], - ["1M", "2M", "4M", "6M", "8M", "10M"]) + + # Ugly hack to detect atari + if game.find('NoFrameskip') > -1: + plt.xticks([4*1e6, 4*2e6, 4*4e6, 4*6e6, 4*8e6, 4*10e6], + ["1M", "2M", "4M", "6M", "8M", "10M"]) + plt.xlim(0, 40e6) + else: + plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5], + ["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"]) + plt.xlim(0, 1e6) + plt.xlabel('Number of Timesteps') plt.ylabel('Rewards') - plt.xlim(0, 40e6) plt.title(game) plt.legend(loc=4) @@ -130,5 +138,4 @@ def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1): if __name__ == "__main__": from visdom import Visdom viz = Visdom() - visdom_plot( - viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1) + visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)