mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 19:32:35 +08:00
84 lines
3.7 KiB
Python
84 lines
3.7 KiB
Python
import xarray as xr
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
import pandas as pd
|
|
|
|
from .util import to_numpy
|
|
|
|
def predict(model, ds_test, batch_size, device='cpu', scaler=None):
|
|
"""
|
|
Gather all predictions into xarray.
|
|
|
|
When we generate prediction in a sequence to sequence model we start at a time then predict
|
|
N steps into the future. So we have 2 dimensions: source time, target time.
|
|
|
|
But we also care about how far we were predicting into the future, so we have 3 dimensions: source time, target time, time ahead.
|
|
|
|
It's hard to use pandas for data with virtual dimensions so we will use xarray. Xarray has an interface similar to pandas but also allows coordinates which are virtual dimensions.
|
|
"""
|
|
load_test = torch.utils.data.dataloader.DataLoader(ds_test, batch_size=batch_size)
|
|
freq = ds_test.df.index.freq
|
|
xrs = []
|
|
for i, batch in enumerate(tqdm(load_test, desc='predict', leave=False)):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
x_past, y_past, x_future, y_future = [d.to(device) for d in batch]
|
|
y_dist, extra = model(x_past, y_past, x_future)
|
|
nll = -y_dist.log_prob(y_future)
|
|
|
|
# Convert to numpy
|
|
mean = to_numpy(y_dist.loc.squeeze(-1))
|
|
std = to_numpy(y_dist.scale.squeeze(-1))
|
|
nll = to_numpy(nll.squeeze(-1))
|
|
y_future = to_numpy(y_future.squeeze(-1))
|
|
y_past = to_numpy(y_past.squeeze(-1))
|
|
|
|
# Make an xarray.Dataset for the data
|
|
bs = y_future.shape[0]
|
|
wp = ds_test.window_past
|
|
t_source = ds_test.df.index[wp + i*bs -1:wp+ i*bs+bs -1].values
|
|
t_ahead = pd.timedelta_range(1, periods=ds_test.window_future, freq=freq).values
|
|
t_behind = pd.timedelta_range(end=0, periods=ds_test.window_past, freq=freq)
|
|
xr_out = xr.Dataset(
|
|
{
|
|
# Format> name: ([dimensions,...], array),
|
|
"y_past": (["t_source", "t_behind",], y_past),
|
|
"nll": (["t_source", "t_ahead",], nll),
|
|
"y_pred": (["t_source", "t_ahead",], mean),
|
|
"y_pred_std": (["t_source", "t_ahead",], std),
|
|
"y_true": (["t_source", "t_ahead",], y_future),
|
|
},
|
|
coords={"t_source": t_source, "t_ahead": t_ahead, "t_behind": t_behind},
|
|
attrs={'freq': str(ds_test.freq), "model": str(type(model)), "targets": ds_test.columns_target}
|
|
)
|
|
xrs.append(xr_out)
|
|
|
|
# Join all batches
|
|
ds_preds = xr.concat(xrs, dim="t_source")
|
|
|
|
# undo scaling on y
|
|
if scaler:
|
|
ds_preds['y_pred_std'].values = ds_preds.y_pred_std * scaler.scale_
|
|
ds_preds['y_past'].values = scaler.inverse_transform(ds_preds.y_past)
|
|
ds_preds['y_pred'].values = scaler.inverse_transform(ds_preds.y_pred)
|
|
ds_preds['y_true'].values = scaler.inverse_transform(ds_preds.y_true)
|
|
|
|
# Add some derived coordinates, they will be the ones not in bold
|
|
# The target time, is a function of the source time, and how far we predict ahead
|
|
ds_preds = ds_preds.assign_coords(t_target=ds_preds.t_source+ds_preds.t_ahead)
|
|
|
|
ds_preds = ds_preds.assign_coords(t_past=ds_preds.t_source+ds_preds.t_behind)
|
|
|
|
# Some plots don't like timedeltas, so lets make a coordinate for time ahead in hours
|
|
ds_preds = ds_preds.assign_coords(t_ahead_hours=(ds_preds.t_ahead*1.0e-9/60/60).astype(float))
|
|
return ds_preds
|
|
|
|
def predict_multi(model, datasets, batch_size, device='cpu', scaler=None):
|
|
"""Predict over multiple datasets."""
|
|
ds_preds = [predict(model.to(device),
|
|
d,
|
|
batch_size,
|
|
device=device,
|
|
scaler=scaler) for d in tqdm(datasets, desc='predict_multi')]
|
|
return xr.concat(ds_preds, dim='block')
|