mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
plotting datasets in hv, also better split with dropna
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:light
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: light
|
||||
# format_version: '1.5'
|
||||
# jupytext_version: 1.6.0
|
||||
# kernelspec:
|
||||
# display_name: seq2seq-time
|
||||
# language: python
|
||||
# name: seq2seq-time
|
||||
# ---
|
||||
|
||||
# OPTIONAL: Load the "autoreload" extension so that code can change. But blacklist large modules
|
||||
# %load_ext autoreload
|
||||
# %autoreload 2
|
||||
# %aimport -pandas
|
||||
# %aimport -torch
|
||||
# %aimport -numpy
|
||||
# %aimport -matplotlib
|
||||
# %aimport -dask
|
||||
# %aimport -tqdm
|
||||
# %matplotlib inline
|
||||
|
||||
# +
|
||||
import xarray as xr
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
# -
|
||||
import warnings
|
||||
warnings.simplefilter('once')
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=DeprecationWarning)
|
||||
|
||||
# +
|
||||
import holoviews as hv
|
||||
from holoviews import opts
|
||||
from holoviews.operation.datashader import datashade, dynspread
|
||||
hv.extension('bokeh', inline=True)
|
||||
from seq2seq_time.visualization.hv_ggplot import ggplot_theme
|
||||
hv.renderer('bokeh').theme = ggplot_theme
|
||||
|
||||
# holoview datashader timeseries options
|
||||
# %opts RGB [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Curve [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Scatter [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Layout [width=800 height=200]
|
||||
# -
|
||||
|
||||
|
||||
# ## Parameters
|
||||
|
||||
# +
|
||||
# # device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# print(f'using {device}')
|
||||
|
||||
window_past = 48*2
|
||||
window_future = 48*2
|
||||
batch_size = 128
|
||||
datasets_root = Path('../data/processed/')
|
||||
# -
|
||||
|
||||
# ## Plot helpers
|
||||
|
||||
# ## Datasets
|
||||
|
||||
# +
|
||||
from seq2seq_time.data.data import IMOSCurrentsVel, AppliancesEnergyPrediction, BejingPM25, GasSensor, MetroInterstateTraffic
|
||||
|
||||
datasets = [IMOSCurrentsVel, BejingPM25, GasSensor, AppliancesEnergyPrediction, MetroInterstateTraffic, ]
|
||||
datasets
|
||||
# -
|
||||
# View train, test, val splits
|
||||
l = hv.Layout()
|
||||
for dataset in datasets:
|
||||
d = dataset(datasets_root)
|
||||
|
||||
p = dynspread(
|
||||
datashade(hv.Scatter(d.df_train[d.columns_target[0]]),
|
||||
cmap='red'))
|
||||
p *= dynspread(
|
||||
datashade(hv.Scatter(d.df_val[d.columns_target[0]]),
|
||||
cmap='green'))
|
||||
p *= dynspread(
|
||||
datashade(hv.Scatter(d.df_test[d.columns_target[0]]),
|
||||
cmap='blue'))
|
||||
p = p.opts(title=f"{dataset}")
|
||||
display(p)
|
||||
|
||||
|
||||
# +
|
||||
# plot a batch
|
||||
def plot_batch_y(ds, i):
|
||||
x_past, y_past, x_future, y_future = ds.get_rows(i)
|
||||
y = pd.concat([y_past, y_future])
|
||||
p = hv.Scatter(y)
|
||||
|
||||
now = y_past.index[-1]
|
||||
p *= hv.VLine(now).relabel('now').opts(color='red')
|
||||
return p
|
||||
|
||||
def plot_batches_y(dataset):
|
||||
ds_name = type(dataset).__name__
|
||||
opts=dict(width=200, height=100, xaxis=None, yaxis=None)
|
||||
ds_train, ds_val, ds_test = d.to_datasets(window_past=window_past,
|
||||
window_future=window_future)
|
||||
n = 4
|
||||
max_i = min(len(ds_train), len(ds_val), len(ds_test))
|
||||
ii = list(np.linspace(0, max_i-10, n-1).astype(int)) + [-1]
|
||||
l = hv.Layout()
|
||||
for i in ii:
|
||||
l += plot_batch_y(ds_train, i).opts(title=f'train {i}', **opts)
|
||||
l += plot_batch_y(ds_val, i).opts(title=f'val {i}', **opts)
|
||||
l += plot_batch_y(ds_test, i).opts(title=f'test {i}', **opts)
|
||||
return l.opts(shared_axes=False, toolbar='right', title=ds_name).cols(3)
|
||||
|
||||
# +
|
||||
|
||||
|
||||
|
||||
# View train, test, val splits
|
||||
for dataset in datasets:
|
||||
d = dataset(datasets_root)
|
||||
display(plot_batches_y(d))
|
||||
|
||||
|
||||
# +
|
||||
def plot_batch_x(ds, i):
|
||||
"""Plot input features"""
|
||||
x_past, y_past, x_future, y_future = ds.get_rows(10)
|
||||
x = pd.concat([x_past, x_future])
|
||||
p = hv.NdOverlay({
|
||||
col: hv.Curve(x[col]) for col in x.columns
|
||||
}, kdims='column')
|
||||
now = y_past.index[-1]
|
||||
p *= hv.VLine(now).relabel('now').opts(color='red')
|
||||
return p
|
||||
|
||||
def plot_batches_x(d):
|
||||
"""Plot input features for multiple batch"""
|
||||
ds_train, ds_val, ds_test = d.to_datasets(window_past=window_past,
|
||||
window_future=window_future)
|
||||
l = plot_batch_x(ds_train, 10) + plot_batch_x(ds_val, 10) + plot_batch_x(ds_test, 10)
|
||||
l = l.cols(1).opts(shared_axes=False, title=f'{type(d).__name__}')
|
||||
return l
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# View train, test, val splits
|
||||
for dataset in datasets:
|
||||
d = dataset(datasets_root)
|
||||
display(plot_batches_x(d))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+10980
-25244
File diff suppressed because one or more lines are too long
@@ -48,6 +48,11 @@
|
||||
# %aimport -tqdm
|
||||
# %matplotlib inline
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter('once')
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=DeprecationWarning)
|
||||
|
||||
# +
|
||||
# Imports
|
||||
import torch
|
||||
@@ -61,58 +66,41 @@ import xarray as xr
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['figure.figsize'] = (10.0, 2.0)
|
||||
plt.style.use('ggplot')
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import pytorch_lightning as pl
|
||||
# -
|
||||
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter('once')
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=DeprecationWarning)
|
||||
|
||||
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
|
||||
from seq2seq_time.predict import predict, predict_multi
|
||||
from seq2seq_time.util import dset_to_nc
|
||||
|
||||
import logging, sys
|
||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
# +
|
||||
import holoviews as hv
|
||||
from holoviews import opts
|
||||
from holoviews.operation.datashader import datashade, dynspread
|
||||
hv.extension('bokeh', inline=False)
|
||||
hv.extension('bokeh', inline=True)
|
||||
from seq2seq_time.visualization.hv_ggplot import ggplot_theme
|
||||
hv.renderer('bokeh').theme = ggplot_theme
|
||||
|
||||
# holoview datashader timeseries options
|
||||
# %opts RGB [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts Curve [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts Scatter [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset"] toolbar="right"]
|
||||
# %opts RGB [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Curve [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Scatter [width=800 height=200 show_grid=True active_tools=["xwheel_zoom"] default_tools=["xpan","xwheel_zoom", "reset", "hover"] toolbar="right"]
|
||||
# %opts Layout [width=800 height=200]
|
||||
# -
|
||||
|
||||
|
||||
from seq2seq_time.data.dataset import Seq2SeqDataSet, Seq2SeqDataSets
|
||||
from seq2seq_time.predict import predict, predict_multi
|
||||
from seq2seq_time.util import dset_to_nc
|
||||
|
||||
# ## Parameters
|
||||
|
||||
# +
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f'using {device}')
|
||||
|
||||
columns_target=['energy(kWh/hh)']
|
||||
window_past = 48*2
|
||||
window_future = 48*2
|
||||
batch_size = 128
|
||||
num_workers = 5
|
||||
freq = '30T'
|
||||
max_rows = 5e5
|
||||
num_workers = 4
|
||||
datasets_root = Path('../data/processed/')
|
||||
window_past
|
||||
|
||||
@@ -169,8 +157,7 @@ def hv_plot_prediction(d):
|
||||
return p
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# +
|
||||
def plot_performance(ds_preds, full=False):
|
||||
"""Multiple plots using xr_preds"""
|
||||
p = hv_plot_prediction(ds_preds.isel(t_source=10))
|
||||
@@ -205,6 +192,7 @@ def plot_performance(ds_preds, full=False):
|
||||
true_vs_pred = dynspread(true_vs_pred)
|
||||
true_vs_pred
|
||||
display(true_vs_pred)
|
||||
|
||||
def plot_hist(trainer):
|
||||
try:
|
||||
df_hist = pd.read_csv(trainer.logger.experiment.metrics_file_path)
|
||||
@@ -241,21 +229,24 @@ def df_bold_min(data):
|
||||
return pd.DataFrame(np.where(is_min, attr, ''),
|
||||
index=data.index, columns=data.columns)
|
||||
|
||||
def display_results(results, metric='nll', strformat="{:2.2f}"):
|
||||
def format_results(results, metric=None):
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()}).T
|
||||
df_results = df_results.rename_axis(index='models', columns=metric)
|
||||
if metric:
|
||||
return df_results.xs(metric, axis=1, level=1).rename_axis(columns=metric)
|
||||
return df_results
|
||||
|
||||
def display_results(results, metric='nll', strformat="{:.2f}"):
|
||||
df_results = format_results(results, metric=metric)
|
||||
|
||||
# display metric
|
||||
display(df_results
|
||||
.xs(metric, axis=1, level=1)
|
||||
.style.format(strformat)
|
||||
.apply(df_bold_min)
|
||||
)
|
||||
return df_results
|
||||
|
||||
|
||||
# -
|
||||
|
||||
|
||||
|
||||
# ## Datasets
|
||||
|
||||
# +
|
||||
@@ -267,10 +258,6 @@ datasets
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# View train, test, val splits
|
||||
l = hv.Layout()
|
||||
for dataset in datasets:
|
||||
@@ -341,7 +328,10 @@ from pytorch_lightning.loggers import CSVLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
|
||||
# Models
|
||||
# # Models
|
||||
|
||||
# +
|
||||
|
||||
from seq2seq_time.models.lstm_seq2seq import LSTMSeq2Seq
|
||||
from seq2seq_time.models.lstm_seq import LSTMSeq
|
||||
from seq2seq_time.models.lstm import LSTM
|
||||
@@ -354,6 +344,8 @@ from seq2seq_time.models.neural_process import RANP
|
||||
from seq2seq_time.models.transformer_process import TransformerProcess
|
||||
from seq2seq_time.models.tcn import TCNSeq2Seq
|
||||
# ## Plots
|
||||
|
||||
|
||||
# +
|
||||
import gc
|
||||
|
||||
@@ -409,11 +401,6 @@ models = [
|
||||
|
||||
|
||||
|
||||
# ## Train
|
||||
|
||||
from collections import defaultdict
|
||||
results = defaultdict(dict)
|
||||
|
||||
# +
|
||||
# Summarize each models shape and weights
|
||||
Dataset = datasets[0]
|
||||
@@ -434,8 +421,13 @@ for m_fn in models:
|
||||
df_summary, df_total = summary(pt_model, x_past, y_past, x_future, y_future, print_summary=False)
|
||||
sizes.append(df_total.rename(columns={'Totals':model_name}))
|
||||
df_model_sizes = pd.concat(sizes, 1)
|
||||
df_model_sizes
|
||||
df_model_sizes.style.format(pd.io.formats.format.EngFormatter(use_eng_prefix=True))
|
||||
# -
|
||||
# ## Train
|
||||
|
||||
from collections import defaultdict
|
||||
results = defaultdict(dict)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -450,8 +442,8 @@ for Dataset in datasets:
|
||||
|
||||
# Init data
|
||||
x_past, y_past, x_future, y_future = ds_train.get_rows(10)
|
||||
input_size = x_past.shape[-1]
|
||||
output_size = y_future.shape[-1]
|
||||
xs = x_past.shape[-1]
|
||||
ys = y_future.shape[-1]
|
||||
|
||||
# Loaders
|
||||
dl_train = DataLoader(ds_train,
|
||||
@@ -467,7 +459,7 @@ for Dataset in datasets:
|
||||
for m_fn in models:
|
||||
try:
|
||||
free_mem()
|
||||
pt_model = m_fn()
|
||||
pt_model = m_fn(xs, ys)
|
||||
model_name = type(pt_model).__name__
|
||||
print(dataset_name, model_name)
|
||||
|
||||
@@ -486,7 +478,7 @@ for Dataset in datasets:
|
||||
amp_level='O1',
|
||||
precision=16,
|
||||
|
||||
limit_train_batches=500,
|
||||
limit_train_batches=800,
|
||||
limit_val_batches=150,
|
||||
logger=CSVLogger("../outputs", name=f'{dataset_name}_{model_name}'),
|
||||
callbacks=[
|
||||
@@ -527,37 +519,22 @@ for Dataset in datasets:
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
display(df_results)
|
||||
# -
|
||||
|
||||
|
||||
|
||||
print(f'Negative Log-Likelihood (NLL).\nover {window_future} steps')
|
||||
# df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
df_results
|
||||
results
|
||||
|
||||
# # Leaderboard
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print(f'Symmetric mean absolute percentage error (SMAPE)\nover {window_future} steps')
|
||||
print(f'Negative Log-Likelihood (NLL).\nover {window_future} steps')
|
||||
df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
|
||||
display_results(results, 'nll')
|
||||
|
||||
# # Plots
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# +
|
||||
|
||||
# # plots
|
||||
# Load saved preds
|
||||
ds_predss = defaultdict(dict)
|
||||
for Dataset in datasets:
|
||||
dataset_name = Dataset.__name__
|
||||
for m_fn in models:
|
||||
pt_model = m_fn()
|
||||
pt_model = m_fn(xs, ys)
|
||||
model_name = type(pt_model).__name__
|
||||
|
||||
checkpoint_name = f"{dataset_name}_{model_name}"
|
||||
@@ -603,3 +580,6 @@ plot_performance(ds_preds, full=True)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+69
-40
@@ -11,19 +11,21 @@ import zipfile
|
||||
|
||||
from .dataset import Seq2SeqDataSet
|
||||
from .util import normalize_encode_dataframe, timeseries_split
|
||||
from ..util import dset_to_nc
|
||||
from ..util import dset_to_nc, logger
|
||||
from .tidal import generate_tidal_periods
|
||||
|
||||
|
||||
class RegressionForecastData:
|
||||
columns_forecast = None # The input colums which can be included in future (e.g. week or weather forecast)
|
||||
columns_target = None # Target columns
|
||||
|
||||
def __init__(self, datasets_root):
|
||||
self.datasets_root = datasets_root
|
||||
|
||||
name = type(self).__name__
|
||||
self.cache_file = self.datasets_root / f"._cache_{name}.pkl"
|
||||
|
||||
# Process data
|
||||
self.df = self.download()
|
||||
self.df = self.download_cache()
|
||||
self.df_norm, self.scaler = self.normalize(self.df)
|
||||
self.output_scaler = next(filter(lambda r:r[0][0] in self.columns_target, self.scaler.features))[-1]
|
||||
self.df_train, self.df_val, self.df_test = self.split(self.df_norm)
|
||||
@@ -31,6 +33,17 @@ class RegressionForecastData:
|
||||
# Check processing
|
||||
self.check()
|
||||
|
||||
def clear_cache(self, ):
|
||||
print(f'rm {self.cache_file}')
|
||||
os.remove(self.cache_file)
|
||||
|
||||
def download_cache(self):
|
||||
if not self.cache_file.exists():
|
||||
logger.info(f"Using cache file {self.cache_file}")
|
||||
df = self.download()
|
||||
df.to_pickle(self.cache_file)
|
||||
return pd.read_pickle(self.cache_file)
|
||||
|
||||
@property
|
||||
def columns_past(self):
|
||||
return set(self.df.columns)-set(self.columns_forecast)-set(self.columns_target)
|
||||
@@ -45,8 +58,8 @@ class RegressionForecastData:
|
||||
return df_norm, scaler
|
||||
|
||||
def split(self, df_norm: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
df_train, df_test = timeseries_split(df_norm, 0.3)
|
||||
df_test, df_val = timeseries_split(df_test, 0.5)
|
||||
df_train, df_test = timeseries_split(df_norm, 0.3, dropna=self.columns_forecast)
|
||||
df_test, df_val = timeseries_split(df_test, 0.5, dropna=self.columns_forecast)
|
||||
return df_train, df_val, df_test
|
||||
|
||||
def check(self) -> None:
|
||||
@@ -84,26 +97,21 @@ class GasSensor(RegressionForecastData):
|
||||
# download if needed
|
||||
# extract_path = self.datasets_root/'gas-sensor-array-temperature-modulation.zip'
|
||||
download_url(url, self.datasets_root)
|
||||
outfile = self.datasets_root / 'gas-sensor-array-temperature-modulation.pk'
|
||||
if not outfile.exists():
|
||||
|
||||
# Load csv's from inside zip
|
||||
zf = zipfile.ZipFile(self.datasets_root / 'gas-sensor-array-temperature-modulation.zip')
|
||||
dfs=[]
|
||||
for f in zf.namelist():
|
||||
if f.endswith('.csv'):
|
||||
now = pd.to_datetime(Path(f).stem, format='%Y%m%d_%H%M%S')
|
||||
df = pd.read_csv(zf.open(f))
|
||||
df.index = pd.to_timedelta(df['Time (s)'], unit='s') + now
|
||||
dfs.append(df)
|
||||
self.df = pd.concat(dfs).dropna(subset=self.columns_target)
|
||||
|
||||
# Load csv's from inside zip
|
||||
zf = zipfile.ZipFile(self.datasets_root / 'gas-sensor-array-temperature-modulation.zip')
|
||||
dfs=[]
|
||||
for f in zf.namelist():
|
||||
if f.endswith('.csv'):
|
||||
now = pd.to_datetime(Path(f).stem, format='%Y%m%d_%H%M%S')
|
||||
df = pd.read_csv(zf.open(f))
|
||||
df.index = pd.to_timedelta(df['Time (s)'], unit='s') + now
|
||||
dfs.append(df)
|
||||
self.df = pd.concat(dfs).dropna(subset=self.columns_target)
|
||||
|
||||
df = df[[ 'CO (ppm)', 'Humidity (%r.h.)', 'Temperature (C)',
|
||||
'Flow rate (mL/min)', 'Heater voltage (V)', 'R1 (MOhm)']]
|
||||
df = df.resample('0.3S').first()
|
||||
|
||||
df.to_pickle(outfile)
|
||||
df = pd.read_pickle(outfile)
|
||||
df = df[[ 'CO (ppm)', 'Humidity (%r.h.)', 'Temperature (C)',
|
||||
'Flow rate (mL/min)', 'Heater voltage (V)', 'R1 (MOhm)']]
|
||||
df = df.resample('0.3S').first()
|
||||
return df
|
||||
|
||||
|
||||
@@ -236,23 +244,23 @@ def get_current_timeseries(
|
||||
if not outfile.exists():
|
||||
|
||||
files = [
|
||||
"IMOS_ANMN-WA_AETVZ_20090715T080000Z_WATR20_FV01_WATR20-0907-Continental-194_END-20090716T181317Z_C-20191122T052830Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20100409T080000Z_WATR20_FV01_WATR20-1004-Continental-194_END-20100430T084500Z_C-20191122T053845Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20101222T080000Z_WATR20_FV01_WATR20-1012-Continental-194_END-20110518T051500Z_C-20200916T020035Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20090715T080000Z_WATR20_FV01_WATR20-0907-Continental-194_END-20090716T181317Z_C-20191122T052830Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20100409T080000Z_WATR20_FV01_WATR20-1004-Continental-194_END-20100430T084500Z_C-20191122T053845Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20101222T080000Z_WATR20_FV01_WATR20-1012-Continental-194_END-20110518T051500Z_C-20200916T020035Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20110608T080000Z_WATR20_FV01_WATR20-1106-Continental-194_END-20111122T035000Z_C-20200916T025619Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20111221T060300Z_WATR20_FV01_WATR20-1112-Continental-194_END-20120704T050500Z_C-20200916T043212Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20120726T044000Z_WATR20_FV01_WATR20-1207-Continental-194_END-20130204T044000Z_C-20200916T032027Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20130221T080000Z_WATR20_FV01_WATR20-1302-Continental-194_END-20131003T035000Z_C-20180529T020609Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20131111T080000Z_WATR20_FV01_WATR20-1311-Continental-194_END-20140519T035000Z_C-20200114T033335Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20140710T080000Z_WATR20_FV01_WATR20-1407-Continental-194_END-20150121T021500Z_C-20180529T055902Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20150213T080000Z_WATR20_FV01_WATR20-1502-Continental-194_END-20150424T134002Z_C-20200114T035347Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20150914T080000Z_WATR20_FV01_WATR20-1509-Continental-194_END-20160331T043000Z_C-20180601T013623Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20160427T080000Z_WATR20_FV01_WATR20-1604-Continental-194_END-20160531T021800Z_C-20180531T071709Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20140710T080000Z_WATR20_FV01_WATR20-1407-Continental-194_END-20150121T021500Z_C-20180529T055902Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20150213T080000Z_WATR20_FV01_WATR20-1502-Continental-194_END-20150424T134002Z_C-20200114T035347Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20150914T080000Z_WATR20_FV01_WATR20-1509-Continental-194_END-20160331T043000Z_C-20180601T013623Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20160427T080000Z_WATR20_FV01_WATR20-1604-Continental-194_END-20160531T021800Z_C-20180531T071709Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20170512T080000Z_WATR20_FV01_WATR20-1705-Continental-194_END-20170717T014558Z_C-20190805T004647Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20171204T080000Z_WATR20_FV01_WATR20-1712-Continental-194_END-20180618T030000Z_C-20180620T233149Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20180802T080000Z_WATR20_FV01_WATR20-1807-Continental-194_END-20190225T054500Z_C-20190227T001343Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20190307T080000Z_WATR20_FV01_WATR20-1903-Continental-194_END-20190911T003144Z_C-20200114T045053Z.nc",
|
||||
"IMOS_ANMN-WA_AETVZ_20190926T080000Z_WATR20_FV01_WATR20-1909-Continental-194_END-20200326T030000Z_C-20200420T064334Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20171204T080000Z_WATR20_FV01_WATR20-1712-Continental-194_END-20180618T030000Z_C-20180620T233149Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20180802T080000Z_WATR20_FV01_WATR20-1807-Continental-194_END-20190225T054500Z_C-20190227T001343Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20190307T080000Z_WATR20_FV01_WATR20-1903-Continental-194_END-20190911T003144Z_C-20200114T045053Z.nc",
|
||||
# "IMOS_ANMN-WA_AETVZ_20190926T080000Z_WATR20_FV01_WATR20-1909-Continental-194_END-20200326T030000Z_C-20200420T064334Z.nc",
|
||||
]
|
||||
base = "http://thredds.aodn.org.au/thredds/fileServer/IMOS/ANMN/WA/WATR20/Velocity/"
|
||||
|
||||
@@ -262,15 +270,21 @@ def get_current_timeseries(
|
||||
# load and merge
|
||||
xds = [xr.open_dataset(cache_folder / f) for f in files]
|
||||
vars = [
|
||||
'VCUR', 'UCUR', 'WCUR', 'TEMP', 'PRES_REL', 'DEPTH', 'ROLL',
|
||||
'VCUR', 'VCUR_quality_control', 'UCUR', 'UCUR_quality_control', 'WCUR', 'WCUR_quality_control', 'TEMP', 'TEMP_quality_control', 'PRES_REL', 'PRES_REL_quality_control', 'DEPTH', 'DEPTH_quality_control', 'ROLL',
|
||||
'PITCH'
|
||||
]
|
||||
xds2 = [x[vars].isel(HEIGHT_ABOVE_SENSOR=18) for x in xds]
|
||||
xd = xr.concat(xds2, dim='TIME')
|
||||
xd = xd.where(xd.DEPTH > 150) # remove outliers
|
||||
xd = xd.where(
|
||||
(xd.DEPTH > 150) & (xd.VCUR_quality_control < 2) & (xd.UCUR_quality_control < 2) &
|
||||
(xd.PRES_REL_quality_control < 2) &
|
||||
(xd.TEMP_quality_control < 2)
|
||||
) # remove bad data
|
||||
|
||||
xd['TIME'] = xd['TIME'].dt.round('10T')
|
||||
xd = xd.dropna(dim='TIME', subset=['VCUR', 'UCUR', 'WCUR'])
|
||||
xd['SPD'] = np.sqrt(xd.VCUR**2 + xd.UCUR**2)
|
||||
# xd = xd.resample(TIME='10T').first() # slow
|
||||
|
||||
# Generate tidal freqs
|
||||
t = xd.TIME.to_series()
|
||||
@@ -280,6 +294,8 @@ def get_current_timeseries(
|
||||
xd = xd.merge(df_eta)
|
||||
|
||||
dset_to_nc(xd, outfile)
|
||||
else:
|
||||
logger.debug(f'Using cached file "{outfile}"')
|
||||
return outfile
|
||||
|
||||
|
||||
@@ -301,16 +317,29 @@ class IMOSCurrentsVel(RegressionForecastData):
|
||||
'MK3', 'MM', 'SSA', 'SA'
|
||||
]
|
||||
|
||||
def clear_cache(self):
|
||||
super().clear_cache()
|
||||
cache_file2 = self.datasets_root / 'MOS_ANMN-WA_AETVZ_WATR20_FV01_WATR20-1909-Continental-194_currents.nc'
|
||||
print(f'rm {cache_file2}')
|
||||
os.remove(cache_file2)
|
||||
|
||||
def download(self):
|
||||
outfile = self.datasets_root / 'MOS_ANMN-WA_AETVZ_WATR20_FV01_WATR20-1909-Continental-194_currents.nc'
|
||||
get_current_timeseries(outfile=outfile)
|
||||
|
||||
# made in previous notebook
|
||||
xd = xr.load_dataset(outfile)
|
||||
df = xd.to_dataframe().drop(
|
||||
columns=['HEIGHT_ABOVE_SENSOR', 'NOMINAL_DEPTH'])
|
||||
df = xd.to_dataframe()
|
||||
df['SPD'] = np.sqrt(df.VCUR**2 + df.UCUR**2)
|
||||
df = df[['VCUR', 'UCUR', 'WCUR', 'TEMP', 'DEPTH', 'M2',
|
||||
'S2', 'N2', 'K2', 'K1', 'O1', 'P1', 'Q1', 'M4', 'M6', 'S4', 'MK3', 'MM',
|
||||
'SSA', 'SA', 'SPD']]
|
||||
df.dropna(subset=self.columns_target, inplace=True)
|
||||
df = df.resample('30T').first().loc['2011':'2015-03']
|
||||
|
||||
# Only keep parts with at most 5 nans in last 48 periods
|
||||
has_past = df.SPD.isna().rolling(48).sum()<5
|
||||
df = df[has_past]
|
||||
|
||||
df = df.resample('10T').first()
|
||||
|
||||
return df
|
||||
|
||||
@@ -37,6 +37,7 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
|
||||
self.window_past = window_past
|
||||
self.window_future = window_future
|
||||
self.columns_target = columns_target
|
||||
self.columns_past = columns_past
|
||||
|
||||
# For speed
|
||||
self._icol_blank = [df.drop(columns = columns_target).columns.tolist().index(n) for n in columns_past]
|
||||
@@ -84,6 +85,9 @@ class Seq2SeqDataSet(torch.utils.data.Dataset):
|
||||
"""
|
||||
Output pandas dataframes for display purposes.
|
||||
"""
|
||||
if i<0:
|
||||
# Handle negative integers
|
||||
i = len(self)+i
|
||||
x_cols = list(self.df.drop(columns=self.columns_target).columns) + ['tsp_days', 'is_past']
|
||||
x_past, y_past, x_future, y_future = self.get_components(i)
|
||||
t_past = self.df.index[i:i+self.window_past]
|
||||
|
||||
@@ -13,7 +13,17 @@ def normalize_encode_dataframe(df, encoder=OrdinalEncoder):
|
||||
df_norm = scaler.fit_transform(df)
|
||||
return df_norm, scaler
|
||||
|
||||
def timeseries_split(df, test_fraction=0.2):
|
||||
def timeseries_split(df, test_fraction=0.2, dropna=None):
|
||||
"""Split timeseries data with test in the future"""
|
||||
i = int(len(df)*test_fraction)
|
||||
return df.iloc[:-i], df.iloc[-i:]
|
||||
|
||||
# If there are lots of nan's we can ignore them when splitting into portions
|
||||
if isinstance(dropna, list):
|
||||
index = df.dropna(subset=dropna).index
|
||||
elif dropna is True:
|
||||
index = df.dropna().index
|
||||
else:
|
||||
index = df.index
|
||||
|
||||
i = int(len(index)*test_fraction)
|
||||
dt = index.values[i]
|
||||
return df.loc[:dt], df.loc[dt:]
|
||||
|
||||
Reference in New Issue
Block a user