Files
Run-Skeleton-Run/baselines/ppo.py
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

172 lines
7.3 KiB
Python

import tensorflow as tf
import numpy as np
import time
from mpi4py import MPI
from collections import deque
from contextlib import contextmanager
from common.logger import Logger
from baselines.baselines_common import Dataset, explained_variance, fmt_row, zipsame
import baselines.baselines_common.tf_util as U
from baselines.baselines_common.mpi_adam import MpiAdam
from baselines.baselines_common.mpi_saver import MpiSaver
from baselines.baselines_common.mpi_moments import mpi_moments
from baselines.trajectories import traj_segment_generator, add_vtarg_and_adv
def learn(env, policy_func, args, *,
timesteps_per_batch, # timesteps per actor per update
clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
optim_epochs, optim_stepsize, optim_batchsize, # optimization hypers
gamma, lam, # advantage estimation
adam_epsilon=1e-5,
schedule='constant'): # annealing for stepsize parameters (epsilon and adam),
# Setup losses and stuff
# ----------------------------------------
ob_space = env.observation_space
ac_space = env.action_space
pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy
atarg = tf.placeholder(dtype=tf.float32,
shape=[None]) # Target advantage function (if applicable)
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
shape=[]) # learning rate multiplier, updated with schedule
clip_param = clip_param * lrmult # Annealed cliping parameter epislon
ob = U.get_placeholder_cached(name="ob")
ac = pi.pdtype.sample_placeholder([None])
kloldnew = oldpi.pd.kl(pi.pd)
ent = pi.pd.entropy()
meankl = U.mean(kloldnew)
meanent = U.mean(ent)
pol_entpen = (-entcoeff) * meanent
ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
surr1 = ratio * atarg # surrogate from conservative policy iteration
surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
vf_loss = U.mean(tf.square(pi.vpred - ret))
total_loss = pol_surr + pol_entpen + vf_loss
losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
var_list = pi.get_trainable_variables()
lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
losses + [U.flatgrad(total_loss, var_list)])
adam = MpiAdam(var_list, epsilon=adam_epsilon)
policy_var_list = [v for v in var_list if v.name.split("/")[0].startswith("pi")]
saver = MpiSaver(policy_var_list, log_prefix=args.logdir)
assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
for (oldv, newv) in
zipsame(oldpi.get_variables(),
pi.get_variables())])
compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)
U.initialize()
saver.restore(restore_from=args.restore_actor_from)
adam.sync()
# Prepare for rollouts
# ----------------------------------------
seg_gen = traj_segment_generator(pi, env, args, timesteps_per_batch, stochastic=True)
episodes_so_far = 0
timesteps_so_far = 0
iters_so_far = 0
tstart = time.time()
lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards
# max_timesteps = 1e10
cur_lrmult = 1.0
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
logger = Logger(args.logdir)
while time.time() - tstart < 86400 * args.max_train_days:
# if schedule == 'constant':
# cur_lrmult = 1.0
# elif schedule == 'linear':
# cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
# else:
# raise NotImplementedError
# logger.log("********** Iteration %i ************" % iters_so_far)
seg = seg_gen.__next__()
add_vtarg_and_adv(seg, gamma, lam)
# ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
vpredbefore = seg["vpred"] # predicted value function before udpate
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=True)
optim_batchsize = optim_batchsize or ob.shape[0]
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
assign_old_eq_new() # set old parameter values to new parameter values
# logger.log("Optimizing...")
# logger.log(fmt_row(13, loss_names))
# Here we do a bunch of optimization epochs over the data
for _ in range(optim_epochs):
losses = [] # list of tuples, each of which gives the loss for a minibatch
for batch in d.iterate_once(optim_batchsize):
*newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"],
batch["vtarg"], cur_lrmult)
adam.update(g, optim_stepsize * cur_lrmult)
losses.append(newlosses)
# logger.log(fmt_row(13, np.mean(losses, axis=0)))
saver.sync()
# logger.log("Evaluating losses...")
losses = []
for batch in d.iterate_once(optim_batchsize):
newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"],
cur_lrmult)
losses.append(newlosses)
meanlosses, _, _ = mpi_moments(losses, axis=0)
# logger.log(fmt_row(13, meanlosses))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
# Logging
logger.scalar_summary("episodes", len(lens), iters_so_far)
for (lossname, lossval) in zip(loss_names, meanlosses):
logger.scalar_summary(lossname, lossval, episodes_so_far)
logger.scalar_summary("ev_tdlam_before", explained_variance(vpredbefore, tdlamret), episodes_so_far)
logger.scalar_summary("step", np.mean(lenbuffer), episodes_so_far)
logger.scalar_summary("reward", np.mean(rewbuffer), episodes_so_far)
logger.scalar_summary("best reward", np.max(rewbuffer), episodes_so_far)
elapsed_time = time.time() - tstart
logger.scalar_summary(
"episode per minute",
episodes_so_far / elapsed_time * 60,
episodes_so_far)
logger.scalar_summary(
"step per second",
timesteps_so_far / elapsed_time,
episodes_so_far)
def flatten_lists(listoflists):
return [el for list_ in listoflists for el in list_]