This commit is contained in:
wassname
2022-11-22 16:58:41 +08:00
parent 9072613d03
commit 8818a2b17b
3 changed files with 229 additions and 122 deletions
+6 -4
View File
@@ -32,6 +32,7 @@ class DeepTIMe3(nn.Module):
# nf=32, depth=6,
nf=17, depth=3,
bn=True,
dilation=6,
ks=[39, 19, 3],
coord=True, fc_dropout=dropout,
)
@@ -57,17 +58,18 @@ class DeepTIMe3(nn.Module):
representation = decode(h_past, coords)
i = length of past, so we can offset the coords
"""
# we summarize the past into a single hidden layer. Then repeat it for each coordinate
past_len = time.shape[1]
encoded_x = self.encoder(past_x.transpose(2, 1))
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
# relative coordinates are the same for each batch, so we make them once and repeat them
past_len = time.shape[1]
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
coords = self.get_coords(past_len).to(time.device) + offset
coords = repeat(coords, "1 t 1 -> b t 1", b=time.shape[0])
# combine and run INR to decode the representation
context_input = torch.cat([encoded_x, coords, time], dim=-1)
context_repr = self.inr(context_input)
return context_repr
+91 -113
View File
File diff suppressed because one or more lines are too long
+132 -5
View File
@@ -18,7 +18,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "8f9ebcf0",
"id": "7f9e3d73",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T02:31:36.717738Z",
@@ -248,7 +248,7 @@
{
"cell_type": "code",
"execution_count": 42,
"id": "bca39ab0",
"id": "83dca123",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T07:05:16.231276Z",
@@ -290,6 +290,83 @@
"1"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "cfc602db",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T07:08:30.855574Z",
"start_time": "2022-11-22T07:08:30.549156Z"
}
},
"outputs": [
{
"ename": "AttributeError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mplot_multi\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage/experiments/Stocks/96M/repeat=0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;43;03m# Path(\"storage/experiments/Stocks/96Mplus/repeat=0\"),\u001b[39;49;00m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m60\u001b[39;49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;241m1\u001b[39m\n",
"Cell \u001b[0;32mIn [39], line 18\u001b[0m, in \u001b[0;36mplot_multi\u001b[0;34m(save_paths, i, title, plot)\u001b[0m\n\u001b[1;32m 14\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(torch\u001b[38;5;241m.\u001b[39mload(save_path\u001b[38;5;241m/\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel.pth\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 15\u001b[0m model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 18\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_set\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 19\u001b[0m b \u001b[38;5;241m=\u001b[39m [bb[\u001b[38;5;28;01mNone\u001b[39;00m, :] \u001b[38;5;28;01mfor\u001b[39;00m bb \u001b[38;5;129;01min\u001b[39;00m b]\n\u001b[1;32m 21\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(train_loader))\n",
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py:145\u001b[0m, in \u001b[0;36mForecastDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 143\u001b[0m cx_start \u001b[38;5;241m=\u001b[39m idx\n\u001b[1;32m 144\u001b[0m cx_end \u001b[38;5;241m=\u001b[39m cx_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlookback_len\n\u001b[0;32m--> 145\u001b[0m c_start \u001b[38;5;241m=\u001b[39m cx_end \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgap\u001b[49m\n\u001b[1;32m 146\u001b[0m c_end \u001b[38;5;241m=\u001b[39m c_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhorizon_len\n\u001b[1;32m 148\u001b[0m qx_start \u001b[38;5;241m=\u001b[39m cx_end \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgap\n",
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/utils/data/dataset.py:83\u001b[0m, in \u001b[0;36mDataset.__getattr__\u001b[0;34m(self, attribute_name)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m\n",
"\u001b[0;31mAttributeError\u001b[0m: "
]
}
],
"source": [
"plot_multi(\n",
" save_paths=[\n",
" Path(\"storage/experiments/Stocks/96M/repeat=0\"),\n",
"# Path(\"storage/experiments/Stocks/96Mplus/repeat=0\"),\n",
" ],\n",
" i=60\n",
" )\n",
"1"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "b3b005c0",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T07:08:43.569996Z",
"start_time": "2022-11-22T07:08:39.011776Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/home/wassname/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/utils/data/dataset.py\u001b[0m(83)\u001b[0;36m__getattr__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 81 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 82 \u001b[0;31m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m---> 83 \u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 84 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 85 \u001b[0;31m \u001b[0;34m@\u001b[0m\u001b[0mclassmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n",
"ipdb> u\n",
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py\u001b[0m(145)\u001b[0;36m__getitem__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 143 \u001b[0;31m \u001b[0mcx_start\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 144 \u001b[0;31m \u001b[0mcx_end\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcx_start\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlookback_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m--> 145 \u001b[0;31m \u001b[0mc_start\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcx_end\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgap\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 146 \u001b[0;31m \u001b[0mc_end\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_start\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhorizon_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 147 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n",
"ipdb> self.gap\n",
"*** AttributeError\n",
"ipdb> q\n"
]
}
],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": 43,
@@ -334,6 +411,56 @@
"1"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "7486f672",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-22T07:07:06.097048Z",
"start_time": "2022-11-22T07:06:49.173386Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py\u001b[0m(105)\u001b[0;36mload_data\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 103 \u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mborder1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mborder2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 104 \u001b[0;31m \u001b[0;31m# y is just the col we predict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m--> 105 \u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mborder1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mborder2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 106 \u001b[0;31m self.timestamps = get_time_features(pd.to_datetime(df_raw.date[border1:border2].values),\n",
"\u001b[0m\u001b[0;32m 107 \u001b[0;31m \u001b[0mnormalise\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalise_time_features\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n",
"ipdb> self.target\n",
"'RSMKs_18_144_72'\n",
"ipdb> data[border1:border2]\n",
"array([[-0.31080395],\n",
" [-0.30072371],\n",
" [-0.29116544],\n",
" ...,\n",
" [-0.14084614],\n",
" [-0.15264537],\n",
" [-0.16396183]])\n",
"ipdb> data\n",
"array([[ 0.08264075],\n",
" [ 0.08548946],\n",
" [ 0.08766026],\n",
" ...,\n",
" [-0.14084614],\n",
" [-0.15264537],\n",
" [-0.16396183]])\n",
"ipdb> data.shape\n",
"(53398, 1)\n",
"ipdb> q\n"
]
}
],
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": 38,
@@ -404,7 +531,7 @@
},
{
"cell_type": "markdown",
"id": "69e00093",
"id": "7fee5ab7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T00:57:32.922740Z",
@@ -500,7 +627,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ba05f2b0",
"id": "e8680326",
"metadata": {},
"outputs": [],
"source": []
@@ -508,7 +635,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6877ee38",
"id": "c20bb663",
"metadata": {},
"outputs": [],
"source": []