mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 20:02:21 +08:00
misc
This commit is contained in:
@@ -46,3 +46,20 @@ python -m experiments.forecast --config_path=storage/experiments/Stocks/96S/repe
|
||||
make build-all path=experiments/configs/Stocks
|
||||
./run.sh
|
||||
```
|
||||
|
||||
# So how does deeptime work?
|
||||
|
||||
original:
|
||||
- inr(coords)
|
||||
- RR
|
||||
|
||||
My mods (I added past other variables):
|
||||
- inr(concat([x, coords]))
|
||||
- RR
|
||||
|
||||
Where INR is one of [mlp, lstm, lstm2, transformer, transforme2, inceptioncausal]
|
||||
|
||||
TODO:
|
||||
- M2S mode
|
||||
- add other INR's
|
||||
- add None as learner
|
||||
|
||||
+7
-2
@@ -36,18 +36,23 @@ class DeepTIMe(nn.Module):
|
||||
def forward(self, x: Tensor, x_time: Tensor, y_time: Tensor) -> Tensor:
|
||||
tgt_horizon_len = y_time.shape[1]
|
||||
batch_size, lookback_len, _ = x.shape
|
||||
|
||||
# relative coordinates are the same for each batch, so we make them once and repeat them
|
||||
coords = self.get_coords(lookback_len, tgt_horizon_len).to(x.device)
|
||||
|
||||
if y_time.shape[-1] != 0:
|
||||
if y_time.shape[-1] != 0: # if y_time is not empty of features
|
||||
time = torch.cat([x_time, y_time], dim=1)
|
||||
coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0])
|
||||
coords = torch.cat([coords, time], dim=-1)
|
||||
time_reprs = self.inr(coords)
|
||||
else:
|
||||
time_reprs = repeat(self.inr(coords), '1 t d -> b t d', b=batch_size)
|
||||
a = self.inr(coords)
|
||||
time_reprs = repeat(a, '1 t d -> b t d', b=batch_size)
|
||||
print(coords.shape, a.shape, batch_size, time_reprs.shape)
|
||||
|
||||
lookback_reprs = time_reprs[:, :-tgt_horizon_len]
|
||||
horizon_reprs = time_reprs[:, -tgt_horizon_len:]
|
||||
|
||||
w, b = self.adaptive_weights(lookback_reprs, x)
|
||||
preds = self.forecast(horizon_reprs, w, b)
|
||||
return preds
|
||||
|
||||
File diff suppressed because one or more lines are too long
+19
-19
@@ -6,8 +6,8 @@
|
||||
"id": "4e09086b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T05:05:01.730258Z",
|
||||
"start_time": "2022-11-20T05:05:00.558925Z"
|
||||
"end_time": "2022-11-20T10:31:28.624566Z",
|
||||
"start_time": "2022-11-20T10:31:27.336916Z"
|
||||
},
|
||||
"lines_to_next_cell": 0
|
||||
},
|
||||
@@ -71,8 +71,8 @@
|
||||
"id": "04499bef",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T05:05:01.737023Z",
|
||||
"start_time": "2022-11-20T05:05:01.731574Z"
|
||||
"end_time": "2022-11-20T10:31:28.634463Z",
|
||||
"start_time": "2022-11-20T10:31:28.626256Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -125,12 +125,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 3,
|
||||
"id": "768530be",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T06:52:29.098202Z",
|
||||
"start_time": "2022-11-20T06:52:29.093424Z"
|
||||
"end_time": "2022-11-20T10:31:28.667218Z",
|
||||
"start_time": "2022-11-20T10:31:28.635969Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -184,7 +184,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "739ee5e3",
|
||||
"id": "f6da82f9",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T06:50:03.264769Z",
|
||||
@@ -196,12 +196,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 4,
|
||||
"id": "b8843217",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T06:52:29.508811Z",
|
||||
"start_time": "2022-11-20T06:52:29.353139Z"
|
||||
"end_time": "2022-11-20T10:31:32.788856Z",
|
||||
"start_time": "2022-11-20T10:31:28.668594Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -221,7 +221,7 @@
|
||||
"1"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
@@ -249,12 +249,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 5,
|
||||
"id": "d39ba7c5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T06:52:41.798916Z",
|
||||
"start_time": "2022-11-20T06:52:41.260013Z"
|
||||
"end_time": "2022-11-20T10:31:33.408193Z",
|
||||
"start_time": "2022-11-20T10:31:32.789874Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -275,7 +275,7 @@
|
||||
"1"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
@@ -301,12 +301,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 6,
|
||||
"id": "9af56ddc",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2022-11-20T06:28:51.826377Z",
|
||||
"start_time": "2022-11-20T06:28:51.721432Z"
|
||||
"end_time": "2022-11-20T10:31:33.539388Z",
|
||||
"start_time": "2022-11-20T10:31:33.409217Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
||||
Reference in New Issue
Block a user