This commit is contained in:
wassname
2022-11-22 15:07:42 +08:00
parent f62f6b7ff0
commit 00ae1d8b8f
4 changed files with 599 additions and 21 deletions
+17
View File
@@ -46,3 +46,20 @@ python -m experiments.forecast --config_path=storage/experiments/Stocks/96S/repe
make build-all path=experiments/configs/Stocks
./run.sh
```
# So how does deeptime work?
original:
- inr(coords)
- RR
My mods (I added past other variables):
- inr(concat([x, coords]))
- RR
Where INR is one of [mlp, lstm, lstm2, transformer, transforme2, inceptioncausal]
TODO:
- M2S mode
- add other INR's
- add None as learner
+7 -2
View File
@@ -36,18 +36,23 @@ class DeepTIMe(nn.Module):
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
tgt_horizon_len = y_time.shape[1]
batch_size, lookback_len, _ = x.shape
# relative coordinates are the same for each batch, so we make them once and repeat them
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
if y_time.shape[-1] != 0:
if y_time.shape[-1] != 0: # if y_time is not empty of features
time = torch.cat([x_time, y_time], dim=1)
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
coords = torch.cat([coords, time], dim=-1)
time_reprs = self.inr(coords)
else:
time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size)
a = self.inr(coords)
time_reprs = repeat(a, '1 t d -> b t d', b=batch_size)
print(coords.shape, a.shape, batch_size, time_reprs.shape)
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
w, b = self.adaptive_weights(lookback_reprs, x)
preds = self.forecast(horizon_reprs, w, b)
return preds
File diff suppressed because one or more lines are too long
+19 -19
View File
@@ -6,8 +6,8 @@
"id": "4e09086b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T05:05:01.730258Z",
"start_time": "2022-11-20T05:05:00.558925Z"
"end_time": "2022-11-20T10:31:28.624566Z",
"start_time": "2022-11-20T10:31:27.336916Z"
},
"lines_to_next_cell": 0
},
@@ -71,8 +71,8 @@
"id": "04499bef",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T05:05:01.737023Z",
"start_time": "2022-11-20T05:05:01.731574Z"
"end_time": "2022-11-20T10:31:28.634463Z",
"start_time": "2022-11-20T10:31:28.626256Z"
}
},
"outputs": [],
@@ -125,12 +125,12 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 3,
"id": "768530be",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T06:52:29.098202Z",
"start_time": "2022-11-20T06:52:29.093424Z"
"end_time": "2022-11-20T10:31:28.667218Z",
"start_time": "2022-11-20T10:31:28.635969Z"
}
},
"outputs": [],
@@ -184,7 +184,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "739ee5e3",
"id": "f6da82f9",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T06:50:03.264769Z",
@@ -196,12 +196,12 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 4,
"id": "b8843217",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T06:52:29.508811Z",
"start_time": "2022-11-20T06:52:29.353139Z"
"end_time": "2022-11-20T10:31:32.788856Z",
"start_time": "2022-11-20T10:31:28.668594Z"
}
},
"outputs": [
@@ -221,7 +221,7 @@
"1"
]
},
"execution_count": 35,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
@@ -249,12 +249,12 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 5,
"id": "d39ba7c5",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T06:52:41.798916Z",
"start_time": "2022-11-20T06:52:41.260013Z"
"end_time": "2022-11-20T10:31:33.408193Z",
"start_time": "2022-11-20T10:31:32.789874Z"
}
},
"outputs": [
@@ -275,7 +275,7 @@
"1"
]
},
"execution_count": 36,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
@@ -301,12 +301,12 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 6,
"id": "9af56ddc",
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-20T06:28:51.826377Z",
"start_time": "2022-11-20T06:28:51.721432Z"
"end_time": "2022-11-20T10:31:33.539388Z",
"start_time": "2022-11-20T10:31:33.409217Z"
}
},
"outputs": [