mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:31:16 +08:00
45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
import plotly
|
|
from plotly.graph_objs import Scatter, Line
|
|
import torch
|
|
|
|
|
|
steps = []
|
|
def plot_line(xs, ys_population, algo):
|
|
steps.append(xs)
|
|
if algo == "SAC":
|
|
colour = 'rgb(0, 172, 237)'
|
|
elif algo == "SAC(GMM)":
|
|
colour = 'rgb(0, 172, 12)'
|
|
else:
|
|
colour = 'rgb(172, 12, 0)'
|
|
|
|
ys = torch.Tensor(ys_population)
|
|
|
|
ys = ys.squeeze()
|
|
|
|
trace = Scatter(x=steps, y=ys.numpy(), line=Line(color=colour), name='Reward')
|
|
|
|
if algo == "SAC(GMM)":
|
|
plotly.offline.plot({
|
|
'data': [trace],
|
|
'layout': dict(title='SAC(GMM)',
|
|
xaxis={'title': 'Steps'},
|
|
yaxis={'title': 'Reward'})
|
|
}, filename='SAC(GMM).html', auto_open=False)
|
|
elif algo == "SAC":
|
|
plotly.offline.plot({
|
|
'data': [trace],
|
|
'layout': dict(title='SAC',
|
|
xaxis={'title': 'Steps'},
|
|
yaxis={'title': 'Reward'})
|
|
}, filename='SAC.html', auto_open=False)
|
|
else:
|
|
plotly.offline.plot({
|
|
'data': [trace],
|
|
'layout': dict(title=algo,
|
|
xaxis={'title': 'Steps'},
|
|
yaxis={'title': 'Reward'})
|
|
}, filename='{}.html'.format(algo), auto_open=False)
|
|
|
|
|