mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 17:02:01 +08:00
tensorboard logger
This commit is contained in:
+3
-2
@@ -1,8 +1,9 @@
|
||||
secrets/secrets.json
|
||||
outputs
|
||||
outputs/
|
||||
.cache
|
||||
.trash
|
||||
models
|
||||
models/
|
||||
logs/
|
||||
|
||||
# Created by https://www.gitignore.io/api/linux,python,windows
|
||||
|
||||
|
||||
@@ -279,4 +279,4 @@ class PortfolioEnv(gym.Env):
|
||||
sharpe_ratio = sharpe(df_info.rate_of_return)
|
||||
title='max_drawdown={: 2.2%} sharpe_ratio={: 2.4f}'.format(mdd,sharpe_ratio)
|
||||
|
||||
df_info[["portfolio_value", "market_value"]].plot(title=title)
|
||||
df_info[["portfolio_value", "market_value"]].plot(title=title, fig=plt.gcf())
|
||||
|
||||
+292
-159
@@ -12,8 +12,8 @@
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.002254Z",
|
||||
"start_time": "2017-10-15T01:42:13.297490Z"
|
||||
"end_time": "2017-10-15T02:29:58.356761Z",
|
||||
"start_time": "2017-10-15T02:29:49.898984Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -33,6 +33,7 @@
|
||||
"# util\n",
|
||||
"from collections import Counter\n",
|
||||
"import pdb\n",
|
||||
"import glob\n",
|
||||
"import time\n",
|
||||
"import tempfile\n",
|
||||
"import itertools\n",
|
||||
@@ -52,8 +53,8 @@
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.101075Z",
|
||||
"start_time": "2017-10-15T01:42:14.003911Z"
|
||||
"end_time": "2017-10-15T02:29:59.755732Z",
|
||||
"start_time": "2017-10-15T02:29:58.358390Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -69,8 +70,8 @@
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.157669Z",
|
||||
"start_time": "2017-10-15T01:42:14.102563Z"
|
||||
"end_time": "2017-10-15T02:29:59.821678Z",
|
||||
"start_time": "2017-10-15T02:29:59.757241Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -84,21 +85,30 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.201990Z",
|
||||
"start_time": "2017-10-15T01:42:14.158832Z"
|
||||
"end_time": "2017-10-15T02:34:29.996288Z",
|
||||
"start_time": "2017-10-15T02:34:29.918670Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'./outputs/tensorforce-PPO-prioritised/tensorforce-PPO-prioritised_20171012_04-42-55.model'"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" \r",
|
||||
"95/|/reward: -0.0000 [-0.0000, -0.0000], portfolio_value: 0.9998 [ 0.9998, 0.9998] expl= 81.00% 0%|| 95/200000 [03:10<105:36:04, 1.90s/it]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'./outputs/tensorforce-PPO-prioritised/tensorforce-PPO-prioritised_20171015_02-34-23.model'"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -106,9 +116,41 @@
|
||||
"source": [
|
||||
"# params\n",
|
||||
"window_length = 50\n",
|
||||
"\n",
|
||||
"save_path = './outputs/tensorforce-PPO-prioritised/tensorforce-PPO-prioritised_20171012_04-42-55.model'\n",
|
||||
"save_path"
|
||||
"import datetime\n",
|
||||
"ts = datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')\n",
|
||||
"save_path = './outputs/tensorforce-PPO-prioritised/tensorforce-PPO-prioritised_%s.model' % ts\n",
|
||||
"save_path = './outputs/tensorforce-PPO-prioritised/tensorforce-PPO-prioritised_20171015_02-34-23.model'\n",
|
||||
"save_path\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T02:34:30.609819Z",
|
||||
"start_time": "2017-10-15T02:34:30.524610Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'logs/tensorforce-PPO-prioritised_20171015_02-34-23'"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"log_dir = os.path.join('logs', os.path.splitext(os.path.basename(save_path))[0])\n",
|
||||
"try:\n",
|
||||
" os.makedirs(log_dir)\n",
|
||||
"except OSError:\n",
|
||||
" pass\n",
|
||||
"log_dir"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -120,11 +162,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.245473Z",
|
||||
"start_time": "2017-10-15T01:42:14.203055Z"
|
||||
"end_time": "2017-10-15T02:29:59.948831Z",
|
||||
"start_time": "2017-10-15T02:29:59.910316Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -135,11 +177,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.311788Z",
|
||||
"start_time": "2017-10-15T01:42:14.247810Z"
|
||||
"end_time": "2017-10-15T02:30:00.017952Z",
|
||||
"start_time": "2017-10-15T02:29:59.949921Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -168,11 +210,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.618015Z",
|
||||
"start_time": "2017-10-15T01:42:14.313685Z"
|
||||
"end_time": "2017-10-15T02:30:00.993572Z",
|
||||
"start_time": "2017-10-15T02:30:00.019275Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -181,9 +223,9 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:gym.envs.registration:Making new env: CartPole-v0\n",
|
||||
"[2017-10-15 09:42:14,602] Making new env: CartPole-v0\n",
|
||||
"[2017-10-15 10:30:00,929] Making new env: CartPole-v0\n",
|
||||
"INFO:gym.envs.registration:Making new env: CartPole-v0\n",
|
||||
"[2017-10-15 09:42:14,608] Making new env: CartPole-v0\n"
|
||||
"[2017-10-15 10:30:00,987] Making new env: CartPole-v0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -191,7 +233,7 @@
|
||||
"df_train = pd.read_hdf('./data/poloniex_30m.hf',key='train')\n",
|
||||
"env = EnvWrapper(\n",
|
||||
" df=df_train,\n",
|
||||
" steps=30, \n",
|
||||
" steps=300, \n",
|
||||
" scale=True, \n",
|
||||
" augment=0.000,\n",
|
||||
" trading_cost=0, # let just overfit first,\n",
|
||||
@@ -203,7 +245,7 @@
|
||||
"df_test = pd.read_hdf('./data/poloniex_30m.hf',key='test')\n",
|
||||
"env_test = EnvWrapper(\n",
|
||||
" df=df_test,\n",
|
||||
" steps=30, \n",
|
||||
" steps=300, \n",
|
||||
" scale=True, \n",
|
||||
" augment=0.00,\n",
|
||||
" trading_cost=0, # let just overfit first\n",
|
||||
@@ -221,11 +263,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:14.690869Z",
|
||||
"start_time": "2017-10-15T01:42:14.619668Z"
|
||||
"end_time": "2017-10-15T02:30:01.088045Z",
|
||||
"start_time": "2017-10-15T02:30:00.994806Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -233,7 +275,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.0 False {'reward': 0.0, 'log_return': 0.0, 'portfolio_value': 1.0, 'return': 1.0038234525884082, 'rate_of_return': 0.0, 'cash_bias': 0.19920207637805584, 'cost': 0.0, 'market_value': 1.0015682040003333, 'date': 1421807400.0, 'steps': 2}\n",
|
||||
"0.0 False {'reward': 0.0, 'log_return': 0.0, 'portfolio_value': 1.0, 'return': 1.0023510634010704, 'rate_of_return': 0.0, 'cash_bias': 0.093732749593602255, 'cost': 0.0, 'market_value': 1.0023516084586421, 'date': 1429261200.0, 'steps': 2}\n",
|
||||
"(5, 50, 3) (5, 50, 3)\n"
|
||||
]
|
||||
}
|
||||
@@ -262,11 +304,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:30.420676Z",
|
||||
"start_time": "2017-10-15T01:42:14.692354Z"
|
||||
"end_time": "2017-10-15T02:30:55.981253Z",
|
||||
"start_time": "2017-10-15T02:30:55.907749Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -279,11 +321,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:30.769981Z",
|
||||
"start_time": "2017-10-15T01:42:30.423107Z"
|
||||
"end_time": "2017-10-15T02:30:56.395442Z",
|
||||
"start_time": "2017-10-15T02:30:56.144527Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -365,11 +407,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:31.635122Z",
|
||||
"start_time": "2017-10-15T01:42:30.771243Z"
|
||||
"end_time": "2017-10-15T02:30:56.471251Z",
|
||||
"start_time": "2017-10-15T02:30:56.396534Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -400,78 +442,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:43:14.458362Z",
|
||||
"start_time": "2017-10-15T01:43:14.384891Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# Check my PR is included\n",
|
||||
"import tensorforce.core.memories\n",
|
||||
"assert isinstance(runner.agent.memory,tensorforce.core.memories.PrioritizedReplay)\n",
|
||||
"assert isinstance(runner.agent, tensorforce.agents.MemoryAgent)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:48:15.367117Z",
|
||||
"start_time": "2017-10-15T01:48:15.294327Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'action0': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()},\n",
|
||||
" 'action1': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()},\n",
|
||||
" 'action2': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()},\n",
|
||||
" 'action3': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()},\n",
|
||||
" 'action4': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()},\n",
|
||||
" 'action5': {'continuous': True,\n",
|
||||
" 'max_value': None,\n",
|
||||
" 'min_value': None,\n",
|
||||
" 'shape': ()}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config.actions.as_dict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:51:06.768136Z",
|
||||
"start_time": "2017-10-15T01:51:06.673474Z"
|
||||
"end_time": "2017-10-15T02:31:00.442210Z",
|
||||
"start_time": "2017-10-15T02:31:00.356486Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -510,7 +485,7 @@
|
||||
" 'type': 'epsilon_anneal'}},)"
|
||||
]
|
||||
},
|
||||
"execution_count": 43,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -528,11 +503,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:53:21.251258Z",
|
||||
"start_time": "2017-10-15T01:53:19.469753Z"
|
||||
"end_time": "2017-10-15T02:31:03.222746Z",
|
||||
"start_time": "2017-10-15T02:31:01.171736Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -541,16 +516,16 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:tensorforce.agents.agent:Configuration values not accessed: first_update, memory_capacity, memory, update_frequency, repeat_update\n",
|
||||
"[2017-10-15 09:53:21,245] Configuration values not accessed: first_update, memory_capacity, memory, update_frequency, repeat_update\n"
|
||||
"[2017-10-15 10:31:03,214] Configuration values not accessed: first_update, memory_capacity, memory, update_frequency, repeat_update\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorforce.agents.ppo_agent.PPOAgent at 0x7fdd41bb90f0>"
|
||||
"<tensorforce.agents.ppo_agent.PPOAgent at 0x7fa421752fd0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 51,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -639,16 +614,18 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train"
|
||||
"# Train\n",
|
||||
"\n",
|
||||
"## Callbacks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:53:56.906721Z",
|
||||
"start_time": "2017-10-15T01:53:56.825652Z"
|
||||
"end_time": "2017-10-15T02:31:15.507302Z",
|
||||
"start_time": "2017-10-15T02:31:15.427324Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -668,11 +645,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:53:57.245985Z",
|
||||
"start_time": "2017-10-15T01:53:57.163430Z"
|
||||
"end_time": "2017-10-15T02:31:15.751990Z",
|
||||
"start_time": "2017-10-15T02:31:15.662780Z"
|
||||
},
|
||||
"code_folding": [],
|
||||
"collapsed": true
|
||||
@@ -717,23 +694,94 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:53:57.934576Z",
|
||||
"start_time": "2017-10-15T01:53:57.833877Z"
|
||||
"end_time": "2017-10-15T02:31:15.934227Z",
|
||||
"start_time": "2017-10-15T02:31:15.825377Z"
|
||||
}
|
||||
},
|
||||
"code_folding": [],
|
||||
"collapsed": true
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"import numpy as np\n",
|
||||
"class TensorBoardLogger(object):\n",
|
||||
" \"\"\"\n",
|
||||
" Log scalar and histograms/distributions to tensorboard.\n",
|
||||
" Usage:\n",
|
||||
" ```\n",
|
||||
" logger = TensorBoardLogger(log_dir = '/tmp/test')\n",
|
||||
" for i in range(10):\n",
|
||||
" logger.log(\n",
|
||||
" logs=dict(\n",
|
||||
" float_test=np.random.random(),\n",
|
||||
" int_test=np.random.randint(0,4),\n",
|
||||
" ),\n",
|
||||
" histograms=dict(\n",
|
||||
" actions=np.random.randint(0,3,size=np.random.randint(5,20))\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" ```\n",
|
||||
" Ref: https://github.com/fchollet/keras/blob/master/keras/callbacks.py\n",
|
||||
" Url: https://gist.github.com/wassname/b692f8e8686655011618dfbe8d8a9e3f\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, log_dir, session=None, episode=0):\n",
|
||||
" self.log_dir = log_dir\n",
|
||||
" self.writer = tf.summary.FileWriter(self.log_dir)\n",
|
||||
" self.episode = episode\n",
|
||||
" print('TensorBoardLogger started. Run `tensorboard --logdir={}` to visualize'.format(os.path.abspath(self.log_dir)))\n",
|
||||
"\n",
|
||||
" self.histograms = {}\n",
|
||||
" self.histogram_inputs = {}\n",
|
||||
" self.session = session or tf.get_default_session() or tf.Session()\n",
|
||||
"\n",
|
||||
" def log(self, logs={}, histograms={}, episode=None):\n",
|
||||
" episode = episode or self.episode\n",
|
||||
" # scalar logging\n",
|
||||
" for name, value in logs.items():\n",
|
||||
" summary = tf.Summary()\n",
|
||||
" summary_value = summary.value.add()\n",
|
||||
" summary_value.simple_value = value\n",
|
||||
" summary_value.tag = name\n",
|
||||
" self.writer.add_summary(summary, episode)\n",
|
||||
"\n",
|
||||
" # histograms\n",
|
||||
" for name, value in histograms.items():\n",
|
||||
" if name not in self.histograms:\n",
|
||||
" # make a tensor with no fixed shape\n",
|
||||
" self.histogram_inputs[name] = tf.Variable(value,validate_shape=False)\n",
|
||||
" self.histograms[name] = tf.summary.histogram(name, self.histogram_inputs[name])\n",
|
||||
"\n",
|
||||
" input_tensor = self.histogram_inputs[name]\n",
|
||||
" summary = self.histograms[name]\n",
|
||||
" summary_str = summary.eval(session=self.session, feed_dict={input_tensor.name:value})\n",
|
||||
" self.writer.add_summary(summary_str, episode)\n",
|
||||
"\n",
|
||||
" self.writer.flush()\n",
|
||||
" self.episode += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T02:31:16.177868Z",
|
||||
"start_time": "2017-10-15T02:31:16.058908Z"
|
||||
},
|
||||
"code_folding": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Callback EpisodeFinishedTQDM\n",
|
||||
"\n",
|
||||
"from tqdm import tqdm_notebook\n",
|
||||
"class EpisodeFinishedTQDM(EpisodeFinished):\n",
|
||||
" \"\"\"Logger for tensorforce using tqdm_notebook for jupyter-notebook.\"\"\"\n",
|
||||
" \n",
|
||||
" def __init__(self, episodes, log_intv):\n",
|
||||
" def __init__(self, episodes, log_intv, session=None, log_dir=None, episode=0):\n",
|
||||
" \"\"\"\n",
|
||||
" log_intv - print the mean metrics every log_intv episodes\n",
|
||||
" \"\"\"\n",
|
||||
@@ -743,50 +791,120 @@
|
||||
" total=episodes, \n",
|
||||
" leave=True, mininterval=5)\n",
|
||||
" \n",
|
||||
" # tensorboard\n",
|
||||
" if log_dir:\n",
|
||||
" self.log_dir = log_dir\n",
|
||||
" elif save_path:\n",
|
||||
" self.log_dir = '/tmp/'+os.path.basename(save_path)\n",
|
||||
" else:\n",
|
||||
" self.log_dir = '/tmp/StepsProgressBar'\n",
|
||||
" self.tensor_board_logger = TensorBoardLogger(self.log_dir, session=session, episode=episode)\n",
|
||||
" \n",
|
||||
" def __call__(self, r):\n",
|
||||
" super().__call__(r)\n",
|
||||
" desc = \"reward: {reward: 2.4f} [{rewards_min: 2.4f}, {rewards_max: 2.4f}], portfolio_value: {portfolio_value: 2.4f} [{portfolio_value_min: 2.4f}, {portfolio_value_max: 2.4f}]\". format(\n",
|
||||
" oai_env = r.environment.gym.unwrapped\n",
|
||||
" exploration = r.agent.exploration.get('action0', lambda x,y:0)(r.episode, np.sum(r.episode_lengths))\n",
|
||||
" desc = \"reward: {reward: 2.4f} [{rewards_min: 2.4f}, {rewards_max: 2.4f}], portfolio_value: {portfolio_value: 2.4f} [{portfolio_value_min: 2.4f}, {portfolio_value_max: 2.4f}] expl={exploration: 2.2%}\".format(\n",
|
||||
" reward=np.mean(r.episode_rewards[-1:]),\n",
|
||||
" rewards_min=np.min(r.episode_rewards[-1:]),\n",
|
||||
" rewards_max=np.max(r.episode_rewards[-1:]),\n",
|
||||
" portfolio_value=np.mean(self.portfolio_values[-1:]),\n",
|
||||
" portfolio_value_min=np.min(self.portfolio_values[-1:]),\n",
|
||||
" portfolio_value_max=np.max(self.portfolio_values[-1:])\n",
|
||||
" portfolio_value_max=np.max(self.portfolio_values[-1:]),\n",
|
||||
" exploration=exploration\n",
|
||||
" )\n",
|
||||
" self.progbar.desc = desc\n",
|
||||
" self.progbar.update(1) # update\n",
|
||||
" \n",
|
||||
" # log to tensorboard\n",
|
||||
" logs=dict(\n",
|
||||
" rewards=r.episode_rewards[-1],\n",
|
||||
" episode_lengths=r.episode_lengths[-1],\n",
|
||||
" episode_time=r.episode_times[-1],\n",
|
||||
" portfolio_value=np.mean(self.portfolio_values[-1:]),\n",
|
||||
" portfolio_value_min=np.min(self.portfolio_values[-1:]),\n",
|
||||
" portfolio_value_max=np.max(self.portfolio_values[-1:]),\n",
|
||||
" exploration=exploration\n",
|
||||
" )\n",
|
||||
" df_info = pd.DataFrame(oai_env.infos)\n",
|
||||
" ep_infos = df_info.mean().to_dict()\n",
|
||||
" logs.update(ep_infos)\n",
|
||||
" self.tensor_board_logger.log(\n",
|
||||
" logs=logs,\n",
|
||||
" histograms=dict(\n",
|
||||
" returns=df_info['return'].values,\n",
|
||||
" portfolio_value=df_info.portfolio_value.values,\n",
|
||||
" market_value=df_info.market_value.values,\n",
|
||||
" ),\n",
|
||||
" episode=r.episode\n",
|
||||
" )\n",
|
||||
" return True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:53:59.814205Z",
|
||||
"start_time": "2017-10-15T01:53:59.729623Z"
|
||||
"end_time": "2017-10-15T02:19:28.278977Z",
|
||||
"start_time": "2017-10-15T02:19:28.132177Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T02:34:35.038267Z",
|
||||
"start_time": "2017-10-15T02:34:34.963840Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorforce.execution import Runner\n",
|
||||
"runner = Runner(agent=agent, environment=environment, save_path=save_path, save_episodes=10000)"
|
||||
"runner = Runner(agent=agent, environment=environment, save_path=save_path, save_episodes=1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:54:00.406978Z",
|
||||
"start_time": "2017-10-15T01:54:00.346179Z"
|
||||
},
|
||||
"collapsed": true
|
||||
"end_time": "2017-10-15T02:34:35.214512Z",
|
||||
"start_time": "2017-10-15T02:34:35.156499Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#agent.load_model(save_path)"
|
||||
"\n",
|
||||
"# Check my PR is included\n",
|
||||
"import tensorforce.core.memories\n",
|
||||
"assert isinstance(runner.agent.memory,tensorforce.core.memories.PrioritizedReplay)\n",
|
||||
"assert isinstance(runner.agent, tensorforce.agents.MemoryAgent)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T02:34:35.498415Z",
|
||||
"start_time": "2017-10-15T02:34:35.437300Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# resume\n",
|
||||
"saves=glob.glob(save_path+'-*')\n",
|
||||
"if len(saves)>0:\n",
|
||||
" # load saved\n",
|
||||
" last_save = os.path.splitext(saves[0])[0]\n",
|
||||
" runner.agent.load_model(last_save)\n",
|
||||
" print('loaded', last_save)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -794,7 +912,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2017-10-15T01:54:00.897Z"
|
||||
"start_time": "2017-10-15T02:34:35.964Z"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
@@ -802,17 +920,32 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c02d07f406244008ab9da56c090f0355"
|
||||
"model_id": "492cbe81a26f4eed8dfe1e2f3db257d8"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TensorBoardLogger started. Run `tensorboard --logdir=/media/isisilon/Data/My_Documents/Documents/eclipse-workspace/rl_keras_finance/portfolio-rl-jiang_2017/logs/tensorforce-PPO-prioritised_20171015_02-34-23` to visualize\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"episodes=int(6e6/30)\n",
|
||||
"episodes = int(6e6 / 30)\n",
|
||||
"runner.run(\n",
|
||||
" episodes=episodes, max_timesteps=200, episode_finished=EpisodeFinishedTQDM(log_intv=1000, episodes=episodes))"
|
||||
" episodes=episodes,\n",
|
||||
" max_timesteps=200,\n",
|
||||
" episode_finished=EpisodeFinishedTQDM(\n",
|
||||
" log_intv=100, \n",
|
||||
" episodes=episodes,\n",
|
||||
" log_dir=log_dir,\n",
|
||||
" session=runner.agent.model.session, \n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -820,8 +953,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:46.508299Z",
|
||||
"start_time": "2017-10-15T01:42:13.091Z"
|
||||
"end_time": "2017-10-15T02:30:35.073461Z",
|
||||
"start_time": "2017-10-15T02:29:49.726Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -857,8 +990,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:46.508945Z",
|
||||
"start_time": "2017-10-15T01:42:13.093Z"
|
||||
"end_time": "2017-10-15T02:30:35.074688Z",
|
||||
"start_time": "2017-10-15T02:29:49.729Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -921,8 +1054,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:46.509761Z",
|
||||
"start_time": "2017-10-15T01:42:13.095Z"
|
||||
"end_time": "2017-10-15T02:30:35.075752Z",
|
||||
"start_time": "2017-10-15T02:29:49.733Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -973,8 +1106,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:46.510600Z",
|
||||
"start_time": "2017-10-15T01:42:13.096Z"
|
||||
"end_time": "2017-10-15T02:30:35.076620Z",
|
||||
"start_time": "2017-10-15T02:29:49.737Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -997,8 +1130,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2017-10-15T01:42:46.511486Z",
|
||||
"start_time": "2017-10-15T01:42:13.098Z"
|
||||
"end_time": "2017-10-15T02:30:35.077480Z",
|
||||
"start_time": "2017-10-15T02:29:49.742Z"
|
||||
},
|
||||
"collapsed": true
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user