diff --git a/README.md b/README.md
index a0a96ad..02071b1 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,10 @@ seq2seq-time
Using sequence to sequence interfaces for timeseries regression
+
+
+
+
Project Organization
------------
@@ -10,20 +14,13 @@ Project Organization
├── Makefile <- Makefile with commands like `make data` or `make train`
├── README.md <- The top-level README for developers using this project.
├── data
- │ ├── external <- Data from third party sources.
│ ├── interim <- Intermediate data that has been transformed.
│ ├── processed <- The final, canonical data sets for modeling.
│ └── raw <- The original, immutable data dump.
│
- ├── docs <- A default Sphinx project; see sphinx-doc.org for details
- │
- ├── models <- Trained and serialized models, model predictions, or model summaries
- │
├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
│ the creator's initials, and a short `-` delimited description, e.g.
- │ `1.0-jqp-initial-data-exploration`.
- │
- ├── references <- Data dictionaries, manuals, and all other explanatory materials.
+ │ `1.0-jqp-initial-data-exploratio │
│
├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
│ └── figures <- Generated graphics and figures to be used in reporting
@@ -33,21 +30,6 @@ Project Organization
│
├── setup.py <- makes project pip installable (pip install -e .) so src can be imported
├── seq2seq_time <- Source code for use in this project.
- │ ├── __init__.py <- Makes src a Python module
- │ │
- │ ├── data <- Scripts to download or generate data
- │ │ └── make_dataset.py
- │ │
- │ ├── features <- Scripts to turn raw data into features for modeling
- │ │ └── build_features.py
- │ │
- │ ├── models <- Scripts to train models and then use trained models to make
- │ │ │ predictions
- │ │ ├── predict_model.py
- │ │ └── train_model.py
- │ │
- │ └── visualization <- Scripts to create exploratory and results oriented visualizations
- │ └── visualize.py
│
└── tox.ini <- tox file with settings for running tox; see tox.readthedocs.io
diff --git a/docs/Makefile b/docs/Makefile
deleted file mode 100644
index 9ad0c23..0000000
--- a/docs/Makefile
+++ /dev/null
@@ -1,153 +0,0 @@
-# Makefile for Sphinx documentation
-#
-
-# You can set these variables from the command line.
-SPHINXOPTS =
-SPHINXBUILD = sphinx-build
-PAPER =
-BUILDDIR = _build
-
-# Internal variables.
-PAPEROPT_a4 = -D latex_paper_size=a4
-PAPEROPT_letter = -D latex_paper_size=letter
-ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
-# the i18n builder cannot share the environment and doctrees with the others
-I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
-
-.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
-
-help:
- @echo "Please use \`make ' where is one of"
- @echo " html to make standalone HTML files"
- @echo " dirhtml to make HTML files named index.html in directories"
- @echo " singlehtml to make a single large HTML file"
- @echo " pickle to make pickle files"
- @echo " json to make JSON files"
- @echo " htmlhelp to make HTML files and a HTML help project"
- @echo " qthelp to make HTML files and a qthelp project"
- @echo " devhelp to make HTML files and a Devhelp project"
- @echo " epub to make an epub"
- @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
- @echo " latexpdf to make LaTeX files and run them through pdflatex"
- @echo " text to make text files"
- @echo " man to make manual pages"
- @echo " texinfo to make Texinfo files"
- @echo " info to make Texinfo files and run them through makeinfo"
- @echo " gettext to make PO message catalogs"
- @echo " changes to make an overview of all changed/added/deprecated items"
- @echo " linkcheck to check all external links for integrity"
- @echo " doctest to run all doctests embedded in the documentation (if enabled)"
-
-clean:
- -rm -rf $(BUILDDIR)/*
-
-html:
- $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
- @echo
- @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
-
-dirhtml:
- $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
- @echo
- @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
-
-singlehtml:
- $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
- @echo
- @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
-
-pickle:
- $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
- @echo
- @echo "Build finished; now you can process the pickle files."
-
-json:
- $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
- @echo
- @echo "Build finished; now you can process the JSON files."
-
-htmlhelp:
- $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
- @echo
- @echo "Build finished; now you can run HTML Help Workshop with the" \
- ".hhp project file in $(BUILDDIR)/htmlhelp."
-
-qthelp:
- $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
- @echo
- @echo "Build finished; now you can run "qcollectiongenerator" with the" \
- ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
- @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/seq2seq-time.qhcp"
- @echo "To view the help file:"
- @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/seq2seq-time.qhc"
-
-devhelp:
- $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
- @echo
- @echo "Build finished."
- @echo "To view the help file:"
- @echo "# mkdir -p $$HOME/.local/share/devhelp/seq2seq-time"
- @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/seq2seq-time"
- @echo "# devhelp"
-
-epub:
- $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
- @echo
- @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
-
-latex:
- $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
- @echo
- @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
- @echo "Run \`make' in that directory to run these through (pdf)latex" \
- "(use \`make latexpdf' here to do that automatically)."
-
-latexpdf:
- $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
- @echo "Running LaTeX files through pdflatex..."
- $(MAKE) -C $(BUILDDIR)/latex all-pdf
- @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
-
-text:
- $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
- @echo
- @echo "Build finished. The text files are in $(BUILDDIR)/text."
-
-man:
- $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
- @echo
- @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
-
-texinfo:
- $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
- @echo
- @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
- @echo "Run \`make' in that directory to run these through makeinfo" \
- "(use \`make info' here to do that automatically)."
-
-info:
- $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
- @echo "Running Texinfo files through makeinfo..."
- make -C $(BUILDDIR)/texinfo info
- @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
-
-gettext:
- $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
- @echo
- @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
-
-changes:
- $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
- @echo
- @echo "The overview file is in $(BUILDDIR)/changes."
-
-linkcheck:
- $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
- @echo
- @echo "Link check complete; look for any errors in the above output " \
- "or in $(BUILDDIR)/linkcheck/output.txt."
-
-doctest:
- $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
- @echo "Testing of doctests in the sources finished, look at the " \
- "results in $(BUILDDIR)/doctest/output.txt."
diff --git a/docs/commands.rst b/docs/commands.rst
deleted file mode 100644
index 2d162f3..0000000
--- a/docs/commands.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-Commands
-========
-
-The Makefile contains the central entry points for common tasks related to this project.
-
-Syncing data to S3
-^^^^^^^^^^^^^^^^^^
-
-* `make sync_data_to_s3` will use `aws s3 sync` to recursively sync files in `data/` up to `s3://[OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')/data/`.
-* `make sync_data_from_s3` will use `aws s3 sync` to recursively sync files from `s3://[OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')/data/` to `data/`.
diff --git a/docs/conf.py b/docs/conf.py
deleted file mode 100644
index 6669817..0000000
--- a/docs/conf.py
+++ /dev/null
@@ -1,244 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# seq2seq-time documentation build configuration file, created by
-# sphinx-quickstart.
-#
-# This file is execfile()d with the current directory set to its containing dir.
-#
-# Note that not all possible configuration values are present in this
-# autogenerated file.
-#
-# All configuration values have a default; values that are commented out
-# serve to show the default.
-
-import os
-import sys
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-# sys.path.insert(0, os.path.abspath('.'))
-
-# -- General configuration -----------------------------------------------------
-
-# If your documentation needs a minimal Sphinx version, state it here.
-# needs_sphinx = '1.0'
-
-# Add any Sphinx extension module names here, as strings. They can be extensions
-# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
-extensions = []
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
-
-# The suffix of source filenames.
-source_suffix = '.rst'
-
-# The encoding of source files.
-# source_encoding = 'utf-8-sig'
-
-# The master toctree document.
-master_doc = 'index'
-
-# General information about the project.
-project = u'seq2seq-time'
-
-# The version info for the project you're documenting, acts as replacement for
-# |version| and |release|, also used in various other places throughout the
-# built documents.
-#
-# The short X.Y version.
-version = '0.1'
-# The full version, including alpha/beta/rc tags.
-release = '0.1'
-
-# The language for content autogenerated by Sphinx. Refer to documentation
-# for a list of supported languages.
-# language = None
-
-# There are two options for replacing |today|: either, you set today to some
-# non-false value, then it is used:
-# today = ''
-# Else, today_fmt is used as the format for a strftime call.
-# today_fmt = '%B %d, %Y'
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-exclude_patterns = ['_build']
-
-# The reST default role (used for this markup: `text`) to use for all documents.
-# default_role = None
-
-# If true, '()' will be appended to :func: etc. cross-reference text.
-# add_function_parentheses = True
-
-# If true, the current module name will be prepended to all description
-# unit titles (such as .. function::).
-# add_module_names = True
-
-# If true, sectionauthor and moduleauthor directives will be shown in the
-# output. They are ignored by default.
-# show_authors = False
-
-# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
-
-# A list of ignored prefixes for module index sorting.
-# modindex_common_prefix = []
-
-
-# -- Options for HTML output ---------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-html_theme = 'default'
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-# html_theme_options = {}
-
-# Add any paths that contain custom themes here, relative to this directory.
-# html_theme_path = []
-
-# The name for this set of Sphinx documents. If None, it defaults to
-# " v documentation".
-# html_title = None
-
-# A shorter title for the navigation bar. Default is the same as html_title.
-# html_short_title = None
-
-# The name of an image file (relative to this directory) to place at the top
-# of the sidebar.
-# html_logo = None
-
-# The name of an image file (within the static path) to use as favicon of the
-# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
-# pixels large.
-# html_favicon = None
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
-
-# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
-# using the given strftime format.
-# html_last_updated_fmt = '%b %d, %Y'
-
-# If true, SmartyPants will be used to convert quotes and dashes to
-# typographically correct entities.
-# html_use_smartypants = True
-
-# Custom sidebar templates, maps document names to template names.
-# html_sidebars = {}
-
-# Additional templates that should be rendered to pages, maps page names to
-# template names.
-# html_additional_pages = {}
-
-# If false, no module index is generated.
-# html_domain_indices = True
-
-# If false, no index is generated.
-# html_use_index = True
-
-# If true, the index is split into individual pages for each letter.
-# html_split_index = False
-
-# If true, links to the reST sources are added to the pages.
-# html_show_sourcelink = True
-
-# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-# html_show_sphinx = True
-
-# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-# html_show_copyright = True
-
-# If true, an OpenSearch description file will be output, and all pages will
-# contain a tag referring to it. The value of this option must be the
-# base URL from which the finished HTML is served.
-# html_use_opensearch = ''
-
-# This is the file name suffix for HTML files (e.g. ".xhtml").
-# html_file_suffix = None
-
-# Output file base name for HTML help builder.
-htmlhelp_basename = 'seq2seq-timedoc'
-
-
-# -- Options for LaTeX output --------------------------------------------------
-
-latex_elements = {
- # The paper size ('letterpaper' or 'a4paper').
- # 'papersize': 'letterpaper',
-
- # The font size ('10pt', '11pt' or '12pt').
- # 'pointsize': '10pt',
-
- # Additional stuff for the LaTeX preamble.
- # 'preamble': '',
-}
-
-# Grouping the document tree into LaTeX files. List of tuples
-# (source start file, target name, title, author, documentclass [howto/manual]).
-latex_documents = [
- ('index',
- 'seq2seq-time.tex',
- u'seq2seq-time Documentation',
- u"3springs", 'manual'),
-]
-
-# The name of an image file (relative to this directory) to place at the top of
-# the title page.
-# latex_logo = None
-
-# For "manual" documents, if this is true, then toplevel headings are parts,
-# not chapters.
-# latex_use_parts = False
-
-# If true, show page references after internal links.
-# latex_show_pagerefs = False
-
-# If true, show URL addresses after external links.
-# latex_show_urls = False
-
-# Documents to append as an appendix to all manuals.
-# latex_appendices = []
-
-# If false, no module index is generated.
-# latex_domain_indices = True
-
-
-# -- Options for manual page output --------------------------------------------
-
-# One entry per manual page. List of tuples
-# (source start file, name, description, authors, manual section).
-man_pages = [
- ('index', 'seq2seq-time', u'seq2seq-time Documentation',
- [u"3springs"], 1)
-]
-
-# If true, show URL addresses after external links.
-# man_show_urls = False
-
-
-# -- Options for Texinfo output ------------------------------------------------
-
-# Grouping the document tree into Texinfo files. List of tuples
-# (source start file, target name, title, author,
-# dir menu entry, description, category)
-texinfo_documents = [
- ('index', 'seq2seq-time', u'seq2seq-time Documentation',
- u"3springs", 'seq2seq-time',
- 'Using sequence to sequence interfaces for timeseries regression', 'Miscellaneous'),
-]
-
-# Documents to append as an appendix to all manuals.
-# texinfo_appendices = []
-
-# If false, no module index is generated.
-# texinfo_domain_indices = True
-
-# How to display URL addresses: 'footnote', 'no', or 'inline'.
-# texinfo_show_urls = 'footnote'
diff --git a/docs/getting-started.rst b/docs/getting-started.rst
deleted file mode 100644
index b4f71c3..0000000
--- a/docs/getting-started.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-Getting started
-===============
-
-This is where you describe how to get set up on a clean install, including the
-commands necessary to get the raw data (using the `sync_data_from_s3` command,
-for example), and then how to make the cleaned, final data sets.
diff --git a/docs/index.rst b/docs/index.rst
deleted file mode 100644
index b48dd13..0000000
--- a/docs/index.rst
+++ /dev/null
@@ -1,24 +0,0 @@
-.. seq2seq-time documentation master file, created by
- sphinx-quickstart.
- You can adapt this file completely to your liking, but it should at least
- contain the root `toctree` directive.
-
-seq2seq-time documentation!
-==============================================
-
-Contents:
-
-.. toctree::
- :maxdepth: 2
-
- getting-started
- commands
-
-
-
-Indices and tables
-==================
-
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index 0e14b1d..0000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,190 +0,0 @@
-@ECHO OFF
-
-REM Command file for Sphinx documentation
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set BUILDDIR=_build
-set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
-set I18NSPHINXOPTS=%SPHINXOPTS% .
-if NOT "%PAPER%" == "" (
- set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
- set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
-)
-
-if "%1" == "" goto help
-
-if "%1" == "help" (
- :help
- echo.Please use `make ^` where ^ is one of
- echo. html to make standalone HTML files
- echo. dirhtml to make HTML files named index.html in directories
- echo. singlehtml to make a single large HTML file
- echo. pickle to make pickle files
- echo. json to make JSON files
- echo. htmlhelp to make HTML files and a HTML help project
- echo. qthelp to make HTML files and a qthelp project
- echo. devhelp to make HTML files and a Devhelp project
- echo. epub to make an epub
- echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
- echo. text to make text files
- echo. man to make manual pages
- echo. texinfo to make Texinfo files
- echo. gettext to make PO message catalogs
- echo. changes to make an overview over all changed/added/deprecated items
- echo. linkcheck to check all external links for integrity
- echo. doctest to run all doctests embedded in the documentation if enabled
- goto end
-)
-
-if "%1" == "clean" (
- for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
- del /q /s %BUILDDIR%\*
- goto end
-)
-
-if "%1" == "html" (
- %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/html.
- goto end
-)
-
-if "%1" == "dirhtml" (
- %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
- goto end
-)
-
-if "%1" == "singlehtml" (
- %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
- goto end
-)
-
-if "%1" == "pickle" (
- %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the pickle files.
- goto end
-)
-
-if "%1" == "json" (
- %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the JSON files.
- goto end
-)
-
-if "%1" == "htmlhelp" (
- %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run HTML Help Workshop with the ^
-.hhp project file in %BUILDDIR%/htmlhelp.
- goto end
-)
-
-if "%1" == "qthelp" (
- %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run "qcollectiongenerator" with the ^
-.qhcp project file in %BUILDDIR%/qthelp, like this:
- echo.^> qcollectiongenerator %BUILDDIR%\qthelp\seq2seq-time.qhcp
- echo.To view the help file:
- echo.^> assistant -collectionFile %BUILDDIR%\qthelp\seq2seq-time.ghc
- goto end
-)
-
-if "%1" == "devhelp" (
- %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished.
- goto end
-)
-
-if "%1" == "epub" (
- %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The epub file is in %BUILDDIR%/epub.
- goto end
-)
-
-if "%1" == "latex" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "text" (
- %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The text files are in %BUILDDIR%/text.
- goto end
-)
-
-if "%1" == "man" (
- %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The manual pages are in %BUILDDIR%/man.
- goto end
-)
-
-if "%1" == "texinfo" (
- %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
- goto end
-)
-
-if "%1" == "gettext" (
- %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
- goto end
-)
-
-if "%1" == "changes" (
- %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
- if errorlevel 1 exit /b 1
- echo.
- echo.The overview file is in %BUILDDIR%/changes.
- goto end
-)
-
-if "%1" == "linkcheck" (
- %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
- if errorlevel 1 exit /b 1
- echo.
- echo.Link check complete; look for any errors in the above output ^
-or in %BUILDDIR%/linkcheck/output.txt.
- goto end
-)
-
-if "%1" == "doctest" (
- %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
- if errorlevel 1 exit /b 1
- echo.
- echo.Testing of doctests in the sources finished, look at the ^
-results in %BUILDDIR%/doctest/output.txt.
- goto end
-)
-
-:end
diff --git a/models/.gitkeep b/models/.gitkeep
deleted file mode 100644
index e69de29..0000000
diff --git a/notebooks/05.0-mc-leaderboard.ipynb b/notebooks/05.0-mc-leaderboard.ipynb
index 9fecb98..b858019 100644
--- a/notebooks/05.0-mc-leaderboard.ipynb
+++ b/notebooks/05.0-mc-leaderboard.ipynb
@@ -29,8 +29,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:52.845608Z",
- "start_time": "2020-10-26T22:42:52.384261Z"
+ "end_time": "2020-10-27T09:14:05.383922Z",
+ "start_time": "2020-10-27T09:14:04.747253Z"
}
},
"outputs": [],
@@ -52,8 +52,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:53.995538Z",
- "start_time": "2020-10-26T22:42:52.849660Z"
+ "end_time": "2020-10-27T09:14:06.880212Z",
+ "start_time": "2020-10-27T09:14:05.387731Z"
}
},
"outputs": [],
@@ -83,8 +83,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:54.025883Z",
- "start_time": "2020-10-26T22:42:53.999469Z"
+ "end_time": "2020-10-27T09:14:06.924470Z",
+ "start_time": "2020-10-27T09:14:06.886088Z"
}
},
"outputs": [],
@@ -98,8 +98,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:54.581011Z",
- "start_time": "2020-10-26T22:42:54.032410Z"
+ "end_time": "2020-10-27T09:14:07.738191Z",
+ "start_time": "2020-10-27T09:14:06.928650Z"
}
},
"outputs": [
@@ -123,8 +123,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:54.616176Z",
- "start_time": "2020-10-26T22:42:54.584688Z"
+ "end_time": "2020-10-27T09:14:07.777905Z",
+ "start_time": "2020-10-27T09:14:07.747671Z"
}
},
"outputs": [
@@ -147,8 +147,8 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:57.778809Z",
- "start_time": "2020-10-26T22:42:54.620588Z"
+ "end_time": "2020-10-27T09:14:11.913203Z",
+ "start_time": "2020-10-27T09:14:07.781404Z"
},
"lines_to_next_cell": 2
},
@@ -1654,8 +1654,8 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:57.871365Z",
- "start_time": "2020-10-26T22:42:57.782486Z"
+ "end_time": "2020-10-27T09:14:12.003600Z",
+ "start_time": "2020-10-27T09:14:11.917828Z"
}
},
"outputs": [
@@ -1701,8 +1701,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:57.935809Z",
- "start_time": "2020-10-26T22:42:57.875189Z"
+ "end_time": "2020-10-27T09:14:12.090684Z",
+ "start_time": "2020-10-27T09:14:12.009300Z"
},
"lines_to_end_of_cell_marker": 2,
"lines_to_next_cell": 0
@@ -1783,8 +1783,8 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:57.989157Z",
- "start_time": "2020-10-26T22:42:57.939518Z"
+ "end_time": "2020-10-27T09:14:12.148553Z",
+ "start_time": "2020-10-27T09:14:12.095302Z"
}
},
"outputs": [],
@@ -1813,8 +1813,8 @@
"execution_count": 10,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.053266Z",
- "start_time": "2020-10-26T22:42:57.996130Z"
+ "end_time": "2020-10-27T09:14:12.208328Z",
+ "start_time": "2020-10-27T09:14:12.153093Z"
}
},
"outputs": [],
@@ -1866,8 +1866,8 @@
"execution_count": 11,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.102340Z",
- "start_time": "2020-10-26T22:42:58.057346Z"
+ "end_time": "2020-10-27T09:14:12.275835Z",
+ "start_time": "2020-10-27T09:14:12.212068Z"
},
"lines_to_next_cell": 2
},
@@ -1884,8 +1884,8 @@
"execution_count": 12,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.187861Z",
- "start_time": "2020-10-26T22:42:58.105923Z"
+ "end_time": "2020-10-27T09:14:12.488369Z",
+ "start_time": "2020-10-27T09:14:12.279976Z"
},
"lines_to_end_of_cell_marker": 2,
"lines_to_next_cell": 0
@@ -1912,8 +1912,8 @@
"execution_count": 13,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.237545Z",
- "start_time": "2020-10-26T22:42:58.191360Z"
+ "end_time": "2020-10-27T09:14:12.557082Z",
+ "start_time": "2020-10-27T09:14:12.492776Z"
}
},
"outputs": [],
@@ -1931,8 +1931,8 @@
"execution_count": 14,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.291410Z",
- "start_time": "2020-10-26T22:42:58.243410Z"
+ "end_time": "2020-10-27T09:14:12.620178Z",
+ "start_time": "2020-10-27T09:14:12.565956Z"
}
},
"outputs": [],
@@ -1980,8 +1980,8 @@
"execution_count": 15,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.653137Z",
- "start_time": "2020-10-26T22:42:58.294747Z"
+ "end_time": "2020-10-27T09:14:13.442957Z",
+ "start_time": "2020-10-27T09:14:12.623794Z"
}
},
"outputs": [
@@ -2012,8 +2012,8 @@
"execution_count": 16,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.706378Z",
- "start_time": "2020-10-26T22:42:58.656467Z"
+ "end_time": "2020-10-27T09:14:13.514463Z",
+ "start_time": "2020-10-27T09:14:13.447612Z"
}
},
"outputs": [
@@ -2047,8 +2047,8 @@
"execution_count": 17,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.759894Z",
- "start_time": "2020-10-26T22:42:58.709607Z"
+ "end_time": "2020-10-27T09:14:13.586379Z",
+ "start_time": "2020-10-27T09:14:13.520269Z"
}
},
"outputs": [],
@@ -2062,8 +2062,8 @@
"execution_count": 18,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-10-26T22:42:58.814430Z",
- "start_time": "2020-10-26T22:42:58.763436Z"
+ "end_time": "2020-10-27T09:14:13.653466Z",
+ "start_time": "2020-10-27T09:14:13.590276Z"
}
},
"outputs": [],
@@ -2073,14 +2073,55 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"metadata": {
"ExecuteTime": {
- "start_time": "2020-10-26T22:42:52.500Z"
+ "end_time": "2020-10-27T09:14:32.606892Z",
+ "start_time": "2020-10-27T09:14:13.657687Z"
+ },
+ "lines_to_next_cell": 2
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using downloaded and verified file: ../data/processed/gas-sensor-array-temperature-modulation.zip\n"
+ ]
+ }
+ ],
+ "source": [
+ "for Dataset in datasets:\n",
+ " dataset_name = Dataset.__name__\n",
+ " dataset = Dataset(datasets_root)\n",
+ " ds_train, ds_test = dataset.to_datasets(window_past=window_past,\n",
+ " window_future=window_future)\n",
+ "\n",
+ " # Init data\n",
+ " x_past, y_past, x_future, y_future = ds_train.get_rows(10)\n",
+ " input_size = x_past.shape[-1]\n",
+ " output_size = y_future.shape[-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-27T13:01:42.356787Z",
+ "start_time": "2020-10-27T09:14:32.611289Z"
},
"scrolled": true
},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
+ " and should_run_async(code)\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -2121,7 +2162,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "9a8acd574c114904bec713587ac683e9",
+ "model_id": "e24c84a5e5124d0480671ce971e2e2f2",
"version_major": 2,
"version_minor": 0
},
@@ -2131,6 +2172,11795 @@
},
"metadata": {},
"output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel BaselineLast\n",
+ "mean_NLL 1.63\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " nll \n",
+ " 1.633936 \n",
+ " \n",
+ " \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast\n",
+ "IMOSCurrentsVel nll 1.633936\n",
+ " rmse 0.257789\n",
+ " smape 0.285106"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 109 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel RANP\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5351bdc826d54a5e851616e3b22b3c8a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel RANP\n",
+ "mean_NLL 23.31\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145\n",
+ " smape 0.285106 0.707343\n",
+ " nll 1.633936 23.305540"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 139 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel LSTM\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f915900dd6af45a4a425735e54fca888",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel LSTM\n",
+ "mean_NLL 19.44\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141\n",
+ " smape 0.285106 0.707343 0.719748\n",
+ " nll 1.633936 23.305540 19.444048"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 114 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel LSTMSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3e08216015f94a93b4da7512a2abd5ec",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel LSTMSeq2Seq\n",
+ "mean_NLL 14.52\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482\n",
+ " smape 0.285106 0.707343 0.719748 0.716110\n",
+ " nll 1.633936 23.305540 19.444048 14.519804"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel TransformerSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f5756cbf9b1e464fa4208e9b7e1d27d1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel TransformerSeq2Seq\n",
+ "mean_NLL 46.98\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "\n",
+ " TransformerSeq2Seq \n",
+ "IMOSCurrentsVel rmse 0.579640 \n",
+ " smape 0.769928 \n",
+ " nll 46.982059 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel Transformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e204c6959c234995848d7edb4ec4bad4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n",
+ " self.run_sanity_check(self.get_model())\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 648, in run_sanity_check\n",
+ " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 568, in run_evaluation\n",
+ " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n",
+ " output = self.trainer.accelerator_backend.validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n",
+ " output = self.__validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n",
+ " output = self.trainer.model.validation_step(*args)\n",
+ " File \"\", line 30, in validation_step\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 13, in forward\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ " outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 181, in forward\n",
+ " output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 294, in forward\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/activation.py\", line 927, in forward\n",
+ " attn_mask=attn_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/functional.py\", line 4049, in multi_head_attention_forward\n",
+ " raise RuntimeError('The size of the 2D attn_mask is not correct.')\n",
+ "RuntimeError: The size of the 2D attn_mask is not correct.\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel TransformerProcess\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f87dbe3e2abb41438b236e72ea670946",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 5: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=80.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IMOSCurrentsVel TransformerProcess\n",
+ "mean_NLL 7.35\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 BaselineLast\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "94926921708c42beb06ac4df4014d9e0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 14: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 17: reducing learning rate of group 0 to 3.0000e-06.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 BaselineLast\n",
+ "mean_NLL 1.71\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " nll \n",
+ " 1.707181 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 nll 1.707181 NaN NaN NaN \n",
+ " rmse 1.383076 NaN NaN NaN \n",
+ " smape 0.278300 NaN NaN NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 nll NaN NaN \n",
+ " rmse NaN NaN \n",
+ " smape NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 107 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 107 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 RANP\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e66acd8f065c48cd931774bb3f223d7c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 RANP\n",
+ "mean_NLL 1.48\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 NaN NaN \n",
+ " smape 0.278300 0.226529 NaN NaN \n",
+ " nll 1.707181 1.478179 NaN NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 135 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 135 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 LSTM\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "210cbc1f0e574455853c30ff4c593869",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 LSTM\n",
+ "mean_NLL 1.41\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 NaN \n",
+ " smape 0.278300 0.226529 0.212408 NaN \n",
+ " nll 1.707181 1.478179 1.407227 NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 108 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 108 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 LSTMSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0fe558affae4417a8e618dd3e0f61faa",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 LSTMSeq2Seq\n",
+ "mean_NLL 1.39\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 TransformerSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0b7ac22c7eb2417aa7cc7e5da0ad885f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 TransformerSeq2Seq\n",
+ "mean_NLL 2.86\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 NaN \n",
+ " smape 0.232141 NaN \n",
+ " nll 2.859907 NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 Transformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3345032e42cb46829a394cd274fbe103",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n",
+ " self.run_sanity_check(self.get_model())\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 648, in run_sanity_check\n",
+ " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 568, in run_evaluation\n",
+ " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n",
+ " output = self.trainer.accelerator_backend.validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n",
+ " output = self.__validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n",
+ " output = self.trainer.model.validation_step(*args)\n",
+ " File \"\", line 30, in validation_step\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 13, in forward\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ " outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 181, in forward\n",
+ " output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 294, in forward\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/activation.py\", line 927, in forward\n",
+ " attn_mask=attn_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/functional.py\", line 4049, in multi_head_attention_forward\n",
+ " raise RuntimeError('The size of the 2D attn_mask is not correct.')\n",
+ "RuntimeError: The size of the 2D attn_mask is not correct.\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 TransformerProcess\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "44213affb81f49a0b565a0d57d98e57b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=34.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "BejingPM25 TransformerProcess\n",
+ "mean_NLL 1.44\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using downloaded and verified file: ../data/processed/gas-sensor-array-temperature-modulation.zip\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor BaselineLast\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4d75180d4dfa4927b1334b85ab358881",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 8: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=231.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor BaselineLast\n",
+ "mean_NLL 1.88\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " nll \n",
+ " 1.879850 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "GasSensor nll 1.879850 NaN NaN NaN \n",
+ " rmse 41.511017 NaN NaN NaN \n",
+ " smape 1.292102 NaN NaN NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 \n",
+ "GasSensor nll NaN NaN \n",
+ " rmse NaN NaN \n",
+ " smape NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 106 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 106 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor RANP\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3bca54f128c640188261cdb4ca42bb8f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=231.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor RANP\n",
+ "mean_NLL -2.24\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "GasSensor rmse 41.511017 2.036231 NaN NaN \n",
+ " smape 1.292102 0.300914 NaN NaN \n",
+ " nll 1.879850 -2.236591 NaN NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 \n",
+ "GasSensor rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 132 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 132 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor LSTM\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4dac078eb6c7444daf6f96e0001d899a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 5: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=231.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor LSTM\n",
+ "mean_NLL 16.40\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 NaN \n",
+ " smape 1.292102 0.300914 0.562990 NaN \n",
+ " nll 1.879850 -2.236591 16.395840 NaN \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 \n",
+ "GasSensor rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 104 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 104 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor LSTMSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a2469764109149f58852f5f133141b94",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=231.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor LSTMSeq2Seq\n",
+ "mean_NLL -1.53\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 2.139018 \n",
+ " smape 1.292102 0.300914 0.562990 0.337733 \n",
+ " nll 1.879850 -2.236591 16.395840 -1.527772 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 \n",
+ "GasSensor rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor TransformerSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d017e77dace74577af317d1aff46e876",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 483, in train\n",
+ " self.train_loop.run_training_epoch()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py\", line 541, in run_training_epoch\n",
+ " batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py\", line 678, in run_training_batch\n",
+ " self.trainer.hiddens\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py\", line 760, in training_step_and_backward\n",
+ " result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py\", line 304, in training_step\n",
+ " training_step_output = self.trainer.accelerator_backend.training_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 60, in training_step\n",
+ " output = self.__training_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 70, in __training_step\n",
+ " output = self.trainer.model.training_step(*args)\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 14, in forward\n",
+ " assert torch.isfinite(y_dist.loc).all(), 'output should be finite'\n",
+ "AssertionError: output should be finite\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor Transformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3e1d064b079d4666ae30d36ba9f3b696",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n",
+ " self.run_sanity_check(self.get_model())\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 648, in run_sanity_check\n",
+ " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 568, in run_evaluation\n",
+ " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n",
+ " output = self.trainer.accelerator_backend.validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n",
+ " output = self.__validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n",
+ " output = self.trainer.model.validation_step(*args)\n",
+ " File \"\", line 30, in validation_step\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 13, in forward\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ " outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 181, in forward\n",
+ " output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 294, in forward\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/activation.py\", line 927, in forward\n",
+ " attn_mask=attn_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/functional.py\", line 4049, in multi_head_attention_forward\n",
+ " raise RuntimeError('The size of the 2D attn_mask is not correct.')\n",
+ "RuntimeError: The size of the 2D attn_mask is not correct.\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "GasSensor TransformerProcess\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 48 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 48 K \n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "28b3d26c7e0a43fdbca9ca48d385d47d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 12: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=231.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "GasSensor TransformerProcess\n",
+ "mean_NLL 0.63\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 0.548482 \n",
+ " smape 0.285106 0.707343 0.719748 0.716110 \n",
+ " nll 1.633936 23.305540 19.444048 14.519804 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 1.026793 \n",
+ " smape 0.278300 0.226529 0.212408 0.210056 \n",
+ " nll 1.707181 1.478179 1.407227 1.386666 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 2.139018 \n",
+ " smape 1.292102 0.300914 0.562990 0.337733 \n",
+ " nll 1.879850 -2.236591 16.395840 -1.527772 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.579640 0.541376 \n",
+ " smape 0.769928 0.703052 \n",
+ " nll 46.982059 7.354424 \n",
+ "BejingPM25 rmse 1.165701 1.074965 \n",
+ " smape 0.232141 0.224733 \n",
+ " nll 2.859907 1.438450 \n",
+ "GasSensor rmse NaN 20.294767 \n",
+ " smape NaN 0.626682 \n",
+ " nll NaN 0.632958 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction BaselineLast\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ce16f59cf56445dbb7254e74a2da3766",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 8: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction BaselineLast\n",
+ "mean_NLL 1.56\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " nll \n",
+ " 1.556087 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction nll 1.556087 NaN NaN \n",
+ " rmse 0.749960 NaN NaN \n",
+ " smape 0.118430 NaN NaN \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction nll NaN NaN \n",
+ " rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction nll NaN \n",
+ " rmse NaN \n",
+ " smape NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 110 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 110 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction RANP\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8a20668730674d8b8fb61026424dde6c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 7: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction RANP\n",
+ "mean_NLL 1.31\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 NaN \n",
+ " smape 0.118430 0.089899 NaN \n",
+ " nll 1.556087 1.306983 NaN \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 141 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 141 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction LSTM\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7177f658e01a4d85be3d9af4839f030c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 6: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction LSTM\n",
+ "mean_NLL 1.94\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 118 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 118 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction LSTMSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "52bddda129194ffa84dbe179d32a54e3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 5: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction LSTMSeq2Seq\n",
+ "mean_NLL 1.57\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 NaN \n",
+ " smape 0.084181 NaN \n",
+ " nll 1.566483 NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction TransformerSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8d1332c9d2064912a47a039346c20bcf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction TransformerSeq2Seq\n",
+ "mean_NLL 2.33\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction Transformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0d64490ed26f489da06d0c665430b5f8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n",
+ " self.run_sanity_check(self.get_model())\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 648, in run_sanity_check\n",
+ " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 568, in run_evaluation\n",
+ " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n",
+ " output = self.trainer.accelerator_backend.validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n",
+ " output = self.__validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n",
+ " output = self.trainer.model.validation_step(*args)\n",
+ " File \"\", line 30, in validation_step\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 13, in forward\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ " outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 181, in forward\n",
+ " output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 294, in forward\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/activation.py\", line 927, in forward\n",
+ " attn_mask=attn_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/functional.py\", line 4049, in multi_head_attention_forward\n",
+ " raise RuntimeError('The size of the 2D attn_mask is not correct.')\n",
+ "RuntimeError: The size of the 2D attn_mask is not correct.\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction TransformerProcess\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b3610174160a4dcd953e8601f8ea3c59",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 9: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=15.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AppliancesEnergyPrediction TransformerProcess\n",
+ "mean_NLL 1.08\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------\n",
+ "0 | _model | BaselineLast | 1 \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic BaselineLast\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2cdbe4db46e342c4a4c66b7ab957da60",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 16: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 20: reducing learning rate of group 0 to 3.0000e-06.\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic BaselineLast\n",
+ "mean_NLL 1.76\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " nll \n",
+ " 1.763111 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic nll 1.763111 NaN NaN \n",
+ " rmse 2800.038818 NaN NaN \n",
+ " smape 0.799299 NaN NaN \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic nll NaN NaN \n",
+ " rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic nll NaN \n",
+ " rmse NaN \n",
+ " smape NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 107 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | RANP | 107 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic RANP\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "80304c3dda1e476986c59a869d2c3b97",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 8: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 12: reducing learning rate of group 0 to 3.0000e-06.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic RANP\n",
+ "mean_NLL -0.27\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 NaN \n",
+ " smape 0.799299 0.103380 NaN \n",
+ " nll 1.763111 -0.266901 NaN \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 135 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "--------------------------------\n",
+ "0 | _model | LSTM | 135 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic LSTM\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e192e7e8aeb64054afd6f0d73f3cbfa0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 16: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic LSTM\n",
+ "mean_NLL -0.17\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " 487.994751 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " 0.106410 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " -0.173646 \n",
+ " NaN \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 487.994751 \n",
+ " smape 0.799299 0.103380 0.106410 \n",
+ " nll 1.763111 -0.266901 -0.173646 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse NaN NaN \n",
+ " smape NaN NaN \n",
+ " nll NaN NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 108 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | LSTMSeq2Seq | 108 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic LSTMSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d8d5b45ce28a42beb46f366a236dde85",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 17: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic LSTMSeq2Seq\n",
+ "mean_NLL -0.25\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " 487.994751 \n",
+ " 494.512085 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " 0.106410 \n",
+ " 0.105370 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " -0.173646 \n",
+ " -0.246719 \n",
+ " NaN \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 487.994751 \n",
+ " smape 0.799299 0.103380 0.106410 \n",
+ " nll 1.763111 -0.266901 -0.173646 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse 494.512085 NaN \n",
+ " smape 0.105370 NaN \n",
+ " nll -0.246719 NaN \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerSeq2Seq | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic TransformerSeq2Seq\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3aa7f450548f46b4951081c155893b35",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 4: reducing learning rate of group 0 to 3.0000e-05.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic TransformerSeq2Seq\n",
+ "mean_NLL 4.15\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " 487.994751 \n",
+ " 494.512085 \n",
+ " 1762.025879 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " 0.106410 \n",
+ " 0.105370 \n",
+ " 0.492156 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " -0.173646 \n",
+ " -0.246719 \n",
+ " 4.152735 \n",
+ " NaN \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 487.994751 \n",
+ " smape 0.799299 0.103380 0.106410 \n",
+ " nll 1.763111 -0.266901 -0.173646 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse 494.512085 1762.025879 \n",
+ " smape 0.105370 0.492156 \n",
+ " nll -0.246719 4.152735 \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse NaN \n",
+ " smape NaN \n",
+ " nll NaN "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "---------------------------------------\n",
+ "0 | _model | Transformer | 2 M \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic Transformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "259e3550902f473fbd8381f6b29f1ddd",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "ERROR:root:failed to run model\n",
+ "Traceback (most recent call last):\n",
+ " File \"\", line 52, in \n",
+ " trainer.fit(model, dl_train, dl_test)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 440, in fit\n",
+ " results = self.accelerator_backend.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 54, in train\n",
+ " results = self.train_or_test()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py\", line 66, in train_or_test\n",
+ " results = self.trainer.train()\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 462, in train\n",
+ " self.run_sanity_check(self.get_model())\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 648, in run_sanity_check\n",
+ " _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\", line 568, in run_evaluation\n",
+ " output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py\", line 171, in evaluation_step\n",
+ " output = self.trainer.accelerator_backend.validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 76, in validation_step\n",
+ " output = self.__validation_step(args)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py\", line 86, in __validation_step\n",
+ " output = self.trainer.model.validation_step(*args)\n",
+ " File \"\", line 30, in validation_step\n",
+ " return self.training_step(batch, batch_idx, phase='val')\n",
+ " File \"\", line 19, in training_step\n",
+ " y_dist, extra = self.forward(*batch)\n",
+ " File \"\", line 13, in forward\n",
+ " y_dist, extra = self._model(x_past, y_past, x_future, y_future)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ " outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 181, in forward\n",
+ " output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/transformer.py\", line 294, in forward\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 722, in _call_impl\n",
+ " result = self.forward(*input, **kwargs)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/modules/activation.py\", line 927, in forward\n",
+ " attn_mask=attn_mask)\n",
+ " File \"/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/torch/nn/functional.py\", line 4049, in multi_head_attention_forward\n",
+ " raise RuntimeError('The size of the 2D attn_mask is not correct.')\n",
+ "RuntimeError: The size of the 2D attn_mask is not correct.\n",
+ "EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "INFO:lightning:EarlyStopping mode auto is unknown, fallback to auto mode.\n",
+ "EarlyStopping mode set to min for monitoring loss/val.\n",
+ "INFO:lightning:EarlyStopping mode set to min for monitoring loss/val.\n",
+ "GPU available: True, used: True\n",
+ "INFO:lightning:GPU available: True, used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning:TPU available: False, using: 0 TPU cores\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "Using native 16bit precision.\n",
+ "INFO:lightning:Using native 16bit precision.\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n",
+ "INFO:lightning:\n",
+ " | Name | Type | Params\n",
+ "----------------------------------------------\n",
+ "0 | _model | TransformerProcess | 49 K \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic TransformerProcess\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9c25d90977ce4307898c0ba77f1efbfe",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HBox(children=(HTML(value='predict'), FloatProgress(value=0.0, max=41.0), HTML(value='')))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MetroInterstateTraffic TransformerProcess\n",
+ "mean_NLL -0.27\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " 487.994751 \n",
+ " 494.512085 \n",
+ " 1762.025879 \n",
+ " 501.883423 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " 0.106410 \n",
+ " 0.105370 \n",
+ " 0.492156 \n",
+ " 0.112702 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " -0.173646 \n",
+ " -0.246719 \n",
+ " 4.152735 \n",
+ " -0.269533 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 487.994751 \n",
+ " smape 0.799299 0.103380 0.106410 \n",
+ " nll 1.763111 -0.266901 -0.173646 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse 494.512085 1762.025879 \n",
+ " smape 0.105370 0.492156 \n",
+ " nll -0.246719 4.152735 \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse 501.883423 \n",
+ " smape 0.112702 \n",
+ " nll -0.269533 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " rmse \n",
+ " 0.257789 \n",
+ " 0.538145 \n",
+ " 0.546141 \n",
+ " 0.548482 \n",
+ " 0.579640 \n",
+ " 0.541376 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.285106 \n",
+ " 0.707343 \n",
+ " 0.719748 \n",
+ " 0.716110 \n",
+ " 0.769928 \n",
+ " 0.703052 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.633936 \n",
+ " 23.305540 \n",
+ " 19.444048 \n",
+ " 14.519804 \n",
+ " 46.982059 \n",
+ " 7.354424 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " rmse \n",
+ " 1.383076 \n",
+ " 1.087366 \n",
+ " 1.032753 \n",
+ " 1.026793 \n",
+ " 1.165701 \n",
+ " 1.074965 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.278300 \n",
+ " 0.226529 \n",
+ " 0.212408 \n",
+ " 0.210056 \n",
+ " 0.232141 \n",
+ " 0.224733 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.707181 \n",
+ " 1.478179 \n",
+ " 1.407227 \n",
+ " 1.386666 \n",
+ " 2.859907 \n",
+ " 1.438450 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " rmse \n",
+ " 41.511017 \n",
+ " 2.036231 \n",
+ " 10.486523 \n",
+ " 2.139018 \n",
+ " NaN \n",
+ " 20.294767 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 1.292102 \n",
+ " 0.300914 \n",
+ " 0.562990 \n",
+ " 0.337733 \n",
+ " NaN \n",
+ " 0.626682 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.879850 \n",
+ " -2.236591 \n",
+ " 16.395840 \n",
+ " -1.527772 \n",
+ " NaN \n",
+ " 0.632958 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " rmse \n",
+ " 0.749960 \n",
+ " 0.562423 \n",
+ " 0.569487 \n",
+ " 0.543795 \n",
+ " 0.566450 \n",
+ " 0.536980 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.118430 \n",
+ " 0.089899 \n",
+ " 0.088473 \n",
+ " 0.084181 \n",
+ " 0.091806 \n",
+ " 0.088910 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.556087 \n",
+ " 1.306983 \n",
+ " 1.939970 \n",
+ " 1.566483 \n",
+ " 2.329622 \n",
+ " 1.080932 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " rmse \n",
+ " 2800.038818 \n",
+ " 486.489319 \n",
+ " 487.994751 \n",
+ " 494.512085 \n",
+ " 1762.025879 \n",
+ " 501.883423 \n",
+ " \n",
+ " \n",
+ " smape \n",
+ " 0.799299 \n",
+ " 0.103380 \n",
+ " 0.106410 \n",
+ " 0.105370 \n",
+ " 0.492156 \n",
+ " 0.112702 \n",
+ " \n",
+ " \n",
+ " nll \n",
+ " 1.763111 \n",
+ " -0.266901 \n",
+ " -0.173646 \n",
+ " -0.246719 \n",
+ " 4.152735 \n",
+ " -0.269533 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM \\\n",
+ "IMOSCurrentsVel rmse 0.257789 0.538145 0.546141 \n",
+ " smape 0.285106 0.707343 0.719748 \n",
+ " nll 1.633936 23.305540 19.444048 \n",
+ "BejingPM25 rmse 1.383076 1.087366 1.032753 \n",
+ " smape 0.278300 0.226529 0.212408 \n",
+ " nll 1.707181 1.478179 1.407227 \n",
+ "GasSensor rmse 41.511017 2.036231 10.486523 \n",
+ " smape 1.292102 0.300914 0.562990 \n",
+ " nll 1.879850 -2.236591 16.395840 \n",
+ "AppliancesEnergyPrediction rmse 0.749960 0.562423 0.569487 \n",
+ " smape 0.118430 0.089899 0.088473 \n",
+ " nll 1.556087 1.306983 1.939970 \n",
+ "MetroInterstateTraffic rmse 2800.038818 486.489319 487.994751 \n",
+ " smape 0.799299 0.103380 0.106410 \n",
+ " nll 1.763111 -0.266901 -0.173646 \n",
+ "\n",
+ " LSTMSeq2Seq TransformerSeq2Seq \\\n",
+ "IMOSCurrentsVel rmse 0.548482 0.579640 \n",
+ " smape 0.716110 0.769928 \n",
+ " nll 14.519804 46.982059 \n",
+ "BejingPM25 rmse 1.026793 1.165701 \n",
+ " smape 0.210056 0.232141 \n",
+ " nll 1.386666 2.859907 \n",
+ "GasSensor rmse 2.139018 NaN \n",
+ " smape 0.337733 NaN \n",
+ " nll -1.527772 NaN \n",
+ "AppliancesEnergyPrediction rmse 0.543795 0.566450 \n",
+ " smape 0.084181 0.091806 \n",
+ " nll 1.566483 2.329622 \n",
+ "MetroInterstateTraffic rmse 494.512085 1762.025879 \n",
+ " smape 0.105370 0.492156 \n",
+ " nll -0.246719 4.152735 \n",
+ "\n",
+ " TransformerProcess \n",
+ "IMOSCurrentsVel rmse 0.541376 \n",
+ " smape 0.703052 \n",
+ " nll 7.354424 \n",
+ "BejingPM25 rmse 1.074965 \n",
+ " smape 0.224733 \n",
+ " nll 1.438450 \n",
+ "GasSensor rmse 20.294767 \n",
+ " smape 0.626682 \n",
+ " nll 0.632958 \n",
+ "AppliancesEnergyPrediction rmse 0.536980 \n",
+ " smape 0.088910 \n",
+ " nll 1.080932 \n",
+ "MetroInterstateTraffic rmse 501.883423 \n",
+ " smape 0.112702 \n",
+ " nll -0.269533 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -2221,26 +14051,172 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 35,
"metadata": {
"ExecuteTime": {
- "start_time": "2020-10-26T22:42:52.500Z"
+ "end_time": "2020-10-27T22:06:50.235348Z",
+ "start_time": "2020-10-27T22:06:50.168117Z"
}
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
+ " and should_run_async(code)\n"
+ ]
+ }
+ ],
"source": [
- "EarlyStopping?"
+ "# File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ "# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask\n",
+ "# File \"/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py\", line 54, in forward\n",
+ "# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 34,
"metadata": {
"ExecuteTime": {
- "start_time": "2020-10-26T22:42:52.500Z"
+ "end_time": "2020-10-27T22:05:27.824869Z",
+ "start_time": "2020-10-27T22:05:27.734831Z"
}
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
+ " and should_run_async(code)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " BaselineLast \n",
+ " RANP \n",
+ " LSTM \n",
+ " LSTMSeq2Seq \n",
+ " TransformerSeq2Seq \n",
+ " TransformerProcess \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " IMOSCurrentsVel \n",
+ " 1.63 \n",
+ " 23.31 \n",
+ " 19.44 \n",
+ " 14.52 \n",
+ " 46.98 \n",
+ " 7.35 \n",
+ " \n",
+ " \n",
+ " BejingPM25 \n",
+ " 1.71 \n",
+ " 1.48 \n",
+ " 1.41 \n",
+ " 1.39 \n",
+ " 2.86 \n",
+ " 1.44 \n",
+ " \n",
+ " \n",
+ " GasSensor \n",
+ " 1.88 \n",
+ " -2.24 \n",
+ " 16.40 \n",
+ " -1.53 \n",
+ " NaN \n",
+ " 0.63 \n",
+ " \n",
+ " \n",
+ " AppliancesEnergyPrediction \n",
+ " 1.56 \n",
+ " 1.31 \n",
+ " 1.94 \n",
+ " 1.57 \n",
+ " 2.33 \n",
+ " 1.08 \n",
+ " \n",
+ " \n",
+ " MetroInterstateTraffic \n",
+ " 1.76 \n",
+ " -0.27 \n",
+ " -0.17 \n",
+ " -0.25 \n",
+ " 4.15 \n",
+ " -0.27 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " BaselineLast RANP LSTM LSTMSeq2Seq \\\n",
+ "IMOSCurrentsVel 1.63 23.31 19.44 14.52 \n",
+ "BejingPM25 1.71 1.48 1.41 1.39 \n",
+ "GasSensor 1.88 -2.24 16.40 -1.53 \n",
+ "AppliancesEnergyPrediction 1.56 1.31 1.94 1.57 \n",
+ "MetroInterstateTraffic 1.76 -0.27 -0.17 -0.25 \n",
+ "\n",
+ " TransformerSeq2Seq TransformerProcess \n",
+ "IMOSCurrentsVel 46.98 7.35 \n",
+ "BejingPM25 2.86 1.44 \n",
+ "GasSensor NaN 0.63 \n",
+ "AppliancesEnergyPrediction 2.33 1.08 \n",
+ "MetroInterstateTraffic 4.15 -0.27 "
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_results.xs('nll', level=1).round(2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-10-27T13:01:42.420773Z",
+ "start_time": "2020-10-27T13:01:42.362236Z"
+ },
+ "lines_to_next_cell": 0
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/wassname/anaconda/envs/seq2seq-time/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
+ " and should_run_async(code)\n"
+ ]
+ }
+ ],
"source": [
"# ds_preds.to_netcdf(trainer.logger.experiment.log_dir+'/ds_preds2.nc')"
]
@@ -2252,7 +14228,8 @@
"ExecuteTime": {
"end_time": "2020-10-26T12:37:48.346152Z",
"start_time": "2020-10-26T12:37:48.248720Z"
- }
+ },
+ "lines_to_next_cell": 2
},
"outputs": [],
"source": []
@@ -2302,7 +14279,7 @@
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
- "width": "307.2px"
+ "width": "209.208px"
},
"toc_section_display": true,
"toc_window_display": true
diff --git a/notebooks/05.0-mc-leaderboard.py b/notebooks/05.0-mc-leaderboard.py
index 96030fa..c387f36 100644
--- a/notebooks/05.0-mc-leaderboard.py
+++ b/notebooks/05.0-mc-leaderboard.py
@@ -311,6 +311,18 @@ results = defaultdict(dict)
from seq2seq_time.metrics import rmse, smape
+for Dataset in datasets:
+ dataset_name = Dataset.__name__
+ dataset = Dataset(datasets_root)
+ ds_train, ds_test = dataset.to_datasets(window_past=window_past,
+ window_future=window_future)
+
+ # Init data
+ x_past, y_past, x_future, y_future = ds_train.get_rows(10)
+ input_size = x_past.shape[-1]
+ output_size = y_future.shape[-1]
+
+
# +
for Dataset in datasets:
dataset_name = Dataset.__name__
@@ -397,7 +409,13 @@ df_results = pd.concat({k:pd.DataFrame(v) for k,v in results.items()})
display(df_results)
# +
-# EarlyStopping?
+# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
+# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
+# File "/media/wassname/Storage5/projects2/3ST/seq2seq-time/seq2seq_time/models/transformer.py", line 54, in forward
+# outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
+# -
+
+df_results.xs('nll', level=1).round(2)
# +
# ds_preds.to_netcdf(trainer.logger.experiment.log_dir+'/ds_preds2.nc')
diff --git a/references/.gitkeep b/references/.gitkeep
deleted file mode 100644
index e69de29..0000000
diff --git a/seq2seq_time/data/data.py b/seq2seq_time/data/data.py
index dfb77c2..6a56ec1 100644
--- a/seq2seq_time/data/data.py
+++ b/seq2seq_time/data/data.py
@@ -87,7 +87,7 @@ class GasSensor(RegressionForecastData):
dfs=[]
for f in zf.namelist():
if f.endswith('.csv'):
- now = pd.to_datetime(Pdset_to_ncath(f).stem, format='%Y%m%d_%H%M%S')
+ now = pd.to_datetime(Path(f).stem, format='%Y%m%d_%H%M%S')
df = pd.read_csv(zf.open(f))
df.index = pd.to_timedelta(df['Time (s)'], unit='s') + now
dfs.append(df)