plot axis

This commit is contained in:
wassname
2017-10-30 15:25:34 +08:00
parent 8e294a095c
commit 4e5df66088
2 changed files with 8 additions and 4 deletions
@@ -1,6 +1,8 @@
import os
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import logging
class LivePlotNotebook(object):
@@ -15,8 +17,10 @@ class LivePlotNotebook(object):
liveplot.update(x, [ya,yb])
"""
def __init__(self, log_dir=None, episode=0, labels=[], title=''):
# TODO check warn if note using right matplotlib backend
def __init__(self, log_dir=None, episode=0, labels=[], title='', ylabel='returns'):
if not matplotlib.rcParams['backend'] == 'nbAgg':
logging.warn("The liveplot callback only work when matplotlib is using the nbAgg backend. Execute 'matplotlib.use('nbAgg', force=True)'' or '%matplotlib notebook'")
self.log_dir = log_dir
self.i = episode
@@ -325,7 +325,7 @@ class PortfolioEnv(gym.Env):
# plot prices and performance
if not self._plot:
self._plot = LivePlotNotebook(
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"])
'/tmp', title='prices & performance', labels=self.sim.asset_names + ["Portfolio"], ylabel='value')
x = df_info.index
y_portfolio = df_info["portfolio_value"]
y_assets = [df_info['price_' + name].cumprod()
@@ -335,7 +335,7 @@ class PortfolioEnv(gym.Env):
# plot portfolio weights
if not self._plot2:
self._plot2 = LivePlotNotebook(
'/tmp', labels=self.sim.asset_names, title='weights')
'/tmp', labels=self.sim.asset_names, title='weights', ylabel='weight')
ys = [df_info['weight_' + name] for name in self.sim.asset_names]
self._plot2.update(x, ys)