mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 18:06:49 +08:00
readme
This commit is contained in:
+12025
-48
File diff suppressed because it is too large
Load Diff
@@ -311,6 +311,18 @@ results = defaultdict(dict)
|
||||
|
||||
from seq2seq_time.metrics import rmse, smape
|
||||
|
||||
for Dataset in datasets:
|
||||
dataset_name = Dataset.__name__
|
||||
dataset = Dataset(datasets_root)
|
||||
ds_train, ds_test = dataset.to_datasets(window_past=window_past,
|
||||
window_future=window_future)
|
||||
|
||||
# Init data
|
||||
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
|
||||
input_size = x_past.shape[-1]
|
||||
output_size = y_future.shape[-1]
|
||||
|
||||
|
||||
# +
|
||||
for Dataset in datasets:
|
||||
dataset_name = Dataset.__name__
|
||||
@@ -397,7 +409,13 @@ df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
display(df_results)
|
||||
|
||||
# +
|
||||
# EarlyStopping?
|
||||
# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
|
||||
# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
|
||||
# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
|
||||
# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
|
||||
# -
|
||||
|
||||
df_results.xs('nll', level=1).round(2)
|
||||
|
||||
# +
|
||||
# ds_preds.to_netcdf(trainer.logger.experiment.log_dir+'/ds_preds2.nc')
|
||||
|
||||
Reference in New Issue
Block a user