mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
142 lines
3.8 KiB
Python
142 lines
3.8 KiB
Python
# Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
|
|
# and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
|
|
# Thanks to the author and OpenAI team!
|
|
|
|
import glob
|
|
import json
|
|
import os
|
|
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from scipy.signal import medfilt
|
|
matplotlib.rcParams.update({'font.size': 8})
|
|
|
|
|
|
def smooth_reward_curve(x, y):
|
|
# Halfwidth of our smoothing convolution
|
|
halfwidth = min(31, int(np.ceil(len(x) / 30)))
|
|
k = halfwidth
|
|
xsmoo = x[k:-k]
|
|
ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
|
|
np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
|
|
downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
|
|
return xsmoo[::downsample], ysmoo[::downsample]
|
|
|
|
|
|
def fix_point(x, y, interval):
|
|
np.insert(x, 0, 0)
|
|
np.insert(y, 0, 0)
|
|
|
|
fx, fy = [], []
|
|
pointer = 0
|
|
|
|
ninterval = int(max(x) / interval + 1)
|
|
|
|
for i in range(ninterval):
|
|
tmpx = interval * i
|
|
|
|
while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
|
|
pointer += 1
|
|
|
|
if pointer + 1 < len(x):
|
|
alpha = (y[pointer + 1] - y[pointer]) / \
|
|
(x[pointer + 1] - x[pointer])
|
|
tmpy = y[pointer] + alpha * (tmpx - x[pointer])
|
|
fx.append(tmpx)
|
|
fy.append(tmpy)
|
|
|
|
return fx, fy
|
|
|
|
|
|
def load_data(indir, smooth, bin_size):
|
|
datas = []
|
|
infiles = glob.glob(os.path.join(indir, '*monitor.json'))
|
|
|
|
for inf in infiles:
|
|
with open(inf, 'r') as f:
|
|
t_start = float(json.loads(f.readline())['t_start'])
|
|
for line in f:
|
|
tmp = json.loads(line)
|
|
t_time = float(tmp['t']) + t_start
|
|
tmp = [t_time, int(tmp['l']), float(tmp['r'])]
|
|
datas.append(tmp)
|
|
|
|
datas = sorted(datas, key=lambda d_entry: d_entry[0])
|
|
result = []
|
|
timesteps = 0
|
|
for i in range(len(datas)):
|
|
result.append([timesteps, datas[i][-1]])
|
|
timesteps += datas[i][1]
|
|
|
|
if len(result) < bin_size:
|
|
return [None, None]
|
|
|
|
x, y = np.array(result)[:, 0], np.array(result)[:, 1]
|
|
|
|
if smooth == 1:
|
|
x, y = smooth_reward_curve(x, y)
|
|
|
|
if smooth == 2:
|
|
y = medfilt(y, kernel_size=9)
|
|
|
|
x, y = fix_point(x, y, bin_size)
|
|
return [x, y]
|
|
|
|
|
|
color_defaults = [
|
|
'#1f77b4', # muted blue
|
|
'#ff7f0e', # safety orange
|
|
'#2ca02c', # cooked asparagus green
|
|
'#d62728', # brick red
|
|
'#9467bd', # muted purple
|
|
'#8c564b', # chestnut brown
|
|
'#e377c2', # raspberry yogurt pink
|
|
'#7f7f7f', # middle gray
|
|
'#bcbd22', # curry yellow-green
|
|
'#17becf' # blue-teal
|
|
]
|
|
|
|
|
|
def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
|
|
tx, ty = load_data(folder, smooth, bin_size)
|
|
if tx is None or ty is None:
|
|
return win
|
|
|
|
fig = plt.figure()
|
|
plt.plot(tx, ty, label="{}".format(name))
|
|
|
|
# Ugly hack to detect atari
|
|
if game.find('NoFrameskip') > -1:
|
|
plt.xticks([4*1e6, 4*2e6, 4*4e6, 4*6e6, 4*8e6, 4*10e6],
|
|
["1M", "2M", "4M", "6M", "8M", "10M"])
|
|
plt.xlim(0, 40e6)
|
|
else:
|
|
plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5],
|
|
["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"])
|
|
plt.xlim(0, 1e6)
|
|
|
|
plt.xlabel('Number of Timesteps')
|
|
plt.ylabel('Rewards')
|
|
|
|
|
|
plt.title(game)
|
|
plt.legend(loc=4)
|
|
plt.show()
|
|
plt.draw()
|
|
|
|
image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
|
plt.close(fig)
|
|
|
|
# Show it in visdom
|
|
image = np.transpose(image, (2, 0, 1))
|
|
return viz.image(image, win=win)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from visdom import Visdom
|
|
viz = Visdom()
|
|
visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)
|