mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
working
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -0,0 +1,676 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:light
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: light
|
||||
# format_version: '1.5'
|
||||
# jupytext_version: 1.6.0
|
||||
# kernelspec:
|
||||
# display_name: seq2seq-time
|
||||
# language: python
|
||||
# name: seq2seq-time
|
||||
# ---
|
||||
|
||||
# # Sequence to Sequence Models for Timeseries Regression
|
||||
#
|
||||
#
|
||||
# In this notebook we are going to tackle a harder problem:
|
||||
# - predicting the future on a timeseries
|
||||
# - using an LSTM
|
||||
# - with rough uncertainty (uncalibrated)
|
||||
# - outputing sequence of predictions
|
||||
#
|
||||
# <img src="../reports/figures/Seq2Seq for regression.png" />
|
||||
#
|
||||
#
|
||||
# https://medium.com/@boitemailjeanmid/smart-meters-in-london-part1-description-and-first-insights-jean-michel-d-db97af2de71b
|
||||
#
|
||||
|
||||
# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
|
||||
# %load_ext autoreload
|
||||
# %autoreload 2
|
||||
# %aimport -pandas
|
||||
# %aimport -torch
|
||||
# %aimport -numpy
|
||||
# %aimport -matplotlib
|
||||
# %aimport -dask
|
||||
# %aimport -tqdm
|
||||
# %matplotlib inline
|
||||
|
||||
# +
|
||||
# Imports
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Variable
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['figure.figsize'] = (12.0, 3.0)
|
||||
plt.style.use('ggplot')
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import pytorch_lightning as pl
|
||||
# -
|
||||
|
||||
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
|
||||
from seq2seq_time.predict import predict
|
||||
|
||||
import logging, sys
|
||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
# ## Parameters
|
||||
|
||||
# +
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f'using {device}')
|
||||
|
||||
columns_target=['energy(kWh/hh)']
|
||||
window_past = 48*2
|
||||
window_future = 48*2
|
||||
batch_size = 256
|
||||
num_workers = 5
|
||||
freq = '30T'
|
||||
max_rows = 2e5
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# ## Load data
|
||||
|
||||
# +
|
||||
|
||||
def get_smartmeter_df(indir=Path('../data/raw/smart-meters-in-london'), max_files=1):
|
||||
"""
|
||||
Data loading and cleanding is always messy, so understand this code is optional.
|
||||
"""
|
||||
|
||||
# Load csv files
|
||||
csv_files = sorted((indir/'halfhourly_dataset').glob('*.csv'))[:max_files]
|
||||
|
||||
# concatendate them
|
||||
df = pd.concat([pd.read_csv(f, parse_dates=[1], na_values=['Null']) for f in csv_files])
|
||||
|
||||
# Add ACORN categories
|
||||
df_households = pd.read_csv(indir/'informations_households.csv')
|
||||
df_households = df_households[['LCLid', 'stdorToU', 'Acorn_grouped']]
|
||||
df = pd.merge(df, df_households, on='LCLid')
|
||||
|
||||
|
||||
df = df.sort_values(['tstp', 'LCLid'])
|
||||
|
||||
df = df.set_index('tstp')
|
||||
|
||||
# Drop nan and 0's
|
||||
df = df[df['energy(kWh/hh)']!=0]
|
||||
df = df.dropna()
|
||||
|
||||
# Add time features
|
||||
time = df.index.to_series()
|
||||
df["month"] = time.dt.month
|
||||
df['day'] = time.dt.day
|
||||
df['week'] = time.dt.week
|
||||
df['hour'] = time.dt.hour
|
||||
df['minute'] = time.dt.minute
|
||||
df['dayofweek'] = time.dt.dayofweek
|
||||
|
||||
# Load weather data
|
||||
df_weather = pd.read_csv(indir/'weather_hourly_darksky.csv', parse_dates=[3])
|
||||
use_cols = ['visibility', 'windBearing', 'temperature', 'time', 'dewPoint',
|
||||
'pressure', 'apparentTemperature', 'windSpeed',
|
||||
'humidity']
|
||||
df_weather = df_weather[use_cols].set_index('time')
|
||||
df_weather = df_weather.resample(freq).first().ffill() # Resample to match energy data
|
||||
|
||||
# Join weather and energy data
|
||||
df = pd.merge(df, df_weather, how='inner', left_index=True, right_index=True, sort=True)
|
||||
|
||||
# Holidays
|
||||
df_hols = pd.read_csv(indir/'uk_bank_holidays.csv', parse_dates=[0])
|
||||
holidays = set(df_hols['Bank holidays'].dt.round('D'))
|
||||
def is_holiday(dt):
|
||||
return dt in holidays
|
||||
days = df.index.floor('D')
|
||||
holiday_mapping = days.unique().to_series().apply(is_holiday).astype(int).to_dict()
|
||||
df['holiday'] = days.to_series().map(holiday_mapping).values
|
||||
|
||||
# sort
|
||||
df = df.reset_index().sort_values(['LCLid', 'index']).set_index('index')
|
||||
df.index.name = 'Date'
|
||||
|
||||
return df
|
||||
# -
|
||||
# Our dataset is the london smartmeter data. But at half hour intervals
|
||||
|
||||
# +
|
||||
df = get_smartmeter_df()
|
||||
|
||||
# # Just get the first one for now
|
||||
# dfs = list(dfs)
|
||||
|
||||
# # df = df.resample(freq).first().dropna() # Where empty we will backfill, this will respect causality, and mostly maintain the mean
|
||||
|
||||
df = df.tail(int(max_rows)).copy() # Just use last X rows
|
||||
# df = pd.concat(dfs[:6], 0)
|
||||
# # df = dfs[0]
|
||||
df.LCLid.value_counts()
|
||||
# -
|
||||
|
||||
|
||||
|
||||
# ### Plot/explore
|
||||
|
||||
df
|
||||
|
||||
|
||||
|
||||
# +
|
||||
import holoviews as hv
|
||||
from holoviews import opts
|
||||
|
||||
from holoviews.plotting.links import RangeToolLink
|
||||
|
||||
import datashader as ds
|
||||
|
||||
from holoviews.operation.datashader import datashade, shade, dynspread, rasterize
|
||||
from holoviews.operation import decimate
|
||||
|
||||
hv.extension('bokeh')
|
||||
|
||||
|
||||
def house_curve(Name=None):
|
||||
if isinstance(Name, int):
|
||||
name = df.LCLid.unique()[Name]
|
||||
d = df[df.LCLid == Name]
|
||||
d_curve = hv.Curve(d, 'Date', 'energy(kWh/hh)', label=Name)
|
||||
return d_curve
|
||||
|
||||
|
||||
dmap = hv.DynamicMap(house_curve, kdims=['Name'])
|
||||
dmap = dmap.redim.values(Name=list(df.LCLid.unique()))
|
||||
dynspread(datashade(dmap).opts(width=800,
|
||||
height=300,
|
||||
tools=['xwheel_zoom', 'pan'],
|
||||
active_tools=['xwheel_zoom', 'pan'],
|
||||
default_tools=['reset', 'save', 'hover']
|
||||
))
|
||||
# -
|
||||
|
||||
|
||||
|
||||
|
||||
# ### Profiling
|
||||
|
||||
# +
|
||||
# from pandas_profiling import ProfileReport
|
||||
# profile = ProfileReport(df, title="Pandas Profiling Report", minimal=True)
|
||||
# profile
|
||||
# -
|
||||
|
||||
# ### Norm
|
||||
|
||||
df.describe()
|
||||
|
||||
# +
|
||||
import sklearn
|
||||
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
|
||||
from sklearn_pandas import DataFrameMapper
|
||||
|
||||
columns_input_numeric = list(df.drop(columns=columns_target)._get_numeric_data().columns)
|
||||
columns_categorical = list(set(df.columns)-set(columns_input_numeric)-set(columns_target))
|
||||
|
||||
output_scalers = [([n], StandardScaler()) for n in columns_target]
|
||||
transformers=output_scalers + \
|
||||
[([n], StandardScaler()) for n in columns_input_numeric] + \
|
||||
[([n], OrdinalEncoder()) for n in columns_categorical]
|
||||
scaler = DataFrameMapper(transformers, df_out=True)
|
||||
df_norm = scaler.fit_transform(df)
|
||||
df_norm
|
||||
# -
|
||||
|
||||
output_scaler = next(filter(lambda r:r[0][0] in columns_target, scaler.features))[-1]
|
||||
output_scaler
|
||||
|
||||
# +
|
||||
# split data, with the test in the future
|
||||
|
||||
d0 =df_norm.index.min()
|
||||
d1 = df_norm.index.max()
|
||||
split_time = d0+(d1-d0)*0.8
|
||||
split_time = split_time.round('1D')
|
||||
print(split_time)
|
||||
df_train = df_norm.groupby('LCLid').apply(lambda d:d.loc[:split_time]).reset_index(level=0, drop=True)
|
||||
df_test = df_norm.groupby('LCLid').apply(lambda d:d.loc[split_time:]).reset_index(level=0, drop=True)
|
||||
# df_test
|
||||
# -
|
||||
|
||||
# Show split
|
||||
df_train['energy(kWh/hh)'].plot(label='train')
|
||||
df_test['energy(kWh/hh)'].plot(label='test')
|
||||
plt.ylabel('energy(kWh/hh)')
|
||||
plt.legend()
|
||||
# ### Dataset
|
||||
# These are the columns that we wont know in the future
|
||||
# We need to blank them out in x_future
|
||||
columns_blank=['visibility',
|
||||
'windBearing', 'temperature', 'dewPoint', 'pressure',
|
||||
'apparentTemperature', 'windSpeed', 'humidity']
|
||||
df_trains = [d.resample(freq).first().ffill().dropna() for _,d in df_train.groupby('LCLid')]
|
||||
df_tests = [d.resample(freq).first().ffill().dropna() for _,d in df_test.groupby('LCLid')]
|
||||
ds_train = Seq2SeqDataSets(df_trains,
|
||||
window_past=window_past,
|
||||
window_future=window_future,
|
||||
columns_blank=columns_blank)
|
||||
ds_test = Seq2SeqDataSets(df_tests,
|
||||
window_past=window_past,
|
||||
window_future=window_future,
|
||||
columns_blank=columns_blank)
|
||||
print(ds_train)
|
||||
print(ds_test)
|
||||
# we can treat it like an array
|
||||
ds_train[0]
|
||||
len(ds_train)
|
||||
ds_train[-1]
|
||||
|
||||
# +
|
||||
# We can get rows
|
||||
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
|
||||
|
||||
# Plot one instance, this is what the model sees
|
||||
y_past['energy(kWh/hh)'].plot(label='past')
|
||||
y_future['energy(kWh/hh)'].plot(ax=plt.gca(), label='future')
|
||||
plt.legend()
|
||||
plt.ylabel('energy(kWh/hh)')
|
||||
|
||||
# Notice we've added on two new columns tsp (time since present) and is_past
|
||||
x_past.tail()
|
||||
# -
|
||||
|
||||
# Notice we've hidden some future columns to prevent cheating
|
||||
x_future.tail()
|
||||
|
||||
# ## Plot helpers
|
||||
|
||||
from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
|
||||
from seq2seq_time.models.lstm import LSTM
|
||||
from seq2seq_time.models.baseline import BaselineLast
|
||||
from seq2seq_time.models.transformer import Transformer
|
||||
from seq2seq_time.models.transformer_seq2seq import TransformerSeq2Seq
|
||||
# ## Plots
|
||||
# +
|
||||
def plot_prediction(ds_preds, i):
|
||||
"""Plot a prediction into the future, at a single point in time."""
|
||||
d = ds_preds.isel(t_source=i)
|
||||
|
||||
# Get arrays
|
||||
xf = d.t_target
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
yt = d.y_true
|
||||
now = d.t_source.squeeze()
|
||||
|
||||
|
||||
plt.figure(figsize=(12, 4))
|
||||
|
||||
plt.scatter(xf, yt, label='true', c='k', s=6)
|
||||
ylim = plt.ylim()
|
||||
|
||||
# plot prediction
|
||||
plt.fill_between(xf, yp-2*s, yp+2*s, alpha=0.25,
|
||||
facecolor="b",
|
||||
interpolate=True,
|
||||
label="2 std",)
|
||||
plt.plot(xf, yp, label='pred', c='b')
|
||||
|
||||
# plot true
|
||||
plt.scatter(
|
||||
d.t_past,
|
||||
d.y_past,
|
||||
c='k',
|
||||
s=6
|
||||
)
|
||||
|
||||
# plot a red line for now
|
||||
plt.vlines(x=now, ymin=0, ymax=1, label='now', color='r')
|
||||
plt.ylim(*ylim)
|
||||
|
||||
now=pd.Timestamp(now.values)
|
||||
plt.title(f'Prediction NLL={d.nll.mean().item():2.2g}')
|
||||
plt.xlabel(f'{now.date()}')
|
||||
plt.ylabel('energy(kWh/hh)')
|
||||
plt.legend()
|
||||
plt.xticks(rotation=45)
|
||||
plt.show()
|
||||
|
||||
def plot_performance(ds_preds, full=False):
|
||||
"""Multiple plots using xr_preds"""
|
||||
plot_prediction(ds_preds, 24)
|
||||
plot_prediction(ds_preds, 480)
|
||||
|
||||
ds_preds.mean('t_source').plot.scatter('t_ahead_hours', 'nll') # Mean over all predictions
|
||||
n = len(ds_preds.t_source)
|
||||
plt.ylabel('Negative Log Likelihood (lower is better)')
|
||||
plt.xlabel('Hours ahead')
|
||||
plt.title(f'NLL vs time ahead (no. samples={n})')
|
||||
plt.show()
|
||||
|
||||
# Make a plot of the NLL over time. Does this solution get worse with time?
|
||||
if full:
|
||||
d = ds_preds.mean('t_ahead').groupby('t_source').mean().plot.scatter('t_source', 'nll')
|
||||
plt.xticks(rotation=45)
|
||||
plt.title('NLL over source time (lower is better)')
|
||||
plt.show()
|
||||
|
||||
# A scatter plot is easy with xarray
|
||||
if full:
|
||||
plt.figure(figsize=(5, 5))
|
||||
ds_preds.plot.scatter('y_true', 'y_pred', s=.01)
|
||||
plt.show()
|
||||
|
||||
print(f'mean_NLL {ds_preds.nll.mean().item():2.2f}')
|
||||
# -
|
||||
|
||||
|
||||
|
||||
def plot_hist(trainer):
|
||||
try:
|
||||
df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
|
||||
df_hist['epoch'] = df_hist['epoch'].ffill()
|
||||
df_histe = df_hist.set_index('epoch').groupby('epoch').mean()
|
||||
if len(df_histe)>1:
|
||||
df_histe[['loss/train', 'loss/val']].plot(title='history')
|
||||
return df_histe
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ## Lightning
|
||||
|
||||
# +
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class PL_MODEL(pl.LightningModule):
|
||||
def __init__(self, model, lr=3e-4, patience=2):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self.lr = lr
|
||||
self.patience = patience
|
||||
|
||||
def forward(self, x_past, y_past, x_future, y_future=None):
|
||||
"""Eval/Predict"""
|
||||
y_dist = self._model(x_past, y_past, x_future)
|
||||
return y_dist
|
||||
|
||||
def training_step(self, batch, batch_idx, phase='train'):
|
||||
x_past, y_past, x_future, y_future = batch
|
||||
y_dist = self.forward(*batch)
|
||||
loss = -y_dist.log_prob(y_future).mean()
|
||||
self.log_dict({f'loss/{phase}':loss})
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self.training_step(batch, batch_idx, phase='val')
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optim,
|
||||
patience=self.patience,
|
||||
verbose=True,
|
||||
min_lr=1e-7,
|
||||
)
|
||||
return {'optimizer': optim, 'lr_scheduler': scheduler, 'monitor': 'loss/val'}
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# # Run
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from pl_bolts.callbacks import PrintTableMetricsCallback
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
|
||||
# +
|
||||
# Init data
|
||||
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
|
||||
input_size = x_past.shape[-1]
|
||||
output_size = y_future.shape[-1]
|
||||
|
||||
dl_train = DataLoader(ds_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=num_workers==0,
|
||||
num_workers=num_workers)
|
||||
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=num_workers)
|
||||
# -
|
||||
|
||||
# Baseline model
|
||||
pt_model = BaselineLast()
|
||||
model = PL_MODEL(pt_model).to(device)
|
||||
trainer = pl.Trainer(gpus=1,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.01,
|
||||
logger=CSVLogger("logs",
|
||||
name=type(pt_model).__name__),
|
||||
)
|
||||
trainer.fit(model, dl_train, dl_test)
|
||||
print(plot_hist(trainer))
|
||||
ds_preds = predict(model.to(device), ds_test.datasets[0], batch_size, device=device, scaler=output_scaler)
|
||||
print(f'baseline nll: {ds_preds.nll.mean().item():2.2g}')
|
||||
|
||||
|
||||
|
||||
models = [
|
||||
# BaselineLast(),
|
||||
LSTM(input_size,
|
||||
output_size,
|
||||
hidden_size=80,
|
||||
lstm_layers=3,
|
||||
lstm_dropout=0.3),
|
||||
Transformer(input_size,
|
||||
output_size,
|
||||
attention_dropout=0.3,
|
||||
nhead=8,
|
||||
nlayers=6,
|
||||
hidden_size=64),
|
||||
LSTMSeq2Seq(input_size,
|
||||
output_size,
|
||||
hidden_size=64,
|
||||
lstm_layers=2,
|
||||
lstm_dropout=0.25),
|
||||
TransformerSeq2Seq(input_size,
|
||||
output_size,
|
||||
hidden_size=64,
|
||||
nhead=8,
|
||||
nlayers=4,
|
||||
attention_dropout=0.3),
|
||||
# Transformer(input_size,
|
||||
# output_size,
|
||||
# attention_dropout=0.2,
|
||||
# nhead=8,
|
||||
# nlayers=6,
|
||||
# hidden_size=128),
|
||||
# LSTM(input_size,
|
||||
# output_size,
|
||||
# hidden_size=128,
|
||||
# lstm_layers=3,
|
||||
# lstm_dropout=0.3),
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for pt_model in models:
|
||||
name = type(pt_model).__name__
|
||||
print(name)
|
||||
|
||||
# Wrap in lightning
|
||||
patience = 2
|
||||
model = PL_MODEL(pt_model, patience=patience, lr=3e-4).to(device)
|
||||
|
||||
# Trainer
|
||||
|
||||
trainer = pl.Trainer(gpus=1,
|
||||
min_epochs=1,
|
||||
max_epochs=10,
|
||||
amp_level='O1',
|
||||
precision=16,
|
||||
gradient_clip_val=0.5,
|
||||
logger=CSVLogger("logs",
|
||||
name=type(pt_model).__name__),
|
||||
callbacks=[
|
||||
EarlyStopping(monitor='loss/val', patience=patience*2),
|
||||
PrintTableMetricsCallback()
|
||||
],
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.fit(model, dl_train, dl_test)
|
||||
|
||||
# Performance
|
||||
print(plot_hist(trainer))
|
||||
|
||||
ds_preds = predict(model.to(device),
|
||||
ds_test.datasets[0],
|
||||
batch_size,
|
||||
device=device,
|
||||
scaler=output_scaler)
|
||||
plot_performance(ds_preds)
|
||||
|
||||
|
||||
|
||||
# # holoviews pred
|
||||
|
||||
import holoviews as hv
|
||||
from holoviews import opts
|
||||
|
||||
|
||||
# +
|
||||
def plot_prediction_now(t_source):
|
||||
"""Plot predictions with holoviews"""
|
||||
|
||||
# Let us pass in an int
|
||||
if isinstance(t_source, int):
|
||||
t_source = ds_preds.t_source[t_source].to_pandas()
|
||||
|
||||
d = ds_preds.sel(t_source=t_source)
|
||||
|
||||
# Sometimes there are duplicate time, take the first
|
||||
if len(d.t_source.shape) and d.t_source.shape[0] > 0:
|
||||
d = d.isel(t_source=0)
|
||||
if len(d.t_source.shape) and d.t_source.shape[0] == 0:
|
||||
return None
|
||||
|
||||
now = d.t_source.to_pandas()
|
||||
|
||||
# Plot true
|
||||
x = np.concatenate([d.t_past, d.t_target])
|
||||
yt = np.concatenate([d.y_past, d.y_true])
|
||||
p = hv.Scatter({
|
||||
'x': x,
|
||||
'y': yt
|
||||
}, label='true').opts(color='black', framewise=True)
|
||||
|
||||
# Get arrays
|
||||
xf = d.t_target.values
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
p *= hv.Curve({
|
||||
'x': xf,
|
||||
'y': yp
|
||||
}, label='pred').opts(color='blue')
|
||||
p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
|
||||
vdims=['y', 'y2'],
|
||||
label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
# plot now line
|
||||
p *= hv.VLine(now, label='now').opts(color='red')
|
||||
return p.opts(title=f'Prediction at {now}. NLL={d.nll.mean().item():2.2f}')
|
||||
|
||||
|
||||
dmap_pred = (hv.DynamicMap(plot_prediction_now, kdims=['t_source'])
|
||||
.redim.values(t_source=ds_preds.t_source.to_pandas())
|
||||
.opts(width=800,
|
||||
height=300,
|
||||
))
|
||||
dmap_pred
|
||||
|
||||
|
||||
# +
|
||||
def plot_predictions_vs_time(it_ahead):
|
||||
"""Plot predictions vs time with holoviews"""
|
||||
|
||||
d = ds_preds.isel(t_ahead=it_ahead).groupby('t_source').first()
|
||||
|
||||
p = hv.Scatter({
|
||||
'x': d.t_source,
|
||||
'y': d.y_true
|
||||
}, label='true').opts(color='black')
|
||||
|
||||
# Get arrays
|
||||
xf = d.t_source.values
|
||||
yp = d.y_pred
|
||||
s = d.y_pred_std
|
||||
p *= hv.Curve({
|
||||
'x': xf,
|
||||
'y': yp
|
||||
}, label='pred').opts(color='blue')
|
||||
p *= hv.Area((xf, yp - 2 * s, yp + 2 * s),
|
||||
vdims=['y', 'y2'],
|
||||
label='2*std').opts(alpha=0.5, line_width=0)
|
||||
|
||||
|
||||
return p.opts(title=f'Prediction at {it_ahead * pd.Timedelta(freq)} ahead. NLL={d.nll.mean().item():2.2f}')
|
||||
|
||||
|
||||
dmap_preds = (hv.DynamicMap(plot_predictions_vs_time, kdims=['it_ahead'])
|
||||
.redim.values(it_ahead=range(ds_preds.t_ahead.shape[0]))
|
||||
.opts(width=800,
|
||||
height=300,
|
||||
))
|
||||
dmap_preds
|
||||
# plot_prediction2(10).opts(width=800)
|
||||
# -
|
||||
|
||||
d = ds_preds.mean('t_source')['nll'].groupby('t_ahead_hours').mean()
|
||||
nll_vs_tahead = hv.Curve((d.t_ahead_hours, d)).redim(x='hours ahead', y='nll').opts(width=800)
|
||||
nll_vs_tahead
|
||||
|
||||
d = ds_preds.mean('t_ahead')['nll'].groupby('t_source').mean()
|
||||
nll_vs_time = hv.Curve(d).opts(width=800)
|
||||
nll_vs_time
|
||||
|
||||
true_vs_pred = hv.Scatter((ds_preds.y_true, ds_preds.y_pred))
|
||||
dynspread(datashade(true_vs_pred))
|
||||
|
||||
# # Summarize experiments
|
||||
|
||||
# # LR finder
|
||||
|
||||
# +
|
||||
|
||||
# Run learning rate finder
|
||||
lr_finder = trainer.tuner.lr_find(model)
|
||||
|
||||
# Results can be found in
|
||||
lr_finder.results
|
||||
|
||||
# Plot with
|
||||
fig = lr_finder.plot(suggest=True)
|
||||
fig.show()
|
||||
|
||||
# Pick point based on plot, or get suggestion
|
||||
new_lr = lr_finder.suggestion()
|
||||
# -
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ dependencies:
|
||||
- blas=1.0=mkl
|
||||
- bleach=3.2.1=pyh9f0ad1d_0
|
||||
- blinker=1.4=py_1
|
||||
- bokeh=2.2.2=py37hc8dfbb8_0
|
||||
- botocore=1.18.18=pyh9f0ad1d_0
|
||||
- brotlipy=0.7.0=py37hb5d75c8_1001
|
||||
- c-ares=1.16.1=h516909a_3
|
||||
@@ -30,14 +31,23 @@ dependencies:
|
||||
- cffi=1.14.3=py37h00ebd2e_1
|
||||
- chardet=3.0.4=py37he5f6b98_1008
|
||||
- click=7.1.2=pyh9f0ad1d_0
|
||||
- cloudpickle=1.6.0=py_0
|
||||
- colorama=0.4.3=py_0
|
||||
- colorcet=2.0.2=py_0
|
||||
- confuse=1.3.0=pyh9f0ad1d_0
|
||||
- cryptography=3.1.1=py37hff6837a_1
|
||||
- cudatoolkit=10.2.89=hfd86e86_1
|
||||
- cycler=0.10.0=py_2
|
||||
- cytoolz=0.11.0=py37h8f50634_1
|
||||
- dask=2.30.0=py_0
|
||||
- dask-core=2.30.0=py_0
|
||||
- dataclasses=0.7=py37_0
|
||||
- datashader=0.11.1=pyh9f0ad1d_0
|
||||
- datashape=0.5.4=py_1
|
||||
- dbus=1.13.18=hb2f20db_0
|
||||
- decorator=4.4.2=py_0
|
||||
- defusedxml=0.6.0=py_0
|
||||
- distributed=2.30.0=py37hc8dfbb8_1
|
||||
- docutils=0.15.2=py37_0
|
||||
- entrypoints=0.3=py37hc8dfbb8_1002
|
||||
- expat=2.2.10=he6710b0_2
|
||||
@@ -52,8 +62,12 @@ dependencies:
|
||||
- grpcio=1.31.0=py37hb0870dc_0
|
||||
- gst-plugins-base=1.14.5=h0935bb2_2
|
||||
- gstreamer=1.14.5=h36ae1b5_2
|
||||
- heapdict=1.0.1=py_0
|
||||
- holoviews=1.13.4=pyh9f0ad1d_0
|
||||
- htmlmin=0.1.12=py_1
|
||||
- icu=67.1=he1b5a44_0
|
||||
- idna=2.10=pyh9f0ad1d_0
|
||||
- imagehash=4.1.0=pyh9f0ad1d_0
|
||||
- importlib-metadata=2.0.0=py37hc8dfbb8_0
|
||||
- importlib_metadata=2.0.0=1
|
||||
- iniconfig=1.1.1=py_0
|
||||
@@ -101,16 +115,21 @@ dependencies:
|
||||
- libxcb=1.14=h7b6447c_0
|
||||
- libxkbcommon=0.10.0=he1b5a44_0
|
||||
- libxml2=2.9.10=h68273f3_2
|
||||
- llvmlite=0.34.0=py37h5202443_2
|
||||
- locket=0.2.0=py_2
|
||||
- lz4-c=1.9.2=he1b5a44_3
|
||||
- markdown=3.3.1=pyh9f0ad1d_0
|
||||
- markupsafe=1.1.1=py37hb5d75c8_2
|
||||
- matplotlib=3.3.2=py37hc8dfbb8_1
|
||||
- matplotlib-base=3.3.2=py37hc9afd2a_1
|
||||
- mccabe=0.6.1=py_1
|
||||
- missingno=0.4.2=py_1
|
||||
- mistune=0.8.4=py37h8f50634_1002
|
||||
- mkl=2020.2=256
|
||||
- more-itertools=8.5.0=py_0
|
||||
- msgpack-python=1.0.0=py37h99015e2_2
|
||||
- multidict=4.7.6=py37h7b6447c_1
|
||||
- multipledispatch=0.6.0=py_0
|
||||
- mypy=0.790=py_0
|
||||
- mypy_extensions=0.4.3=py37hc8dfbb8_1
|
||||
- mysql-common=8.0.21=2
|
||||
@@ -120,22 +139,30 @@ dependencies:
|
||||
- nbformat=5.0.8=py_0
|
||||
- ncurses=6.2=he1b5a44_2
|
||||
- nest-asyncio=1.4.1=py_0
|
||||
- networkx=2.5=py_0
|
||||
- ninja=1.10.1=hfc4b9b4_2
|
||||
- notebook=6.1.4=py37hc8dfbb8_1
|
||||
- nspr=4.29=he1b5a44_1
|
||||
- nss=3.58=h27285de_1
|
||||
- numba=0.51.2=py37h9fdb41a_0
|
||||
- numpy=1.19.2=py37h7ea13bd_1
|
||||
- oauthlib=3.1.0=py_0
|
||||
- olefile=0.46=pyh9f0ad1d_1
|
||||
- openssl=1.1.1h=h516909a_0
|
||||
- packaging=20.4=pyh9f0ad1d_0
|
||||
- pandas=1.1.3=py37h9fdb41a_2
|
||||
- pandas-profiling=2.9.0=pyh9f0ad1d_0
|
||||
- pandoc=2.11.0.2=hd18ef5c_0
|
||||
- pandocfilters=1.4.2=py_1
|
||||
- panel=0.9.7=py_0
|
||||
- param=1.9.3=py_0
|
||||
- parso=0.7.1=pyh9f0ad1d_0
|
||||
- partd=1.1.0=py_0
|
||||
- pathspec=0.8.0=pyh9f0ad1d_0
|
||||
- patsy=0.5.1=py_0
|
||||
- pcre=8.44=he1b5a44_0
|
||||
- pexpect=4.8.0=py37hc8dfbb8_1
|
||||
- phik=0.10.0=py_0
|
||||
- pickleshare=0.7.5=py37hc8dfbb8_1002
|
||||
- pillow=8.0.0=py37h718be6c_0
|
||||
- pip=20.2.4=py_0
|
||||
@@ -150,6 +177,7 @@ dependencies:
|
||||
- pyasn1-modules=0.2.8=py_0
|
||||
- pycodestyle=2.6.0=pyh9f0ad1d_0
|
||||
- pycparser=2.20=pyh9f0ad1d_2
|
||||
- pyct=0.4.8=py37_0
|
||||
- pydocstyle=5.1.1=py_0
|
||||
- pyflakes=2.2.0=pyh9f0ad1d_0
|
||||
- pygments=2.7.1=py_0
|
||||
@@ -167,6 +195,8 @@ dependencies:
|
||||
- pytorch=1.6.0=py3.7_cuda10.2.89_cudnn7.6.5_0
|
||||
- pytorch-lightning=1.0.2=py_0
|
||||
- pytz=2020.1=pyh9f0ad1d_0
|
||||
- pyviz_comms=0.7.6=pyh9f0ad1d_0
|
||||
- pywavelets=1.1.1=py37h161383b_3
|
||||
- pyyaml=5.3.1=py37hb5d75c8_1
|
||||
- pyzmq=19.0.2=py37hac76be4_2
|
||||
- qt=5.12.9=h1f2b2cb_0
|
||||
@@ -178,11 +208,17 @@ dependencies:
|
||||
- s3transfer=0.3.3=py37hc8dfbb8_2
|
||||
- scikit-learn=0.23.2=py37h6785257_0
|
||||
- scipy=1.5.2=py37hb14ef9d_2
|
||||
- seaborn=0.11.0=0
|
||||
- seaborn-base=0.11.0=py_0
|
||||
- send2trash=1.5.0=py_0
|
||||
- setuptools=49.6.0=py37he5f6b98_2
|
||||
- six=1.15.0=pyh9f0ad1d_0
|
||||
- snowballstemmer=2.0.0=py_0
|
||||
- sortedcontainers=2.2.2=pyh9f0ad1d_0
|
||||
- sqlite=3.33.0=h4cf870e_1
|
||||
- statsmodels=0.12.0=py37h161383b_1
|
||||
- tangled-up-in-unicode=0.0.6=pyh9f0ad1d_0
|
||||
- tblib=1.7.0=py_0
|
||||
- tensorboard=2.3.0=py_0
|
||||
- tensorboard-plugin-wit=1.6.0=pyh9f0ad1d_0
|
||||
- terminado=0.9.1=py37hc8dfbb8_1
|
||||
@@ -190,6 +226,7 @@ dependencies:
|
||||
- threadpoolctl=2.1.0=pyh5ca1d4c_0
|
||||
- tk=8.6.10=hed695b0_1
|
||||
- toml=0.10.1=pyh9f0ad1d_0
|
||||
- toolz=0.11.1=py_0
|
||||
- torchvision=0.7.0=py37_cu102
|
||||
- tornado=6.0.4=py37h8f50634_2
|
||||
- tqdm=4.50.2=pyh9f0ad1d_0
|
||||
@@ -198,6 +235,7 @@ dependencies:
|
||||
- typing-extensions=3.7.4.3=0
|
||||
- typing_extensions=3.7.4.3=py_0
|
||||
- urllib3=1.25.10=py_0
|
||||
- visions=0.5.0=pyh9f0ad1d_0
|
||||
- wcwidth=0.2.5=pyh9f0ad1d_2
|
||||
- webencodings=0.5.1=py_1
|
||||
- werkzeug=1.0.1=pyh9f0ad1d_0
|
||||
@@ -209,6 +247,7 @@ dependencies:
|
||||
- yapf=0.30.0=pyh9f0ad1d_0
|
||||
- yarl=1.6.2=py37h8f50634_0
|
||||
- zeromq=4.3.3=he1b5a44_2
|
||||
- zict=2.0.0=py_0
|
||||
- zipp=3.3.1=py_0
|
||||
- zlib=1.2.11=h516909a_1010
|
||||
- zstd=1.4.5=h6597ccf_2
|
||||
@@ -216,6 +255,7 @@ dependencies:
|
||||
- pyqt5-sip==4.19.18
|
||||
- pyqtchart==5.12
|
||||
- pyqtwebengine==5.12.1
|
||||
- pytorch-lightning-bolts==0.2.5
|
||||
- sklearn-pandas==2.0.2
|
||||
- torchsummaryx==1.3.0
|
||||
prefix: /home/wassname/anaconda/envs/seq2seq-time
|
||||
|
||||
@@ -23,4 +23,7 @@ dependencies:
|
||||
- pytorch-lightning
|
||||
- yapf
|
||||
- ipywidgets
|
||||
- holoviews
|
||||
- pandas-profiling
|
||||
- datashader
|
||||
prefix: /home/wassname/anaconda/envs/seq2seq-time
|
||||
|
||||
@@ -62,8 +62,8 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
|
||||
x_future = x[self.window_past:]
|
||||
y_future = y[self.window_past:]
|
||||
|
||||
# Stop it cheating by using future weather measurements
|
||||
x_future[:, self._icol_blank] = 0
|
||||
# Stop it cheating by using future weather measurements. Fill in with last value
|
||||
x_future[:, self._icol_blank] = x_past[0, self._icol_blank]
|
||||
return x_past, y_past, x_future, y_future
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class BaselineLast(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.std = nn.Parameter(torch.tensor(1.))
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
B, S, F = future_x.shape
|
||||
mean = past_y[:, -1:].repeat(1, S, 1)
|
||||
std = (self.std * 1.0).repeat(1, S, 1)
|
||||
return torch.distributions.Normal(mean, std)
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class LSTM(nn.Module):
|
||||
def __init__(self, input_size, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05, nan_value=0):
|
||||
super().__init__()
|
||||
self._min_std = _min_std
|
||||
self.nan_value = nan_value
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_size + output_size,
|
||||
hidden_size=hidden_size,
|
||||
batch_first=True,
|
||||
num_layers=lstm_layers,
|
||||
dropout=lstm_dropout,
|
||||
)
|
||||
self.mean = nn.Linear(hidden_size, output_size)
|
||||
self.std = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
future_y_fake = (
|
||||
torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
|
||||
)
|
||||
context = torch.cat([past_x, past_y], -1).detach()
|
||||
target = torch.cat([future_x, future_y_fake], -1).detach()
|
||||
x = torch.cat([context, target * 1], 1).detach()
|
||||
|
||||
steps = past_y.shape[1]
|
||||
outputs, _ = self.lstm(x)
|
||||
outputs = outputs[:, steps:, :]
|
||||
|
||||
# outputs: [B, T, num_direction * H]
|
||||
mean = self.mean(outputs)
|
||||
log_sigma = self.std(outputs)
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
y_dist = torch.distributions.Normal(mean, sigma)
|
||||
return y_dist
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class LSTMSeq2Seq(nn.Module):
|
||||
def __init__(self, input_size, output_size, hidden_size=32, lstm_layers=2, lstm_dropout=0, _min_std = 0.05):
|
||||
super().__init__()
|
||||
self._min_std = _min_std
|
||||
|
||||
self.encoder = nn.LSTM(
|
||||
input_size=input_size + output_size,
|
||||
hidden_size=hidden_size,
|
||||
batch_first=True,
|
||||
num_layers=lstm_layers,
|
||||
dropout=lstm_dropout,
|
||||
)
|
||||
self.decoder = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
batch_first=True,
|
||||
num_layers=lstm_layers,
|
||||
dropout=lstm_dropout,
|
||||
)
|
||||
self.mean = nn.Linear(hidden_size, output_size)
|
||||
self.std = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
x = torch.cat([past_x, past_y], -1)
|
||||
_, (h_out, cell) = self.encoder(x)
|
||||
|
||||
# output = [batch size, seq len, hid dim * n directions]
|
||||
outputs, (_, _) = self.decoder(future_x, (h_out, cell))
|
||||
|
||||
# outputs: [B, T, num_direction * H]
|
||||
mean = self.mean(outputs)
|
||||
log_sigma = self.std(outputs)
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
y_dist = torch.distributions.Normal(mean, sigma)
|
||||
return y_dist
|
||||
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""
|
||||
A single transformer, masking nan or 0
|
||||
"""
|
||||
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=2, hidden_size=16, nan_value=0, min_std=0.01):
|
||||
super().__init__()
|
||||
self._min_std = min_std
|
||||
self.nan_value = nan_value
|
||||
enc_x_dim = x_dim + y_dim
|
||||
|
||||
self.enc_emb = nn.Linear(enc_x_dim, hidden_size)
|
||||
encoder_norm = nn.LayerNorm(hidden_size)
|
||||
layer_enc = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
dim_feedforward=hidden_size*4,
|
||||
dropout=attention_dropout,
|
||||
nhead=nhead,
|
||||
# activation
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
layer_enc, num_layers=nlayers, norm=encoder_norm
|
||||
)
|
||||
self.mean = nn.Linear(hidden_size, y_dim)
|
||||
self.std = nn.Linear(hidden_size, y_dim)
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
future_y_fake = (
|
||||
torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
|
||||
)
|
||||
context = torch.cat([past_x, past_y], -1).detach()
|
||||
target = torch.cat([future_x, future_y_fake], -1).detach()
|
||||
x = torch.cat([context, target * 1], 1).detach()
|
||||
|
||||
# Masks
|
||||
x_mask = torch.isfinite(x) & (x != self.nan_value)
|
||||
x[~x_mask] = 0
|
||||
x = x.detach()
|
||||
x_key_padding_mask = ~x_mask.any(-1)
|
||||
|
||||
x = self.enc_emb(x).permute(1, 0, 2)
|
||||
|
||||
outputs = self.encoder(x, src_key_padding_mask=x_key_padding_mask).permute(
|
||||
1, 0, 2
|
||||
)
|
||||
|
||||
# Seems to help a little, especially with extrapolating out of bounds
|
||||
steps = past_y.shape[1]
|
||||
mean = self.mean(outputs)[:, steps:, :]
|
||||
log_sigma = self.std(outputs)[:, steps:, :]
|
||||
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
return torch.distributions.Normal(mean, sigma)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
||||
class TransformerSeq2Seq(nn.Module):
|
||||
def __init__(self, x_size, y_size, hidden_size=16, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01, nan_value=0):
|
||||
super().__init__()
|
||||
self._min_std = min_std
|
||||
self.nan_value = nan_value
|
||||
|
||||
self.enc_emb = nn.Linear(x_size + y_size, hidden_size)
|
||||
self.dec_emb = nn.Linear(x_size, hidden_size)
|
||||
|
||||
encoder_norm = nn.LayerNorm(hidden_size)
|
||||
layer_enc = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
dim_feedforward=hidden_size*4,
|
||||
dropout=attention_dropout,
|
||||
nhead=nhead,
|
||||
# activation
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
layer_enc, num_layers=nlayers, norm=encoder_norm
|
||||
)
|
||||
|
||||
layer_dec = nn.TransformerDecoderLayer(
|
||||
d_model=hidden_size,
|
||||
dim_feedforward=hidden_size*4,
|
||||
dropout=attention_dropout,
|
||||
nhead=nhead,
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(hidden_size)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
layer_dec, num_layers=nlayers, norm=decoder_norm
|
||||
)
|
||||
self.mean = nn.Linear(hidden_size, y_size)
|
||||
self.std = nn.Linear(hidden_size, y_size)
|
||||
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
x = torch.cat([past_x, past_y], -1)
|
||||
|
||||
# Masks
|
||||
future_mask = torch.isfinite(future_x) & (future_x!=self.nan_value)
|
||||
tgt_key_padding_mask = ~future_mask.any(-1)
|
||||
|
||||
past_mask = torch.isfinite(x) & (x!=self.nan_value)
|
||||
src_key_padding_mask = ~past_mask.any(-1)# * float('-inf')
|
||||
|
||||
# Embed
|
||||
x = self.enc_emb(x)
|
||||
# Size([B, C, X]) -> Size([B, C, hidden_dim])
|
||||
future_x = self.dec_emb(future_x)
|
||||
# Size([B, C, T]) -> Size([B, C, hidden_dim])
|
||||
|
||||
x = x.permute(1, 0, 2) # (B,C,hidden_dim) -> (C,B,hidden_dim)
|
||||
future_x = future_x.permute(1, 0, 2)
|
||||
# requires (C, B, hidden_dim)
|
||||
memory = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
|
||||
# Then expand it, so it's available as we decode, conditional on future_x
|
||||
# (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim)
|
||||
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
|
||||
# Then expand it, so it's available as we decode, conditional on future_x
|
||||
memory = memory.max(dim=0, keepdim=True)[0].expand_as(future_x)
|
||||
|
||||
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
|
||||
|
||||
# [T, B, emb_dim] -> [B, T, emb_dim]
|
||||
outputs = outputs.permute(1, 0, 2).contiguous()
|
||||
# Size([B, T, emb_dim])
|
||||
mean = self.mean(outputs)
|
||||
log_sigma = self.std(outputs)
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
return torch.distributions.Normal(mean, sigma)
|
||||
|
||||
Reference in New Issue
Block a user