use glutonts 0.8.0 and remove wandb

This commit is contained in:
Dr. Kashif Rasul
2021-07-06 13:16:48 +02:00
parent 5da3be5abd
commit 0ab82da2cb
15 changed files with 19 additions and 28 deletions
@@ -26,7 +26,7 @@ from gluonts.transform import (
TestSplitSampler,
ExpectedNumInstanceSampler,
)
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.model.predictor import Predictor
+1 -1
View File
@@ -26,7 +26,7 @@ from gluonts.transform import (
TestSplitSampler,
ExpectedNumInstanceSampler,
)
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.model.predictor import Predictor
+1 -1
View File
@@ -7,7 +7,7 @@ from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
+2 -2
View File
@@ -6,7 +6,7 @@ import torch.nn as nn
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
@@ -32,7 +32,7 @@ class LSTNetEstimator(PyTorchEstimator):
def __init__(
self,
freq: str,
prediction_length: int,
prediction_length: Optional[int],
context_length: int,
num_series: int,
ar_window: int = 24,
+1 -1
View File
@@ -7,7 +7,7 @@ from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.predictor import Predictor
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.transform import (
InstanceSplitter,
ValidationSplitSampler,
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.model.predictor import Predictor
+1 -1
View File
@@ -6,7 +6,7 @@ from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.model.predictor import Predictor
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import (
+1 -1
View File
@@ -12,7 +12,7 @@ from gluonts.time_feature import (
time_features_from_frequency_str,
)
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.transform import (
Transformation,
Chain,
+1 -1
View File
@@ -5,7 +5,7 @@ import torch
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.model.predictor import Predictor
from gluonts.transform import (
Transformation,
@@ -8,7 +8,7 @@ from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.modules.distribution_output import DistributionOutput
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
@@ -5,7 +5,7 @@ import torch
from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.time_feature import TimeFeature
from gluonts.torch.support.util import copy_parameters
from gluonts.torch.util import copy_parameters
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.model.predictor import Predictor
from gluonts.transform import (
-8
View File
@@ -2,7 +2,6 @@ import time
from typing import List, Optional, Union
from tqdm.auto import tqdm
import wandb
import torch
import torch.nn as nn
@@ -23,7 +22,6 @@ class Trainer:
learning_rate: float = 1e-3,
weight_decay: float = 1e-6,
maximum_learning_rate: float = 1e-2,
wandb_mode: str = "disabled",
clip_gradient: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
**kwargs,
@@ -36,7 +34,6 @@ class Trainer:
self.maximum_learning_rate = maximum_learning_rate
self.clip_gradient = clip_gradient
self.device = device
wandb.init(mode=wandb_mode, **kwargs)
def __call__(
self,
@@ -44,8 +41,6 @@ class Trainer:
train_iter: DataLoader,
validation_iter: Optional[DataLoader] = None,
) -> None:
wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch)
optimizer = Adam(
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
@@ -86,8 +81,6 @@ class Trainer:
refresh=False,
)
wandb.log({"loss": loss.item()})
loss.backward()
if self.clip_gradient is not None:
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
@@ -127,7 +120,6 @@ class Trainer:
if self.num_batches_per_epoch == batch_no:
break
wandb.log({"avg_val_loss": avg_epoch_loss_val})
it.close()
# mark epoch end time and log time cost of current epoch
+1 -2
View File
@@ -16,7 +16,7 @@ setup(
python_requires=">=3.6",
install_requires=[
"torch>=1.8.0",
"gluonts>=0.7.0",
"gluonts>=0.8.0",
"holidays",
"numpy~=1.16",
"pandas~=1.1",
@@ -24,7 +24,6 @@ setup(
"tqdm",
"matplotlib",
"tensorboard",
"wandb",
],
test_suite="tests",
tests_require=["flake8", "pytest"],
+1 -1
View File
@@ -37,7 +37,7 @@ def test_distribution():
estimator = DeepAREstimator(
freq=freq,
prediction_length=prediction_length,
input_size=48,
input_size=15,
trainer=Trainer(epochs=1, num_batches_per_epoch=1),
distr_output=StudentTOutput(),
)
@@ -56,7 +56,7 @@ def learn_distribution(
distr = distr_output.distribution(distr_args)
loss = -distr.log_prob(sample_label).mean()
loss.backward()
clip_grad_norm_(arg_proj.parameters(), 10.0)
#clip_grad_norm_(arg_proj.parameters(), 10.0)
optimizer.step()
num_batches += 1
@@ -77,7 +77,7 @@ def learn_distribution(
torch.ones((1, 1, 1)), torch.ones((1, 1)) * 0.1
)
return samples.mean(), samples.std(), percentile_10, percentile_90
return samples.mean(), samples.std(), percentile_10.squeeze(), percentile_90.squeeze()
def test_independent_implicit_quantile() -> None:
@@ -181,7 +181,7 @@ def test_training_with_implicit_quantile_output():
num_batches_per_epoch=3,
batch_size=256,
),
input_size=48,
input_size=15,
)
deepar_predictor = deepar_estimator.train(dataset.train, num_workers=1)
forecast_it, ts_it = make_evaluation_predictions(
@@ -224,7 +224,7 @@ def test_instanciation_of_args_proj():
num_batches_per_epoch=1,
batch_size=256,
),
input_size=48,
input_size=15,
)
assert distr_output.method_calls == 1
deepar_predictor = deepar_estimator.train(dataset.train, num_workers=1)
@@ -258,7 +258,7 @@ def test_instanciation_of_args_proj():
num_batches_per_epoch=1,
batch_size=256,
),
input_size=48,
input_size=15,
)
assert distr_output.method_calls == 3
new_estimator.train(dataset.train, num_workers=1)