nicer plots, more classes

This commit is contained in:
wassname
2020-10-24 20:21:07 +08:00
parent ddeba12bc7
commit fd6defbdc5
11 changed files with 9692 additions and 117 deletions
+1
View File
@@ -1,4 +1,5 @@
lightning_logs/
dataset_folder/
logs/
# Byte-compiled / optimized / DLL files
File diff suppressed because one or more lines are too long
@@ -0,0 +1,805 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.6.0
# kernelspec:
# display_name: seq2seq-time
# language: python
# name: seq2seq-time
# ---
# # Sequence to Sequence Models for Timeseries Regression
#
#
# In this notebook we are going to tackle a harder problem:
# - predicting the future on a timeseries
# - using an LSTM
# - with rough uncertainty (uncalibrated)
# - outputing sequence of predictions
#
# <img src="../reports/figures/Seq2Seq for regression.png" />
#
#
# https://medium.com/@boitemailjeanmid/smart-meters-in-london-part1-description-and-first-insights-jean-michel-d-db97af2de71b
#
# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
# %load_ext autoreload
# %autoreload 2
# %aimport -pandas
# %aimport -torch
# %aimport -numpy
# %aimport -matplotlib
# %aimport -dask
# %aimport -tqdm
# %matplotlib inline
# +
# Imports
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
import torch
import torch.utils.data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (12.0, 3.0)
plt.style.use('ggplot')
from pathlib import Path
from tqdm.auto import tqdm
import pytorch_lightning as pl
# -
import warnings
warnings.simplefilter('once')
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
from seq2seq_time.predict import predict, predict_multi
import logging, sys
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# ## Parameters
# +
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'using {device}')
columns_target=['energy(kWh/hh)']
window_past = 48*2
window_future = 48*2
batch_size = 128
num_workers = 5
freq = '30T'
max_rows = 5e5
# -
# ## Load data
# +
def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=8):
"""
Data loading and cleanding is always messy, so understand this code is optional.
"""
# Load csv files
csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]
dfs = []
for f in csv_files:
df = (pd.read_csv(f, parse_dates=[1], na_values=['Null'])
.groupby('tstp')
.sum()
.sort_index()
)
df['block'] = f.stem
# Drop nan and 0's
df = df[df['energy(kWh/hh)']!=0]
df = df.dropna()
# Add time features
time = df.index.to_series()
df["month"] = time.dt.month
df['day'] = time.dt.day
df['week'] = time.dt.week
df['hour'] = time.dt.hour
df['minute'] = time.dt.minute
df['dayofweek'] = time.dt.dayofweek
# Load weather data
df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
'pressure', 'apparentTemperature', 'windSpeed',
'humidity']
df_weather = df_weather[use_cols].set_index('time')
# Resample to match energy data
# Use first, since we have bearing, and you can't take mean
df_weather = df_weather.resample(freq).first().ffill()
# Join weather and energy data
df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
# Holidays
df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
holidays = set(df_hols['Bank holidays'].dt.round('D'))
def is_holiday(dt):
return dt in holidays
days = df.index.floor('D')
holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
df['holiday'] = days.to_series().map(holiday_mapping).values
# sort
df.index.name = 'Date'
df = df.loc['2012-09':] # Weird value before this
dfs.append(df)
return pd.concat(dfs)
# -
# Our dataset is the london smartmeter data. But at half hour intervals
# +
df = get_smartmeter_df(max_files=12)
# # Just get the first one for now
# dfs = list(dfs)
# # df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
df = df.tail(int(max_rows)).copy() # Just use last X rows
# df = pd.concat(dfs[:6], 0)
# # df = dfs[0]
print(df.block.value_counts())
df
# -
# ### Plot/explore
# +
import holoviews as hv
from holoviews import opts
from holoviews.plotting.links import RangeToolLink
import datashader as ds
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
from holoviews.operation import decimate
hv.extension('bokeh')
# def house_curve(Name=None):
# if isinstance(Name, int):
# name = df.block.unique()[Name]
# d = df[df.block == Name]
# d_curve = hv.Curve(d, 'Date', 'energy(kWh/hh)', label=Name).opts(framewise=True)
# return d_curve
# dmap = hv.DynamicMap(house_curve, kdims=['Name'])
# dmap = dmap.redim.values(Name=list(df.block.unique()))
# dynspread(datashade(dmap).opts(width=800,
# height=300,
# tools=['xwheel_zoom', 'pan'],
# active_tools=['xwheel_zoom', 'pan'],
# default_tools=['reset', 'save', 'hover']
# ))
# -
# ### Profiling
# +
# from pandas_profiling import ProfileReport
# profile = ProfileReport(df, title="Pandas Profiling Report", minimal=True)
# profile
# -
# ### Norm
df.describe()
# +
import sklearn
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn_pandas import DataFrameMapper
columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)
columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))
output_scalers = [([n], StandardScaler()) for n in columns_target]
transformers=output_scalers + \
[([n], StandardScaler()) for n in columns_input_numeric] + \
[([n], OrdinalEncoder()) for n in columns_categorical]
scaler = DataFrameMapper(transformers, df_out=True)
df_norm = scaler.fit_transform(df)
df_norm
# -
output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
output_scaler
# ### Split
# +
# split data, with the test in the future
d0 =df_norm.index.min()
d1 = df_norm.index.max()
split_time = d0+(d1-d0)*0.8
split_time = split_time.round('1D')
print(split_time)
df_train = df_norm.groupby('block').apply(lambda d:d.loc[:split_time]).reset_index(level=0, drop=True)
df_test = df_norm.groupby('block').apply(lambda d:d.loc[split_time:]).reset_index(level=0, drop=True)
# df_test
# +
# # Show split
# df_train['energy(kWh/hh)'].plot(label='train')
# df_test['energy(kWh/hh)'].plot(label='test')
# plt.ylabel('energy(kWh/hh)')
# plt.legend()
# -
# # Show split
scatter = dynspread(datashade(hv.Curve(df_train, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='blue'))
scatter *= dynspread(datashade(hv.Curve(df_test, kdims=['Date'], vdims=['energy(kWh/hh)', 'block']).groupby('block'), cmap='red'))
scatter = scatter.opts(plot=dict(width=800))
scatter
# ### Dataset
# +
# ### Dataset
# These are the columns that we wont know in the future
# We need to blank them out in x_future
columns_blank=['visibility',
'windBearing', 'temperature', 'dewPoint', 'pressure',
'apparentTemperature', 'windSpeed', 'humidity']
df_trains = [d.resample(freq).first().ffill().dropna() for _,d in df_train.groupby('block')]
df_tests = [d.resample(freq).first().ffill().dropna() for _,d in df_test.groupby('block')]
ds_train = Seq2SeqDataSets(df_trains,
window_past=window_past,
window_future=window_future,
columns_blank=columns_blank)
ds_test = Seq2SeqDataSets(df_tests,
window_past=window_past,
window_future=window_future,
columns_blank=columns_blank)
print(ds_train)
print(ds_test)
# -
# we can treat it like an array
ds_train[0]
len(ds_train)
ds_train[-1]
# +
# We can get rows
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
# Plot one instance, this is what the model sees
y_past['energy(kWh/hh)'].plot(label='past')
y_future['energy(kWh/hh)'].plot(ax=plt.gca(), label='future')
plt.legend()
plt.ylabel('energy(kWh/hh)')
# Notice we've added on two new columns tsp (time since present) and is_past
x_past.tail()
# -
# Notice we've hidden some future columns to prevent cheating
x_future.tail()
# ## Plot helpers
# +
def plot_prediction(ds_preds, i):
"""Plot a prediction into the future, at a single point in time."""
d = ds_preds.isel(t_source=i)
# Get arrays
xf = d.t_target
yp = d.y_pred
s = d.y_pred_std
yt = d.y_true
now = d.t_source.squeeze()
plt.figure(figsize=(12, 4))
plt.scatter(xf, yt, label='true', c='k', s=6)
ylim = plt.ylim()
# plot prediction
plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
facecolor="b",
interpolate=True,
label="2 std",)
plt.plot(xf, yp, label='pred', c='b')
# plot true
plt.scatter(
d.t_past,
d.y_past,
c='k',
s=6
)
# plot a red line for now
plt.vlines(x=now, ymin=0, ymax=1, label='now', color='r')
plt.ylim(*ylim)
now=pd.Timestamp(now.values)
plt.title(f'Prediction NLL={d.nll.mean().item():2.2g}')
plt.xlabel(f'{now.date()}')
plt.ylabel('energy(kWh/hh)')
plt.legend()
plt.xticks(rotation=45)
plt.show()
def plot_performance(ds_preds, full=False):
"""Multiple plots using xr_preds"""
plot_prediction(ds_preds, 24)
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
n = len(ds_preds.t_source)
plt.ylabel('Negative Log Likelihood (lower is better)')
plt.xlabel('Hours ahead')
plt.title(f'NLL vs time ahead (no. samples={n})')
plt.show()
# Make a plot of the NLL over time. Does this solution get worse with time?
if full:
d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')
plt.xticks(rotation=45)
plt.title('NLL over source time (lower is better)')
plt.show()
# A scatter plot is easy with xarray
if full:
plt.figure(figsize=(5, 5))
ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
plt.show()
# -
def plot_hist(trainer):
try:
df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
df_hist['epoch'] = df_hist['epoch'].ffill()
df_histe = df_hist.set_index('epoch').groupby('epoch').mean()
if len(df_histe)>1:
df_histe[['loss/train', 'loss/val']].plot(title='history')
return df_histe
except Exception:
pass
# ## Lightning
# +
import pytorch_lightning as pl
class PL_MODEL(pl.LightningModule):
def __init__(self, model, lr=3e-4, patience=2, weight_decay=0):
super().__init__()
self._model = model
self.lr = lr
self.patience = patience
self.weight_decay = weight_decay
def forward(self, x_past, y_past, x_future, y_future=None):
"""Eval/Predict"""
y_dist, extra = self._model(x_past, y_past, x_future, y_future)
return y_dist, extra
def training_step(self, batch, batch_idx, phase='train'):
x_past, y_past, x_future, y_future = batch
y_dist, extra = self.forward(*batch)
loss = -y_dist.log_prob(y_future).mean()
self.log_dict({f'loss/{phase}':loss})
if ('loss' in extra) and (phase=='train'):
# some models have a special loss
loss = extra['loss']
self.log_dict({f'model_loss/{phase}':loss})
return loss
def validation_step(self, batch, batch_idx):
return self.training_step(batch, batch_idx, phase='val')
def configure_optimizers(self):
optim = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optim,
patience=self.patience,
verbose=True,
min_lr=1e-7,
)
return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}
# -
# # Run
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# +
# Init data
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
input_size = x_past.shape[-1]
output_size = y_future.shape[-1]
dl_train = DataLoader(ds_train,
batch_size=batch_size,
shuffle=True,
pin_memory=num_workers==0,
num_workers=num_workers)
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=num_workers)
# +
import gc
def free_mem():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
# -
from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
from seq2seq_time.models.lstm_seq import LSTMSeq
from seq2seq_time.models.lstm import LSTM
from seq2seq_time.models.baseline import BaselineLast
from seq2seq_time.models.transformer import Transformer
from seq2seq_time.models.transformer_autor import TransformerAutoR
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
from seq2seq_time.models.transformer_seq import TransformerSeq
from seq2seq_time.models.neural_process import RANP
from seq2seq_time.models.transformer_process import TransformerProcess
# ## Plots
# +
# PL_MODEL(TransformerAutoR(input_size, output_size, hidden_out_size=32),
# patience=patience,
# lr=2e-5,
# weight_decay=1e-3)
# -
models = [
# TransformerAutoR2(input_size,
# output_size),
lambda: TransformerAutoR(input_size,
output_size, hidden_out_size=32),
lambda: RANP(input_size,
output_size, hidden_dim=32,
latent_dim=64, n_decoder_layers=4),
lambda: LSTM(input_size,
output_size,
hidden_size=80,
lstm_layers=3,
lstm_dropout=0.3),
lambda: LSTMSeq2Seq(input_size,
output_size,
hidden_size=64,
lstm_layers=2,
lstm_dropout=0.25),
lambda: TransformerSeq2Seq(input_size,
output_size,
hidden_size=128,
nhead=8,
nlayers=4,
attention_dropout=0.2),
lambda: Transformer(input_size,
output_size,
attention_dropout=0.2,
nhead=8,
nlayers=8,
hidden_size=128),
# lambda: TransformerSeq(input_size,
# output_size),
# lambda: LSTMSeq(input_size,
# output_size),
lambda :TransformerProcess(input_size,
output_size, hidden_size=16,
latent_dim=8, dropout=0.5,
nlayers=4,)
]
models
# Baseline model
pt_model = BaselineLast()
model = PL_MODEL(pt_model).to(device)
trainer = pl.Trainer(gpus=1,
max_epochs=1,
limit_train_batches=0.01,
logger=CSVLogger("logs",
name=type(pt_model).__name__),
)
trainer.fit(model, dl_train, dl_test)
print(plot_hist(trainer))
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*8,
device=device,
scaler=output_scaler)
print(f'baseline nll: {ds_predss.nll.mean().item():2.2g}')
# ## Train
for m_fn in models:
pt_model = m_fn()
name = type(pt_model).__name__
print(name)
# Wrap in lightning
patience = 2
model = PL_MODEL(pt_model, patience=patience, lr=2e-5, weight_decay=1e-3).to(device)
# Trainer
trainer = pl.Trainer(gpus=1,
min_epochs=2,
max_epochs=30,
amp_level='O1',
precision=16,
gradient_clip_val=1,
logger=CSVLogger("logs",
name=type(pt_model).__name__),
callbacks=[
EarlyStopping(monitor='loss/val', patience=patience*2),
# PrintTableMetricsCallback2()
],
)
# Train
trainer.fit(model, dl_train, dl_test)
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*2,
device=device,
scaler=output_scaler)
print(name)
print(f'mean_NLL {ds_predss.nll.mean().item():2.2f}')
# Performance
ds_preds = ds_predss.isel(block=0)
print(plot_hist(trainer))
plot_performance(ds_preds)
model.cpu()
free_mem()
# # Plots
# +
# Get latest checkpoint for a model type...
pt_model = models[1]()
name = type(pt_model).__name__
checkpoints = (Path('logs')/name).glob('version_*')
sort_checkpoints = lambda f:int(f.stem.split('_')[-1])
checkpoints = sorted(checkpoints, key=sort_checkpoints)
latest_checkpoint = checkpoints[-1]
checkpoint_f = sorted(latest_checkpoint.glob('checkpoints/*.ckpt'))[-1]
print('pt model name', name)
print('latest_checkpoint', checkpoint_f)
# -
# Load
model = PL_MODEL(pt_model).to(device)
model.load_from_checkpoint(str(checkpoint_f), model=pt_model)
ds_predss = predict_multi(model.to(device),
ds_test.datasets,
batch_size*4,
device=device,
scaler=output_scaler)
ds_predss.nll.mean().item()
ds_pred_block = ds_predss.isel(block=1)
# # holoviews pred
# +
import holoviews as hv
from holoviews import opts
import holoviews as hv
from holoviews import opts
import datashader as ds
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
from holoviews.operation import decimate
hv.extension('bokeh')
# +
# A few diagnostic plots
d_source = ds_predss.mean(['t_ahead',
'block'])['nll'].groupby('t_source').mean()
nll_vs_time = (hv.Curve(d_source).opts(width=600,
height=200,
title='Error vs time of prediction'))
d_ahead = ds_predss.mean(['t_source',
'block'])['nll'].groupby('t_ahead_hours').mean()
nll_vs_tahead = (hv.Curve(
(d_ahead.t_ahead_hours,
d_ahead)).redim(x='hours ahead',
y='nll').opts(width=600,
height=200,
title='Error vs time ahead'))
true_vs_pred = datashade(hv.Scatter(
(ds_predss.y_true,
ds_predss.y_pred))).redim(x='true', y='pred').opts(title='Scatter plot')
true_vs_pred = dynspread(true_vs_pred)
l = nll_vs_time + nll_vs_tahead + true_vs_pred
l.cols(1).opts(
framewise=True,
shared_axes=False,
)
# +
def hv_predict_from_time(t_source):
"""Plot predictions with holoviews"""
# Let us pass in an int
if isinstance(t_source, int):
t_source = ds_pred_block.t_source[t_source].to_pandas()
d = ds_pred_block.sel(t_source=t_source)
# Sometimes there are duplicate times, take the first
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
d = d.isel(t_source=0)
if len(d.t_source.shape) and d.t_source.shape[0] == 0:
return None
now = d.t_source.to_pandas()
# Plot true
x = np.concatenate([d.t_past, d.t_target])
yt = np.concatenate([d.y_past, d.y_true])
p = hv.Scatter({
'x': x,
'y': yt
}, label='true').opts(color='black')
# Get arrays
xf = d.t_target.values
yp = d.y_pred
s = d.y_pred_std
p *= hv.Curve({
'x': xf,
'y': yp
}, label='pred').opts(color='blue')
p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
vdims=['y', 'y2'],
label='2*std').opts(alpha=0.5, line_width=0)
# plot now line
p *= hv.VLine(now, label='now').opts(color='red', framewise=True)
return p.opts(title=f'Prediction at {now}. NLL={d.nll.mean().item():2.2f}')
dmap_hv_predict_from_time = (hv.DynamicMap(hv_predict_from_time, kdims=['t_source'])
.redim.values(t_source=ds_pred_block.t_source.to_pandas())
.opts(width=800,
height=300,
))
dmap_hv_predict_from_time
# +
def hv_plot_predictions_vs_time(it_ahead=6,
std=False,
ds_pred_block=ds_pred_block):
"""Plot predictions vs time with holoviews"""
d = ds_pred_block.isel(t_ahead=it_ahead).groupby('t_source').first()
p = hv.Scatter({
'x': d.t_source,
'y': d.y_true
}, label='true').opts(color='black', size=2)
# Get arrays
xf = d.t_source.values
yp = d.y_pred
s = d.y_pred_std
# Mean
p *= hv.Curve({'x': xf, 'y': yp}, label='pred').opts(color='blue')
if std:
p *= hv.Spread((xf, yp, s*2),
label='2*std').opts(alpha=0.5, line_width=0)
else:
p = datashade(p)
title = f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}'
return p.opts(
title=title,
width=800,
height=300,
tools=['xwheel_zoom'],
active_tools=['xwheel_zoom', 'pan'],
)
p = hv_plot_predictions_vs_time(
6, std=True, ds_pred_block=ds_pred_block.isel(t_source=slice(100, 4000)))
p
# -
# # Summarize experiments
# # LR finder
# +
# # Run learning rate finder
# lr_finder = trainer.tuner.lr_find(model)
# # Results can be found in
# lr_finder.results
# # Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()
# # Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()
# -
+5
View File
@@ -255,7 +255,12 @@ dependencies:
- pyqt5-sip==4.19.18
- pyqtchart==5.12
- pyqtwebengine==5.12.1
- pytorch-fast-transformers==0.3.0
- pytorch-lightning-bolts==0.2.5
- sklearn==0.0
- sklearn-pandas==2.0.2
- torchsummaryx==1.3.0
- ucimlr==0.3.0
- unlzw==0.1.1
- xlrd==1.2.0
prefix: /home/wassname/anaconda/envs/seq2seq-time
+5 -3
View File
@@ -20,9 +20,11 @@ class LSTM(nn.Module):
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
future_y_fake = (
torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
)
B, S, _ = future_x.shape
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
# future_y_fake = (
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
# )
context = torch.cat([past_x, past_y], -1).detach()
target = torch.cat([future_x, future_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
+62 -100
View File
@@ -32,7 +32,11 @@ class LSTMBlock(nn.Module):
class NPBlockRelu2d(nn.Module):
"""Block for Neural Processes."""
"""
Block for Neural Processes.
We want to apply batchnorm and dropout to the channels. We reshape so we can use Dropout2d & BatchNorm2d
"""
def __init__(
self, in_channels, out_channels, dropout=0, batchnorm=False, bias=False
@@ -101,7 +105,6 @@ class Attention(nn.Module):
def __init__(
self,
hidden_dim,
attention_type,
attention_layers=2,
n_heads=8,
x_dim=1,
@@ -155,48 +158,33 @@ class LatentEncoder(nn.Module):
input_dim,
hidden_dim=32,
latent_dim=32,
self_attention_type="dot",
n_encoder_layers=3,
min_std=0.01,
batchnorm=False,
dropout=0,
attention_dropout=0,
use_self_attn=True,
attention_layers=2,
use_lstm=False,
):
super().__init__()
# self._input_layer = nn.Linear(input_dim, hidden_dim)
if use_lstm:
self._encoder = LSTMBlock(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_encoder_layers,
)
else:
self._encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_encoder_layers,
)
if use_self_attn:
self._self_attention = Attention(
hidden_dim,
self_attention_type,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_encoder_layers,
)
self._self_attention = Attention(
hidden_dim,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._penultimate_layer = nn.Linear(hidden_dim, hidden_dim)
self._mean = nn.Linear(hidden_dim, latent_dim)
self._log_var = nn.Linear(hidden_dim, latent_dim)
self._min_std = min_std
self._use_lstm = use_lstm
self._use_self_attn = use_self_attn
def forward(self, x, y):
encoder_input = torch.cat([x, y], dim=-1)
@@ -205,11 +193,8 @@ class LatentEncoder(nn.Module):
encoded = self._encoder(encoder_input)
# Aggregator: take the mean over all points
if self._use_self_attn:
attention_output = self._self_attention(encoded, encoded, encoded)
mean_repr = attention_output.mean(dim=1)
else:
mean_repr = encoded.mean(dim=1)
attention_output = self._self_attention(encoded, encoded, encoded)
mean_repr = attention_output.mean(dim=1)
# Have further MLP layers that map to the parameters of the Gaussian latent
mean_repr = torch.relu(self._penultimate_layer(mean_repr))
@@ -230,45 +215,28 @@ class DeterministicEncoder(nn.Module):
x_dim,
hidden_dim=32,
n_d_encoder_layers=3,
self_attention_type="dot",
cross_attention_type="dot",
use_self_attn=True,
attention_layers=2,
batchnorm=False,
dropout=0,
attention_dropout=0,
use_lstm=False,
):
super().__init__()
self._use_self_attn = use_self_attn
# self._input_layer = nn.Linear(input_dim, hidden_dim)
if use_lstm:
self._d_encoder = LSTMBlock(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_d_encoder_layers,
)
else:
self._d_encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_d_encoder_layers,
)
if use_self_attn:
self._self_attention = Attention(
hidden_dim,
self_attention_type,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._d_encoder = BatchMLP(
input_dim,
hidden_dim,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_d_encoder_layers,
)
self._self_attention = Attention(
hidden_dim,
attention_layers,
rep="identity",
dropout=attention_dropout,
)
self._cross_attention = Attention(
hidden_dim,
cross_attention_type,
x_dim=x_dim,
attention_layers=attention_layers,
)
@@ -280,8 +248,7 @@ class DeterministicEncoder(nn.Module):
# Pass final axis through MLP
d_encoded = self._d_encoder(d_encoder_input)
if self._use_self_attn:
d_encoded = self._self_attention(d_encoded, d_encoded, d_encoded)
d_encoded = self._self_attention(d_encoded, d_encoded, d_encoded)
# Apply attention as mean aggregation
h = self._cross_attention(past_x, d_encoded, future_x)
@@ -301,7 +268,6 @@ class Decoder(nn.Module):
min_std=0.01,
batchnorm=False,
dropout=0,
use_lstm=False,
):
super(Decoder, self).__init__()
self._future_transform = nn.Linear(x_dim, hidden_dim)
@@ -310,22 +276,14 @@ class Decoder(nn.Module):
else:
hidden_dim_2 = hidden_dim + latent_dim
if use_lstm:
self._decoder = LSTMBlock(
hidden_dim_2,
hidden_dim_2,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_decoder_layers,
)
else:
self._decoder = BatchMLP(
hidden_dim_2,
hidden_dim_2,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_decoder_layers,
)
self._decoder = BatchMLP(
hidden_dim_2,
hidden_dim_2,
batchnorm=batchnorm,
dropout=dropout,
num_layers=n_decoder_layers,
)
self._mean = nn.Linear(hidden_dim_2, y_dim)
self._std = nn.Linear(hidden_dim_2, y_dim)
self._use_deterministic_path = use_deterministic_path
@@ -363,18 +321,14 @@ class RANP(nn.Module):
latent_dim=32, # size of latent space
n_latent_encoder_layers=2,
n_det_encoder_layers=2, # number of deterministic encoder layers
n_decoder_layers=2,
n_decoder_layers=4,
use_deterministic_path=True,
min_std=0.01, # To avoid collapse use a minimum standard deviation, should be much smaller than variation in labels
dropout=0,
use_self_attn=True,
attention_dropout=0,
batchnorm=False,
attention_layers=2,
use_rnn=True, # use RNN/LSTM
use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
use_lstm_d=False, # use another lstm in decoder instead of MLP
**kwargs,
):
@@ -399,11 +353,9 @@ class RANP(nn.Module):
n_encoder_layers=n_latent_encoder_layers,
attention_layers=attention_layers,
dropout=dropout,
use_self_attn=use_self_attn,
attention_dropout=attention_dropout,
batchnorm=batchnorm,
min_std=min_std,
use_lstm=use_lstm_le,
)
self._deterministic_encoder = DeterministicEncoder(
@@ -412,11 +364,9 @@ class RANP(nn.Module):
hidden_dim=hidden_dim,
n_d_encoder_layers=n_det_encoder_layers,
attention_layers=attention_layers,
use_self_attn=use_self_attn,
dropout=dropout,
batchnorm=batchnorm,
attention_dropout=attention_dropout,
use_lstm=use_lstm_de,
)
self._decoder = Decoder(
@@ -429,7 +379,6 @@ class RANP(nn.Module):
min_std=min_std,
n_decoder_layers=n_decoder_layers,
use_deterministic_path=use_deterministic_path,
use_lstm=use_lstm_d,
)
self._use_deterministic_path = use_deterministic_path
@@ -443,19 +392,17 @@ class RANP(nn.Module):
x, _ = self._lstm(x)
past_x = x[:, :S]
future_x = x[:, S:]
# future_x, _ = self._lstm(future_x)
# past_x, _ = self._lstm(past_x)
dist_prior, log_var_prior = self._latent_encoder(past_x, past_y)
if (future_y is not None):
dist_post, log_var_post = self._latent_encoder(future_x, future_y)
y = torch.cat([past_y, future_y], 1)
dist_post, log_var_post = self._latent_encoder(x, y)
if self.training:
z = dist_prior.rsample()
else:
z = dist_prior.loc
num_targets = future_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]
@@ -478,5 +425,20 @@ class RANP(nn.Module):
:, : past_x.size(1)
].mean()
loss = (kl_loss - log_p).mean()
return dist, {'loss':loss}
return dist, {'loss': loss}
# class NP(RANP):
# """Recurrent Attentive Neural Process for Sequential Data."""
# def __init__(
# self,
# use_self_attn=True,
# # TODO use cross attention flag
# use_rnn=True, # use RNN/LSTM
# use_lstm_le=False, # use another LSTM in latent encoder instead of MLP
# use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP
# use_lstm_d=False, # use another lstm in decoder instead of MLP
# **kwargs,
# ):
# kwargs
# super().__init__(**kwargs)
+13 -6
View File
@@ -2,12 +2,13 @@ import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class Transformer(nn.Module):
"""
A single transformer, masking nan or 0
"""
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=2, hidden_size=16, nan_value=0, min_std=0.01):
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=8, hidden_size=32, nan_value=0, min_std=0.01):
super().__init__()
self._min_std = min_std
self.nan_value = nan_value
@@ -17,7 +18,7 @@ class Transformer(nn.Module):
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*4,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
# activation
@@ -30,9 +31,11 @@ class Transformer(nn.Module):
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
future_y_fake = (
torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
)
B, S, _ = future_x.shape
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
# future_y_fake = (
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * past_y[:, -1].repeat(B, S, 1)
# )
context = torch.cat([past_x, past_y], -1).detach()
target = torch.cat([future_x, future_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
@@ -44,8 +47,12 @@ class Transformer(nn.Module):
x_key_padding_mask = ~x_mask.any(-1)
x = self.enc_emb(x).permute(1, 0, 2)
B, S, _ = x.shape
mask = mask_upper_triangular(S, device)
outputs = self.encoder(x, src_key_padding_mask=x_key_padding_mask).permute(
outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
).permute(
1, 0, 2
)
+73
View File
@@ -0,0 +1,73 @@
from tqdm.auto import tqdm
from torch import nn
import torch
from torch.nn import functional as F
import fast_transformers
from fast_transformers.builders import TransformerEncoderBuilder
class TransformerAutoR(nn.Module):
def __init__(self, x_dim, y_dim, hidden_out_size=256, nlayers=8, n_heads=8, use_lstm=False, attention_dropout=0, dropout=0, min_std=0.01):
super().__init__()
self._min_std = min_std
self.use_lstm = use_lstm
hidden_out_size = hidden_out_size//n_heads
x_size = x_dim + y_dim
# TODO embedd both X's the same
if use_lstm:
self.x_emb = LSTMBlock(x_size, x_size)
self.enc_emb = nn.Linear(x_size, hidden_out_size*n_heads)
self.encoder = fast_transformers.builders.TransformerEncoderBuilder.from_kwargs(
attention_type="causal-linear",
n_layers=nlayers,
n_heads=n_heads,
feed_forward_dimensions=hidden_out_size*8*n_heads,
query_dimensions=hidden_out_size,
value_dimensions=hidden_out_size,
attention_dropout=attention_dropout,
dropout=dropout,
).get()
self.mean = nn.Linear(hidden_out_size*n_heads, y_dim)
self.std = nn.Linear(hidden_out_size*n_heads, y_dim)
def forward(self, past_x, past_y, future_x, future_y=None, mask_context=True, mask_target=True):
device = next(self.parameters()).device
B, S, _ = future_x.shape
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
# future_y_fake = (
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * 0
# )
context = torch.cat([past_x, past_y], -1)
target = torch.cat([future_x, future_y_fake], -1)
x = torch.cat([context, target * 1], 1).detach()
# LSTM
if self.use_lstm:
x = self.x_emb(x)
# Size([B, T, Y]) -> Size([B, T, Y])
# Embed
x = self.enc_emb(x)
# requires (B, C, hidden_dim)
steps = past_y.shape[1]
N = x.shape[1]
mask = fast_transformers.masking.TriangularCausalMask(N, device=device)
outputs = self.encoder(x, attn_mask=mask)[:, steps:, :]
# Size([B, T, emb_dim])
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
return (
y_dist,
{}
)
+9 -4
View File
@@ -31,8 +31,8 @@ class LatentEncoder(nn.Module):
self.encoder = nn.TransformerEncoder(
layer_enc, num_layers=num_layers, norm=encoder_norm
)
self.mean = nn.Linear(hidden_size, latent_dim)
self.log_var = nn.Linear(hidden_size, latent_dim)
self.mean = nn.Linear(hidden_size*3, latent_dim)
self.log_var = nn.Linear(hidden_size*3, latent_dim)
self._min_std = min_std
def forward(self, x, y):
@@ -48,7 +48,13 @@ class LatentEncoder(nn.Module):
r = self.encoder(x, mask=mask)
r = r.permute(1, 0, 2) # (S,B,hidden_size) -> (B,S,hidden_size)
r = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size)
# Aggregation (max/mean/last)
r_mean = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size)
r_last = r[:, -1, :]
r_max = r.max(1)[0]
r = torch.cat([r_mean, r_last, r_max], -1)
mean = self.mean(r)
log_sigma = self.log_var(r)
sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_sigma * 0.5)
@@ -56,7 +62,6 @@ class LatentEncoder(nn.Module):
return dist
class Decoder(nn.Module):
def __init__(
self,
+1
View File
@@ -2,6 +2,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class TransformerSeq(nn.Module):
"""
+3 -4
View File
@@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class TransformerSeq2Seq(nn.Module):
def __init__(self, x_size, y_size, hidden_size=16, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01, nan_value=0):
@@ -16,7 +16,7 @@ class TransformerSeq2Seq(nn.Module):
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*4,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
# activation
@@ -27,7 +27,7 @@ class TransformerSeq2Seq(nn.Module):
layer_dec = nn.TransformerDecoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*4,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
)
@@ -67,7 +67,6 @@ class TransformerSeq2Seq(nn.Module):
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
# Then expand it, so it's available as we decode, conditional on future_x
memory = memory.max(dim=0, keepdim=True)[0].expand_as(future_x)
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
# [T, B, emb_dim] -> [B, T, emb_dim]