This commit is contained in:
wassname
2020-10-29 12:22:47 +08:00
parent 052fd6596c
commit 4cbc7b4073
12 changed files with 12050 additions and 700 deletions
File diff suppressed because it is too large Load Diff
+19 -1
View File
@@ -311,6 +311,18 @@ results = defaultdict(dict)
from seq2seq_time.metrics import rmse, smape
for Dataset in datasets:
dataset_name = Dataset.__name__
dataset = Dataset(datasets_root)
ds_train, ds_test = dataset.to_datasets(window_past=window_past,
window_future=window_future)
# Init data
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
input_size = x_past.shape[-1]
output_size = y_future.shape[-1]
# +
for Dataset in datasets:
dataset_name = Dataset.__name__
@@ -397,7 +409,13 @@ df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
display(df_results)
# +
# EarlyStopping?
# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
# -
df_results.xs('nll', level=1).round(2)
# +
# ds_preds.to_netcdf(trainer.logger.experiment.log_dir+'/ds_preds2.nc')