mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 18:40:49 +08:00
misc
This commit is contained in:
+6
-4
@@ -32,6 +32,7 @@ class DeepTIMe3(nn.Module):
|
||||
# nf=32, depth=6,
|
||||
nf=17, depth=3,
|
||||
bn=True,
|
||||
dilation=6,
|
||||
ks=[39, 19, 3],
|
||||
coord=True, fc_dropout=dropout,
|
||||
)
|
||||
@@ -57,17 +58,18 @@ class DeepTIMe3(nn.Module):
|
||||
representation = decode(h_past, coords)
|
||||
i = length of past, so we can offset the coords
|
||||
"""
|
||||
|
||||
# we summarize the past into a single hidden layer. Then repeat it for each coordinate
|
||||
past_len = time.shape[1]
|
||||
encoded_x = self.encoder(past_x.transpose(2, 1))
|
||||
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
|
||||
|
||||
# relative coordinates are the same for each batch, so we make them once and repeat them
|
||||
past_len = time.shape[1]
|
||||
encoded_x = repeat(encoded_x, "b f -> b t f", t=past_len)
|
||||
coords = self.get_coords(past_len).to(time.device) + offset
|
||||
coords = repeat(coords, "1 t 1 -> b t 1", b=time.shape[0])
|
||||
|
||||
|
||||
# combine and run INR to decode the representation
|
||||
context_input = torch.cat([encoded_x, coords, time], dim=-1)
|
||||
|
||||
context_repr = self.inr(context_input)
|
||||
return context_repr
|
||||
|
||||
|
||||
+91
-113
File diff suppressed because one or more lines are too long
+132
-5
@@ -18,7 +18,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8f9ebcf0",
|
||||
"id": "7f9e3d73",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T02:31:36.717738Z",
|
||||
@@ -248,7 +248,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "bca39ab0",
|
||||
"id": "83dca123",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T07:05:16.231276Z",
|
||||
@@ -290,6 +290,83 @@
|
||||
"1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"id": "cfc602db",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T07:08:30.855574Z",
|
||||
"start_time": "2022-11-22T07:08:30.549156Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn [49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mplot_multi\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage/experiments/Stocks/96M/repeat=0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;43;03m# Path(\"storage/experiments/Stocks/96Mplus/repeat=0\"),\u001b[39;49;00m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m60\u001b[39;49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;241m1\u001b[39m\n",
|
||||
"Cell \u001b[0;32mIn [39], line 18\u001b[0m, in \u001b[0;36mplot_multi\u001b[0;34m(save_paths, i, title, plot)\u001b[0m\n\u001b[1;32m 14\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(torch\u001b[38;5;241m.\u001b[39mload(save_path\u001b[38;5;241m/\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel.pth\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 15\u001b[0m model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 18\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_set\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 19\u001b[0m b \u001b[38;5;241m=\u001b[39m [bb[\u001b[38;5;28;01mNone\u001b[39;00m, :] \u001b[38;5;28;01mfor\u001b[39;00m bb \u001b[38;5;129;01min\u001b[39;00m b]\n\u001b[1;32m 21\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(train_loader))\n",
|
||||
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py:145\u001b[0m, in \u001b[0;36mForecastDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 143\u001b[0m cx_start \u001b[38;5;241m=\u001b[39m idx\n\u001b[1;32m 144\u001b[0m cx_end \u001b[38;5;241m=\u001b[39m cx_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlookback_len\n\u001b[0;32m--> 145\u001b[0m c_start \u001b[38;5;241m=\u001b[39m cx_end \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgap\u001b[49m\n\u001b[1;32m 146\u001b[0m c_end \u001b[38;5;241m=\u001b[39m c_start \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhorizon_len\n\u001b[1;32m 148\u001b[0m qx_start \u001b[38;5;241m=\u001b[39m cx_end \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgap\n",
|
||||
"File \u001b[0;32m~/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/utils/data/dataset.py:83\u001b[0m, in \u001b[0;36mDataset.__getattr__\u001b[0;34m(self, attribute_name)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plot_multi(\n",
|
||||
" save_paths=[\n",
|
||||
" Path(\"storage/experiments/Stocks/96M/repeat=0\"),\n",
|
||||
"# Path(\"storage/experiments/Stocks/96Mplus/repeat=0\"),\n",
|
||||
" ],\n",
|
||||
" i=60\n",
|
||||
" )\n",
|
||||
"1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"id": "b3b005c0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T07:08:43.569996Z",
|
||||
"start_time": "2022-11-22T07:08:39.011776Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/home/wassname/miniforge3/envs/deeptime/lib/python3.8/site-packages/torch/utils/data/dataset.py\u001b[0m(83)\u001b[0;36m__getattr__\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 81 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 82 \u001b[0;31m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m---> 83 \u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 84 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 85 \u001b[0;31m \u001b[0;34m@\u001b[0m\u001b[0mclassmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> u\n",
|
||||
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py\u001b[0m(145)\u001b[0;36m__getitem__\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 143 \u001b[0;31m \u001b[0mcx_start\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 144 \u001b[0;31m \u001b[0mcx_end\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcx_start\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlookback_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 145 \u001b[0;31m \u001b[0mc_start\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcx_end\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgap\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 146 \u001b[0;31m \u001b[0mc_end\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_start\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhorizon_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 147 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> self.gap\n",
|
||||
"*** AttributeError\n",
|
||||
"ipdb> q\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
@@ -334,6 +411,56 @@
|
||||
"1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"id": "7486f672",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-22T07:07:06.097048Z",
|
||||
"start_time": "2022-11-22T07:06:49.173386Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/media/wassname/SGIronWolf/projects5/investing/DeepTime/data/datasets.py\u001b[0m(105)\u001b[0;36mload_data\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 103 \u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mborder1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mborder2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 104 \u001b[0;31m \u001b[0;31m# y is just the col we predict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 105 \u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mborder1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mborder2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 106 \u001b[0;31m self.timestamps = get_time_features(pd.to_datetime(df_raw.date[border1:border2].values),\n",
|
||||
"\u001b[0m\u001b[0;32m 107 \u001b[0;31m \u001b[0mnormalise\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalise_time_features\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> self.target\n",
|
||||
"'RSMKs_18_144_72'\n",
|
||||
"ipdb> data[border1:border2]\n",
|
||||
"array([[-0.31080395],\n",
|
||||
" [-0.30072371],\n",
|
||||
" [-0.29116544],\n",
|
||||
" ...,\n",
|
||||
" [-0.14084614],\n",
|
||||
" [-0.15264537],\n",
|
||||
" [-0.16396183]])\n",
|
||||
"ipdb> data\n",
|
||||
"array([[ 0.08264075],\n",
|
||||
" [ 0.08548946],\n",
|
||||
" [ 0.08766026],\n",
|
||||
" ...,\n",
|
||||
" [-0.14084614],\n",
|
||||
" [-0.15264537],\n",
|
||||
" [-0.16396183]])\n",
|
||||
"ipdb> data.shape\n",
|
||||
"(53398, 1)\n",
|
||||
"ipdb> q\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
@@ -404,7 +531,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "69e00093",
|
||||
"id": "7fee5ab7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T00:57:32.922740Z",
|
||||
@@ -500,7 +627,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba05f2b0",
|
||||
"id": "e8680326",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -508,7 +635,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6877ee38",
|
||||
"id": "c20bb663",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
||||
Reference in New Issue
Block a user