mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Add visualization
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import gym
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable
|
||||
|
||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from envs import make_env
|
||||
from model import ActorCritic
|
||||
from torch.autograd import Variable
|
||||
import argparse
|
||||
from vizualize_atari import visdom_plot
|
||||
|
||||
parser = argparse.ArgumentParser(description='A3C')
|
||||
parser.add_argument('--lr', type=float, default=7e-4,
|
||||
@@ -37,6 +40,8 @@ parser.add_argument('--num-stack', type=int, default=4,
|
||||
help='number of frames to stack (default: 4)')
|
||||
parser.add_argument('--log-interval', type=int, default=10,
|
||||
help='log interval, one log per n updates (default: 10)')
|
||||
parser.add_argument('--vis-interval', type=int, default=100,
|
||||
help='vis interval, one log per n updates (default: 100)')
|
||||
parser.add_argument('--num-frames', type=int, default=10e6,
|
||||
help='number of frames to train (default: 10e6)')
|
||||
parser.add_argument('--env-name', default='PongNoFrameskip-v4',
|
||||
@@ -45,10 +50,13 @@ parser.add_argument('--log-dir', default='/tmp/gym/',
|
||||
help='directory to save agent logs (default: /tmp/gym)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--no-vis', action='store_true', default=False,
|
||||
help='disables visdom visualization')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
args.vis = not args.no_vis
|
||||
|
||||
num_updates = int(args.num_frames) // args.num_steps // args.num_processes
|
||||
|
||||
@@ -59,17 +67,27 @@ if args.cuda:
|
||||
try:
|
||||
os.makedirs(args.log_dir)
|
||||
except OSError:
|
||||
pass
|
||||
files = glob.glob(os.path.join(args.log_dir, '*.monitor.json'))
|
||||
for f in files:
|
||||
os.remove(f)
|
||||
|
||||
|
||||
def main():
|
||||
print("#######")
|
||||
print("WARNING: All rewards are clipped so you need to use a monitor (see envs.py) to get true rewards")
|
||||
print("WARNING: All rewards are clipped so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
|
||||
print("#######")
|
||||
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
|
||||
envs = SubprocVecEnv([make_env(args.env_name, args.seed, i, args.log_dir)
|
||||
for i in range(args.num_processes)])
|
||||
if args.vis:
|
||||
from visdom import Visdom
|
||||
viz = Visdom()
|
||||
win = None
|
||||
|
||||
envs = SubprocVecEnv([
|
||||
make_env(args.env_name, args.seed, i, args.log_dir)
|
||||
for i in range(args.num_processes)
|
||||
])
|
||||
|
||||
actor_critic = ActorCritic(envs.observation_space.shape[0] * args.num_stack, envs.action_space)
|
||||
|
||||
@@ -77,7 +95,6 @@ def main():
|
||||
actor_critic.cuda()
|
||||
|
||||
optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)
|
||||
#optimizer = KFACOptimizer(actor_critic, damping=1e-2, kl_clip=0.01, stat_decay=0.99)
|
||||
|
||||
obs_shape = envs.observation_space.shape
|
||||
obs_shape = (obs_shape[0] * args.num_stack, obs_shape[1], obs_shape[2])
|
||||
@@ -185,8 +202,17 @@ def main():
|
||||
states[0].copy_(states[-1])
|
||||
|
||||
if j % args.log_interval == 0:
|
||||
print("Updates {}, num frames {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".format(
|
||||
j, j * args.num_processes * args.num_steps, final_rewards.mean(), final_rewards.median(), final_rewards.min(), final_rewards.max(), -dist_entropy.data[0], value_loss.data[0], action_loss.data[0]))
|
||||
print("Updates {}, num frames {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
|
||||
format(j, j * args.num_processes * args.num_steps,
|
||||
final_rewards.mean(),
|
||||
final_rewards.median(),
|
||||
final_rewards.min(),
|
||||
final_rewards.max(), -dist_entropy.data[0],
|
||||
value_loss.data[0], action_loss.data[0]))
|
||||
|
||||
if j % args.vis_interval == 0:
|
||||
win = visdom_plot(viz, win, args.log_dir, args.env_name, 'a2c')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
|
||||
# and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
|
||||
# Thanks to the author and OpenAI team!
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.signal import medfilt
|
||||
matplotlib.rcParams.update({'font.size': 8})
|
||||
|
||||
|
||||
def smooth_reward_curve(x, y):
|
||||
# Halfwidth of our smoothing convolution
|
||||
halfwidth = min(31, int(np.ceil(len(x) / 30)))
|
||||
k = halfwidth
|
||||
xsmoo = x[k:-k]
|
||||
ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
|
||||
np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
|
||||
downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
|
||||
return xsmoo[::downsample], ysmoo[::downsample]
|
||||
|
||||
|
||||
def fix_point(x, y, interval):
|
||||
np.insert(x, 0, 0)
|
||||
np.insert(y, 0, 0)
|
||||
|
||||
fx, fy = [], []
|
||||
pointer = 0
|
||||
|
||||
ninterval = int(max(x) / interval + 1)
|
||||
|
||||
for i in range(ninterval):
|
||||
tmpx = interval * i
|
||||
|
||||
while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
|
||||
pointer += 1
|
||||
|
||||
if pointer + 1 < len(x):
|
||||
alpha = (y[pointer + 1] - y[pointer]) / \
|
||||
(x[pointer + 1] - x[pointer])
|
||||
tmpy = y[pointer] + alpha * (tmpx - x[pointer])
|
||||
fx.append(tmpx)
|
||||
fy.append(tmpy)
|
||||
|
||||
return fx, fy
|
||||
|
||||
|
||||
def load_data(indir, smooth, bin_size):
|
||||
datas = []
|
||||
infiles = glob.glob(os.path.join(indir, '*monitor.json'))
|
||||
|
||||
for inf in infiles:
|
||||
with open(inf, 'r') as f:
|
||||
t_start = float(json.loads(f.readline())['t_start'])
|
||||
for line in f:
|
||||
tmp = json.loads(line)
|
||||
t_time = float(tmp['t']) + t_start
|
||||
tmp = [t_time, int(tmp['l']), float(tmp['r'])]
|
||||
datas.append(tmp)
|
||||
|
||||
datas = sorted(datas, key=lambda d_entry: d_entry[0])
|
||||
result = []
|
||||
timesteps = 0
|
||||
for i in range(len(datas)):
|
||||
result.append([timesteps, datas[i][-1]])
|
||||
timesteps += datas[i][1]
|
||||
|
||||
if len(result) < bin_size:
|
||||
return [None, None]
|
||||
|
||||
x, y = np.array(result)[:, 0], np.array(result)[:, 1]
|
||||
|
||||
if smooth == 1:
|
||||
x, y = smooth_reward_curve(x, y)
|
||||
|
||||
if smooth == 2:
|
||||
y = medfilt(y, kernel_size=9)
|
||||
|
||||
x, y = fix_point(x, y, bin_size)
|
||||
return [x, y]
|
||||
|
||||
|
||||
color_defaults = [
|
||||
'#1f77b4', # muted blue
|
||||
'#ff7f0e', # safety orange
|
||||
'#2ca02c', # cooked asparagus green
|
||||
'#d62728', # brick red
|
||||
'#9467bd', # muted purple
|
||||
'#8c564b', # chestnut brown
|
||||
'#e377c2', # raspberry yogurt pink
|
||||
'#7f7f7f', # middle gray
|
||||
'#bcbd22', # curry yellow-green
|
||||
'#17becf' # blue-teal
|
||||
]
|
||||
|
||||
|
||||
def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
|
||||
tx, ty = load_data(folder, smooth, bin_size)
|
||||
if tx is None or ty is None:
|
||||
return win
|
||||
|
||||
fig = plt.figure()
|
||||
plt.plot(tx, ty, label="{}".format(name))
|
||||
plt.xticks([4*1e6, 4*2e6, 4*4e6, 4*6e6, 4*8e6, 4*10e6],
|
||||
["1M", "2M", "4M", "6M", "8M", "10M"])
|
||||
plt.xlabel('Number of Timesteps')
|
||||
plt.ylabel('Rewards')
|
||||
|
||||
plt.xlim(0, 40e6)
|
||||
|
||||
plt.title(game)
|
||||
plt.legend(loc=4)
|
||||
plt.show()
|
||||
plt.draw()
|
||||
|
||||
image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
plt.close(fig)
|
||||
|
||||
# Show it in visdom
|
||||
image = np.transpose(image, (2, 0, 1))
|
||||
return viz.image(image, win=win)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from visdom import Visdom
|
||||
viz = Visdom()
|
||||
visdom_plot(
|
||||
viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)
|
||||
Reference in New Issue
Block a user