lots of models, inc NP

This commit is contained in:
wassname
2020-10-19 20:51:23 +08:00
parent e986751e41
commit 7b6c729db5
12 changed files with 3554 additions and 1623 deletions
File diff suppressed because one or more lines are too long
+214 -158
View File
@@ -61,8 +61,11 @@ from tqdm.auto import tqdm
import pytorch_lightning as pl
# -
import warnings
warnings.simplefilter('once')
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
from seq2seq_time.predict import predict
from seq2seq_time.predict import predict, predict_multi
import logging, sys
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -79,7 +82,7 @@ window_future = 48*2
batch_size = 256
num_workers = 5
freq = '30T'
max_rows = 2e5
max_rows = 5e5
# -
@@ -88,7 +91,7 @@ max_rows = 2e5
# +
def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=1):
def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=8):
"""
Data loading and cleanding is always messy, so understand this code is optional.
"""
@@ -96,62 +99,65 @@ def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_file
# Load csv files
csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]
# concatendate them
df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])
dfs = []
for f in csv_files:
df = (pd.read_csv(f, parse_dates=[1], na_values=['Null'])
.groupby('tstp')
.sum()
.sort_index()
)
df['block'] = f.stem
# Drop nan and 0's
df = df[df['energy(kWh/hh)']!=0]
df = df.dropna()
# Add time features
time = df.index.to_series()
df["month"] = time.dt.month
df['day'] = time.dt.day
df['week'] = time.dt.week
df['hour'] = time.dt.hour
df['minute'] = time.dt.minute
df['dayofweek'] = time.dt.dayofweek
# Load weather data
df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
'pressure', 'apparentTemperature', 'windSpeed',
'humidity']
df_weather = df_weather[use_cols].set_index('time')
# Resample to match energy data
# Use first, since we have bearing, and you can't take mean
df_weather = df_weather.resample(freq).first().ffill()
# Join weather and energy data
df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
# Holidays
df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
holidays = set(df_hols['Bank holidays'].dt.round('D'))
def is_holiday(dt):
return dt in holidays
days = df.index.floor('D')
holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
df['holiday'] = days.to_series().map(holiday_mapping).values
# sort
df.index.name = 'Date'
df = df.loc['2012-09':] # Weird value before this
# Add ACORN categories
df_households = pd.read_csv(indir/'informations_households.csv')
df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]
df = pd.merge(df, df_households, on='LCLid')
dfs.append(df)
df = df.sort_values(['tstp', 'LCLid'])
df = df.set_index('tstp')
# Drop nan and 0's
df = df[df['energy(kWh/hh)']!=0]
df = df.dropna()
# Add time features
time = df.index.to_series()
df["month"] = time.dt.month
df['day'] = time.dt.day
df['week'] = time.dt.week
df['hour'] = time.dt.hour
df['minute'] = time.dt.minute
df['dayofweek'] = time.dt.dayofweek
# Load weather data
df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
'pressure', 'apparentTemperature', 'windSpeed',
'humidity']
df_weather = df_weather[use_cols].set_index('time')
df_weather = df_weather.resample(freq).first().ffill() # Resample to match energy data
# Join weather and energy data
df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
# Holidays
df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
holidays = set(df_hols['Bank holidays'].dt.round('D'))
def is_holiday(dt):
return dt in holidays
days = df.index.floor('D')
holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
df['holiday'] = days.to_series().map(holiday_mapping).values
# sort
df = df.reset_index().sort_values(['LCLid', 'index']).set_index('index')
df.index.name = 'Date'
return df
return pd.concat(dfs)
# -
# Our dataset is the london smartmeter data. But at half hour intervals
# +
df = get_smartmeter_df()
df = get_smartmeter_df(max_files=12)
# # Just get the first one for now
# dfs = list(dfs)
@@ -161,14 +167,15 @@ df = get_smartmeter_df()
df = df.tail(int(max_rows)).copy() # Just use last X rows
# df = pd.concat(dfs[:6], 0)
# # df = dfs[0]
df.LCLid.value_counts()
print(df.block.value_counts())
df
# -
# ### Plot/explore
df
@@ -186,22 +193,22 @@ from holoviews.operation import decimate
hv.extension('bokeh')
def house_curve(Name=None):
if isinstance(Name, int):
name = df.LCLid.unique()[Name]
d = df[df.LCLid == Name]
d_curve = hv.Curve(d, 'Date', 'energy(kWh/hh)', label=Name)
return d_curve
# def house_curve(Name=None):
# if isinstance(Name, int):
# name = df.block.unique()[Name]
# d = df[df.block == Name]
# d_curve = hv.Curve(d, 'Date', 'energy(kWh/hh)', label=Name).opts(framewise=True)
# return d_curve
dmap = hv.DynamicMap(house_curve, kdims=['Name'])
dmap = dmap.redim.values(Name=list(df.LCLid.unique()))
dynspread(datashade(dmap).opts(width=800,
height=300,
tools=['xwheel_zoom', 'pan'],
active_tools=['xwheel_zoom', 'pan'],
default_tools=['reset', 'save', 'hover']
))
# dmap = hv.DynamicMap(house_curve, kdims=['Name'])
# dmap = dmap.redim.values(Name=list(df.block.unique()))
# dynspread(datashade(dmap).opts(width=800,
# height=300,
# tools=['xwheel_zoom', 'pan'],
# active_tools=['xwheel_zoom', 'pan'],
# default_tools=['reset', 'save', 'hover']
# ))
# -
@@ -239,6 +246,8 @@ df_norm
output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
output_scaler
# ### Split
# +
# split data, with the test in the future
@@ -247,24 +256,36 @@ d1 = df_norm.index.max()
split_time = d0+(d1-d0)*0.8
split_time = split_time.round('1D')
print(split_time)
df_train = df_norm.groupby('LCLid').apply(lambda d:d.loc[:split_time]).reset_index(level=0, drop=True)
df_test = df_norm.groupby('LCLid').apply(lambda d:d.loc[split_time:]).reset_index(level=0, drop=True)
df_train = df_norm.groupby('block').apply(lambda d:d.loc[:split_time]).reset_index(level=0, drop=True)
df_test = df_norm.groupby('block').apply(lambda d:d.loc[split_time:]).reset_index(level=0, drop=True)
# df_test
# +
# # Show split
# df_train['energy(kWh/hh)'].plot(label='train')
# df_test['energy(kWh/hh)'].plot(label='test')
# plt.ylabel('energy(kWh/hh)')
# plt.legend()
# -
# Show split
df_train['energy(kWh/hh)'].plot(label='train')
df_test['energy(kWh/hh)'].plot(label='test')
plt.ylabel('energy(kWh/hh)')
plt.legend()
# # Show split
scatter = dynspread(datashade(hv.Curve(df_train, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='blue'))
scatter *= dynspread(datashade(hv.Curve(df_test, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='red'))
scatter = scatter.opts(plot=dict(width=800))
scatter
# ### Dataset
# +
# ### Dataset
# These are the columns that we wont know in the future
# We need to blank them out in x_future
columns_blank=['visibility',
'windBearing', 'temperature', 'dewPoint', 'pressure',
'apparentTemperature', 'windSpeed', 'humidity']
df_trains = [d.resample(freq).first().ffill().dropna() for _,d in df_train.groupby('LCLid')]
df_tests = [d.resample(freq).first().ffill().dropna() for _,d in df_test.groupby('LCLid')]
df_trains = [d.resample(freq).first().ffill().dropna() for _,d in df_train.groupby('block')]
df_tests = [d.resample(freq).first().ffill().dropna() for _,d in df_test.groupby('block')]
ds_train = Seq2SeqDataSets(df_trains,
window_past=window_past,
window_future=window_future,
@@ -275,6 +296,7 @@ ds_test = Seq2SeqDataSets(df_tests,
columns_blank=columns_blank)
print(ds_train)
print(ds_test)
# -
# we can treat it like an array
ds_train[0]
len(ds_train)
@@ -297,14 +319,9 @@ x_past.tail()
# Notice we've hidden some future columns to prevent cheating
x_future.tail()
# ## Plot helpers
from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
from seq2seq_time.models.lstm import LSTM
from seq2seq_time.models.baseline import BaselineLast
from seq2seq_time.models.transformer import Transformer
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
# ## Plots
# +
def plot_prediction(ds_preds, i):
"""Plot a prediction into the future, at a single point in time."""
@@ -352,8 +369,9 @@ def plot_prediction(ds_preds, i):
def plot_performance(ds_preds, full=False):
"""Multiple plots using xr_preds"""
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
plot_prediction(ds_preds, 24)
plot_prediction(ds_preds, 480)
# plot_prediction(ds_preds, 480)
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
n = len(ds_preds.t_source)
@@ -375,7 +393,7 @@ def plot_performance(ds_preds, full=False):
ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
plt.show()
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
# -
@@ -406,14 +424,18 @@ class PL_MODEL(pl.LightningModule):
def forward(self, x_past, y_past, x_future, y_future=None):
"""Eval/Predict"""
y_dist = self._model(x_past, y_past, x_future)
return y_dist
y_dist, extra = self._model(x_past, y_past, x_future, y_future)
return y_dist, extra
def training_step(self, batch, batch_idx, phase='train'):
x_past, y_past, x_future, y_future = batch
y_dist = self.forward(*batch)
y_dist, extra = self.forward(*batch)
loss = -y_dist.log_prob(y_future).mean()
self.log_dict({f'loss/{phase}':loss})
if ('loss' in extra) and (phase=='train'):
# some models have a special loss
loss = extra['loss']
self.log_dict({f'model_loss/{phase}':loss})
return loss
def validation_step(self, batch, batch_idx):
@@ -435,7 +457,6 @@ class PL_MODEL(pl.LightningModule):
# # Run
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import CSVLogger
from pl_bolts.callbacks import PrintTableMetricsCallback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
@@ -453,6 +474,50 @@ dl_train = DataLoader(ds_train,
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=num_workers)
# -
from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
from seq2seq_time.models.lstm_seq import LSTMSeq
from seq2seq_time.models.lstm import LSTM
from seq2seq_time.models.baseline import BaselineLast
from seq2seq_time.models.transformer import Transformer
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
from seq2seq_time.models.transformer_seq import TransformerSeq
from seq2seq_time.models.anp import RANP
# ## Plots
# +
models = [
RANP(input_size,
output_size),
LSTM(input_size,
output_size,
hidden_size=80,
lstm_layers=3,
lstm_dropout=0.3),
LSTMSeq2Seq(input_size,
output_size,
hidden_size=64,
lstm_layers=2,
lstm_dropout=0.25),
TransformerSeq2Seq(input_size,
output_size,
hidden_size=64,
nhead=8,
nlayers=4,
attention_dropout=0.3),
Transformer(input_size,
output_size,
attention_dropout=0.3,
nhead=8,
nlayers=6,
hidden_size=64),
TransformerSeq(input_size,
output_size),
LSTMSeq(input_size,
output_size),
]
# -
# Baseline model
pt_model = BaselineLast()
model = PL_MODEL(pt_model).to(device)
@@ -467,49 +532,6 @@ print(plot_hist(trainer))
ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
models = [
# BaselineLast(),
LSTM(input_size,
output_size,
hidden_size=80,
lstm_layers=3,
lstm_dropout=0.3),
Transformer(input_size,
output_size,
attention_dropout=0.3,
nhead=8,
nlayers=6,
hidden_size=64),
LSTMSeq2Seq(input_size,
output_size,
hidden_size=64,
lstm_layers=2,
lstm_dropout=0.25),
TransformerSeq2Seq(input_size,
output_size,
hidden_size=64,
nhead=8,
nlayers=4,
attention_dropout=0.3),
# Transformer(input_size,
# output_size,
# attention_dropout=0.2,
# nhead=8,
# nlayers=6,
# hidden_size=128),
# LSTM(input_size,
# output_size,
# hidden_size=128,
# lstm_layers=3,
# lstm_dropout=0.3),
]
for pt_model in models:
name = type(pt_model).__name__
print(name)
@@ -518,36 +540,66 @@ for pt_model in models:
patience = 2
model = PL_MODEL(pt_model, patience=patience, lr=3e-4).to(device)
# Trainer
# Trainer
trainer = pl.Trainer(gpus=1,
min_epochs=1,
min_epochs=2,
max_epochs=10,
amp_level='O1',
precision=16,
gradient_clip_val=0.5,
gradient_clip_val=1,
logger=CSVLogger("logs",
name=type(pt_model).__name__),
callbacks=[
EarlyStopping(monitor='loss/val', patience=patience*2),
PrintTableMetricsCallback()
# PrintTableMetricsCallback2()
],
)
# Train
trainer.fit(model, dl_train, dl_test)
# Performance
print(plot_hist(trainer))
ds_preds = predict(model.to(device),
ds_test.datasets[0],
batch_size,
device=device,
scaler=output_scaler)
print(name)
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
# Performance
print(plot_hist(trainer))
plot_performance(ds_preds)
# %debug
ds_preds = predict(model.to(device),q
ds_test.datasets[0],
batch_size,
device=device,
scaler=output_scaler)
# +
# ds_predss = predict_multi(model.to(device),
# ds_test.datasets,
# batch_size,
# device=device,
# scaler=output_scaler)
# -
ds_test.datasets[0].df.index.value_counts()
# TODO why dup?
ds_preds.sel(t_source='2013-11-11 00:30:00')
# TODO why duplicates?
d = ds_preds.isel(t_ahead=0)
d.t_source.to_series().sort_index()#.value_counts()
# np.unique
# d
# # holoviews pred
@@ -565,7 +617,7 @@ def plot_prediction_now(t_source):
d = ds_preds.sel(t_source=t_source)
# Sometimes there are duplicate time, take the first
# Sometimes there are duplicate times, take the first
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
d = d.isel(t_source=0)
if len(d.t_source.shape) and d.t_source.shape[0] == 0:
@@ -579,7 +631,7 @@ def plot_prediction_now(t_source):
p = hv.Scatter({
'x': x,
'y': yt
}, label='true').opts(color='black', framewise=True)
}, label='true').opts(color='black')
# Get arrays
xf = d.t_target.values
@@ -594,7 +646,7 @@ def plot_prediction_now(t_source):
label='2*std').opts(alpha=0.5, line_width=0)
# plot now line
p *= hv.VLine(now, label='now').opts(color='red')
p *= hv.VLine(now, label='now').opts(color='red', framewise=True)
return p.opts(title=f'Prediction at {now}. NLL={d.nll.mean().item():2.2f}')
@@ -611,6 +663,7 @@ def plot_predictions_vs_time(it_ahead):
"""Plot predictions vs time with holoviews"""
d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()
print(d)
p = hv.Scatter({
'x': d.t_source,
@@ -639,19 +692,22 @@ dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
height=300,
))
dmap_preds
# plot_prediction2(10).opts(width=800)
# TODO fixme
# -
d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()
nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
nll_vs_tahead
d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
nll_vs_time = hv.Curve(d).opts(width=800)
nll_vs_time
# +
# d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
# nll_vs_time = hv.Curve(d).opts(width=800)
# nll_vs_time
true_vs_pred = hv.Scatter((ds_preds.y_true, ds_preds.y_pred))
dynspread(datashade(true_vs_pred))
# +
# true_vs_pred = hv.Scatter((ds_preds.y_true, ds_preds.y_pred))
# dynspread(datashade(true_vs_pred))
# -
# # Summarize experiments
@@ -659,18 +715,18 @@ dynspread(datashade(true_vs_pred))
# +
# Run learning rate finder
lr_finder = trainer.tuner.lr_find(model)
# # Run learning rate finder
# lr_finder = trainer.tuner.lr_find(model)
# Results can be found in
lr_finder.results
# # Results can be found in
# lr_finder.results
# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()
# # Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()
# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()
# # Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()
# -
+473
View File
@@ -0,0 +1,473 @@
"""Recurrent Attentive Neural Process."""
import torch
from torch import nn
import torch.nn.functional as F
import math
class LSTMBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0,
batchnorm=False,
bias=False,
num_layers=1,
):
super().__init__()
self._lstm = nn.LSTM(
input_size=in_channels,
hidden_size=out_channels,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
bias=bias,
)
def forward(self, x):
return self._lstm(x)[0]
class NPBlockRelu2d(nn.Module):
"""Block for Neural Processes."""
def __init__(
self, in_channels, out_channels, dropout=0, batchnorm=False, bias=False
):
super().__init__()
self.linear = nn.Linear(in_channels, out_channels, bias=bias)
self.act = nn.ReLU()
self.dropout = nn.Dropout2d(dropout)
self.norm = nn.BatchNorm2d(out_channels) if batchnorm else False
def forward(self, x):
# x.shape is (Batch, Sequence, Channels)
# We pass a linear over it which operates on the Channels
x = self.act(self.linear(x))
# Now we want to apply batchnorm and dropout to the channels. So we put it in shape
# (Batch, Channels, Sequence, None) so we can use Dropout2d & BatchNorm2d
x = x.permute(0, 2, 1)[:, :, :, None]
if self.norm:
x = self.norm(x)
x = self.dropout(x)
return x[:, :, :, 0].permute(0, 2, 1)
class BatchMLP(nn.Module):
"""Apply MLP to the final axis of a 3D tensor (reusing already defined MLPs).
Args:
input: input tensor of shape [B,n,d_in].
output_sizes: An iterable containing the output sizes of the MLP as defined
in `basic.Linear`.
Returns:
tensor of shape [B,n,d_out] where d_out=output_size
"""
def __init__(
self, input_size, output_size, num_layers=2, dropout=0, batchnorm=False
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.num_layers = num_layers
self.initial = NPBlockRelu2d(
input_size, output_size, dropout=dropout, batchnorm=batchnorm
)
self.encoder = nn.Sequential(
*[
NPBlockRelu2d(
output_size, output_size, dropout=dropout, batchnorm=batchnorm
)
for _ in range(num_layers - 2)
]
)
self.final = nn.Linear(output_size, output_size)
def forward(self, x):
x = self.initial(x)
x = self.encoder(x)
return self.final(x)
class Attention(nn.Module):
def __init__(
self,
hidden_dim,
attention_type,
attention_layers=2,
n_heads=8,
x_dim=1,
rep="mlp",
dropout=0,
batchnorm=False,
):
super().__init__()
self._rep = rep
if self._rep == "mlp":
self.batch_mlp_k = BatchMLP(
x_dim,
hidden_dim,
attention_layers,
dropout=dropout,
batchnorm=batchnorm,
)
self.batch_mlp_q = BatchMLP(
x_dim,
hidden_dim,
attention_layers,
dropout=dropout,
batchnorm=batchnorm,
)
self._W = torch.nn.MultiheadAttention(
hidden_dim, n_heads, bias=False, dropout=dropout
)
self._attention_func = self._pytorch_multihead_attention
def forward(self, k, v, q):
if self._rep == "mlp":
k = self.batch_mlp_k(k)
q = self.batch_mlp_q(q)
rep = self._attention_func(k, v, q)
return rep
def _pytorch_multihead_attention(self, k, v, q):
# Pytorch multiheaded attention takes inputs if diff order and permutation
q = q.permute(1, 0, 2)
k = k.permute(1, 0, 2)
v = v.permute(1, 0, 2)
o = self._W(q, k, v)[0]
return o.permute(1, 0, 2)
class LatentEncoder(nn.Module):
def __init__(
self,
input_dim,
hidden_dim=32,
latent_dim=32,
self_attention_type="dot",
n_encoder_layers=3,
min_std=0.01,
batchnorm=False,
dropout=0,
attention_dropout=0,
use_self_attn=True,
attention_layers=2,
use_lstm=False,
):
super().__init__()
# self._input_layer = nn.Linear(input_dim, hidden_dim)
if use_lstm:
self._encoder = LSTMBlock(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_encoder_layers,
)
else:
self._encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_encoder_layers,
)
if use_self_attn:
self._self_attention = Attention(
hidden_dim,
self_attention_type,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._penultimate_layer = nn.Linear(hidden_dim, hidden_dim)
self._mean = nn.Linear(hidden_dim, latent_dim)
self._log_var = nn.Linear(hidden_dim, latent_dim)
self._min_std = min_std
self._use_lstm = use_lstm
self._use_self_attn = use_self_attn
def forward(self, x, y):
encoder_input = torch.cat([x, y], dim=-1)
# Pass final axis through MLP
encoded = self._encoder(encoder_input)
# Aggregator: take the mean over all points
if self._use_self_attn:
attention_output = self._self_attention(encoded, encoded, encoded)
mean_repr = attention_output.mean(dim=1)
else:
mean_repr = encoded.mean(dim=1)
# Have further MLP layers that map to the parameters of the Gaussian latent
mean_repr = torch.relu(self._penultimate_layer(mean_repr))
# Then apply further linear layers to output latent mu and log sigma
mean = self._mean(mean_repr)
log_var = self._log_var(mean_repr)
sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_var * 0.5)
dist = torch.distributions.Normal(mean, sigma)
return dist, log_var
class DeterministicEncoder(nn.Module):
def __init__(
self,
input_dim,
x_dim,
hidden_dim=32,
n_d_encoder_layers=3,
self_attention_type="dot",
cross_attention_type="dot",
use_self_attn=True,
attention_layers=2,
batchnorm=False,
dropout=0,
attention_dropout=0,
use_lstm=False,
):
super().__init__()
self._use_self_attn = use_self_attn
# self._input_layer = nn.Linear(input_dim, hidden_dim)
if use_lstm:
self._d_encoder = LSTMBlock(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_d_encoder_layers,
)
else:
self._d_encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_d_encoder_layers,
)
if use_self_attn:
self._self_attention = Attention(
hidden_dim,
self_attention_type,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._cross_attention = Attention(
hidden_dim,
cross_attention_type,
x_dim=x_dim,
attention_layers=attention_layers,
)
def forward(self, past_x, past_y, future_x):
# Concatenate x and y along the filter axes
d_encoder_input = torch.cat([past_x, past_y], dim=-1)
# Pass final axis through MLP
d_encoded = self._d_encoder(d_encoder_input)
if self._use_self_attn:
d_encoded = self._self_attention(d_encoded, d_encoded, d_encoded)
# Apply attention as mean aggregation
h = self._cross_attention(past_x, d_encoded, future_x)
return h
class Decoder(nn.Module):
def __init__(
self,
x_dim,
y_dim,
hidden_dim=32,
latent_dim=32,
n_decoder_layers=3,
use_deterministic_path=True,
min_std=0.01,
batchnorm=False,
dropout=0,
use_lstm=False,
):
super(Decoder, self).__init__()
self._future_transform = nn.Linear(x_dim, hidden_dim)
if use_deterministic_path:
hidden_dim_2 = 2 * hidden_dim + latent_dim
else:
hidden_dim_2 = hidden_dim + latent_dim
if use_lstm:
self._decoder = LSTMBlock(
hidden_dim_2,
hidden_dim_2,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_decoder_layers,
)
else:
self._decoder = BatchMLP(
hidden_dim_2,
hidden_dim_2,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_decoder_layers,
)
self._mean = nn.Linear(hidden_dim_2, y_dim)
self._std = nn.Linear(hidden_dim_2, y_dim)
self._use_deterministic_path = use_deterministic_path
self._min_std = min_std
def forward(self, r, z, future_x):
# concatenate future_x and representation
x = self._future_transform(future_x)
if self._use_deterministic_path:
z = torch.cat([r, z], dim=-1)
r = torch.cat([z, x], dim=-1)
r = self._decoder(r)
# Get the mean and the variance
mean = self._mean(r)
log_sigma = self._std(r)
# Bound or clamp the variance
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
dist = torch.distributions.Normal(mean, sigma)
return dist, log_sigma
class RANP(nn.Module):
"""Recurrent Attentive Neural Process for Sequential Data."""
def __init__(
self,
x_dim, # features in input
y_dim, # number of features in output
hidden_dim=32, # size of hidden space
latent_dim=32, # size of latent space
n_latent_encoder_layers=2,
n_det_encoder_layers=2, # number of deterministic encoder layers
n_decoder_layers=2,
use_deterministic_path=True,
min_std=0.01, # To avoid collapse use a minimum standard deviation, should be much smaller than variation in labels
dropout=0,
use_self_attn=True,
attention_dropout=0,
batchnorm=False,
attention_layers=2,
use_rnn=True, # use RNN/LSTM
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
use_lstm_d=False, # use another lstm in decoder instead of MLP
**kwargs,
):
super().__init__()
self._use_rnn = use_rnn
if self._use_rnn:
self._lstm = nn.LSTM(
input_size=x_dim,
hidden_size=hidden_dim,
num_layers=attention_layers,
dropout=dropout,
batch_first=True,
)
x_dim = hidden_dim
self._latent_encoder = LatentEncoder(
x_dim + y_dim,
hidden_dim=hidden_dim,
latent_dim=latent_dim,
n_encoder_layers=n_latent_encoder_layers,
attention_layers=attention_layers,
dropout=dropout,
use_self_attn=use_self_attn,
attention_dropout=attention_dropout,
batchnorm=batchnorm,
min_std=min_std,
use_lstm=use_lstm_le,
)
self._deterministic_encoder = DeterministicEncoder(
input_dim=x_dim + y_dim,
x_dim=x_dim,
hidden_dim=hidden_dim,
n_d_encoder_layers=n_det_encoder_layers,
attention_layers=attention_layers,
use_self_attn=use_self_attn,
dropout=dropout,
batchnorm=batchnorm,
attention_dropout=attention_dropout,
use_lstm=use_lstm_de,
)
self._decoder = Decoder(
x_dim,
y_dim,
hidden_dim=hidden_dim,
latent_dim=latent_dim,
dropout=dropout,
batchnorm=batchnorm,
min_std=min_std,
n_decoder_layers=n_decoder_layers,
use_deterministic_path=use_deterministic_path,
use_lstm=use_lstm_d,
)
self._use_deterministic_path = use_deterministic_path
def forward(self, past_x, past_y, future_x, future_y=None):
if self._use_rnn:
# see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x)
# x need to be provided as [B, T, H]
future_x, _ = self._lstm(future_x)
past_x, _ = self._lstm(past_x)
dist_prior, log_var_prior = self._latent_encoder(past_x, past_y)
if future_y is not None:
dist_post, log_var_post = self._latent_encoder(future_x, future_y)
z = dist_post.loc
else:
z = dist_prior.loc
num_targets = future_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]
if self._use_deterministic_path:
r = self._deterministic_encoder(
past_x, past_y, future_x
) # [B, T_target, H]
else:
r = None
dist, log_sigma = self._decoder(r, z, future_x)
loss = None
if future_y is not None:
log_p = dist.log_prob(future_y).mean(-1)
kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean(
-1
) # [B, R].mean(-1)
kl_loss = kl_loss[:, None].expand(log_p.shape)
mse_loss = F.mse_loss(dist.loc, future_y, reduction="none")[
:, : past_x.size(1)
].mean()
loss = (kl_loss - log_p).mean()
return dist, {'loss':loss}
+1 -1
View File
@@ -12,4 +12,4 @@ class BaselineLast(nn.Module):
B, S, F = future_x.shape
mean = past_y[:, -1:].repeat(1, S, 1)
std = (self.std * 1.0).repeat(1, S, 1)
return torch.distributions.Normal(mean, std)
return torch.distributions.Normal(mean, std), {}
+2 -2
View File
@@ -3,7 +3,7 @@ from torch import nn
from torch.nn import functional as F
class LSTM(nn.Module):
def __init__(self, input_size, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05, nan_value=0):
def __init__(self, input_size, output_size, hidden_size=64, lstm_layers=3, lstm_dropout=0, _min_std = 0.05, nan_value=0):
super().__init__()
self._min_std = _min_std
self.nan_value = nan_value
@@ -36,4 +36,4 @@ class LSTM(nn.Module):
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
return y_dist
return y_dist, {}
+34
View File
@@ -0,0 +1,34 @@
import torch
from torch import nn
from torch.nn import functional as F
class LSTMSeq(nn.Module):
def __init__(self, input_size, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05, nan_value=0):
super().__init__()
self._min_std = _min_std
self.nan_value = nan_value
self.lstm = nn.LSTM(
input_size=input_size + output_size,
hidden_size=hidden_size,
batch_first=True,
num_layers=lstm_layers,
dropout=lstm_dropout,
)
self.mean = nn.Linear(hidden_size, output_size)
self.std = nn.Linear(hidden_size, output_size)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
x = torch.cat([past_x, past_y], -1).detach()
steps = future_x.shape[1]
outputs, _ = self.lstm(x)
outputs = outputs[:, -steps:, :]
# outputs: [B, T, num_direction * H]
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
return y_dist, {}
View File
+1 -1
View File
@@ -55,5 +55,5 @@ class Transformer(nn.Module):
log_sigma = self.std(outputs)[:, steps:, :]
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma)
return torch.distributions.Normal(mean, sigma), {}
+54
View File
@@ -0,0 +1,54 @@
import torch
from torch import nn
from torch.nn import functional as F
class TransformerSeq(nn.Module):
"""
A single transformer, masking nan or 0
"""
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=2, hidden_size=16, nan_value=0, min_std=0.01):
super().__init__()
self._min_std = min_std
self.nan_value = nan_value
enc_x_dim = x_dim + y_dim
self.enc_emb = nn.Linear(enc_x_dim, hidden_size)
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*4,
dropout=attention_dropout,
nhead=nhead,
# activation
)
self.encoder = nn.TransformerEncoder(
layer_enc, num_layers=nlayers, norm=encoder_norm
)
self.mean = nn.Linear(hidden_size, y_dim)
self.std = nn.Linear(hidden_size, y_dim)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
x = torch.cat([past_x, past_y], -1).detach()
# Masks
x_mask = torch.isfinite(x) & (x != self.nan_value)
x[~x_mask] = 0
x = x.detach()
x_key_padding_mask = ~x_mask.any(-1)
x = self.enc_emb(x).permute(1, 0, 2)
outputs = self.encoder(x, src_key_padding_mask=x_key_padding_mask).permute(
1, 0, 2
)
# Seems to help a little, especially with extrapolating out of bounds
steps = future_x.shape[1]
mean = self.mean(outputs)[:, -steps:, :]
log_sigma = self.std(outputs)[:, -steps:, :]
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma), {}
+1 -1
View File
@@ -76,5 +76,5 @@ class TransformerSeq2Seq(nn.Module):
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma)
return torch.distributions.Normal(mean, sigma), {}
+11 -2
View File
@@ -19,11 +19,11 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
load_test = torch.utils.data.dataloader.DataLoader(ds_test, batch_size=batch_size)
freq = ds_test.df.index.freq
xrs = []
for i, batch in enumerate(tqdm(load_test, desc='predict')):
for i, batch in enumerate(tqdm(load_test, desc='predict', leave=False)):
model.eval()
with torch.no_grad():
x_past, y_past, x_future, y_future = [d.to(device) for d in batch]
y_dist = model(x_past, y_past, x_future, y_future)
y_dist, extra = model(x_past, y_past, x_future)
nll = -y_dist.log_prob(y_future)
# Convert to numpy
@@ -70,3 +70,12 @@ def predict(model, ds_test, batch_size, device='cpu', scaler=None):
# Some plots don't like timedeltas, so lets make a coordinate for time ahead in hours
ds_preds = ds_preds.assign_coords(t_ahead_hours=(ds_preds.t_ahead*1.0e-9/60/60).astype(float))
return ds_preds
def predict_multi(model, datasets, batch_size, device='cpu', scaler=None):
"""Predict over multiple datasets."""
ds_preds = [predict(model.to(device),
d,
batch_size,
device=device,
scaler=output_scaler) for d in tqdm(datasets)]
return xr.concat(ds_preds, dim='block')