From 8da5cb49d94440b625e9c224922421f75050f54b Mon Sep 17 00:00:00 2001 From: Hstellar Date: Wed, 20 Apr 2022 00:57:39 -0400 Subject: [PATCH] cleaned up(still not working) --- pyraformer/pyraformer/.DS_Store => .DS_Store | Bin 6148 -> 6148 bytes pyraformer/.DS_Store | Bin 6148 -> 6148 bytes pyraformer/estimator.py | 122 +++-- pyraformer/lightning_module.py | 125 ++--- pyraformer/module.py | 538 ++++++++++++++++--- pyraformer/pyraformer.ipynb | 2 +- pyraformer/pyraformer/Layers.py | 4 +- 7 files changed, 613 insertions(+), 178 deletions(-) rename pyraformer/pyraformer/.DS_Store => .DS_Store (76%) diff --git a/pyraformer/pyraformer/.DS_Store b/.DS_Store similarity index 76% rename from pyraformer/pyraformer/.DS_Store rename to .DS_Store index 109e4f3061517a841ebdfb9fb2038d2ba6ff5409..b51cf7ad146094638b7f6c85e57113d9fd8a9c87 100644 GIT binary patch literal 6148 zcmeHKQA@)x5WZ~FbsNGS6!b0N>%eX@M0_cA{sAlcpfX!Jwb(Ud?Hn=&eb+zaAMy8i zmt^4F20n?*9Nc|LF4uA&B$opK=8nQ9Kpg-qRKijf%@QF$>4FsWN(qt4Z`^}`A@re# z)`(%iF!0wIptDdg%EKt8~ed!)Ejr}+ovMQ{3z)SWk3}6aLdhAl!RjJicu1#GS*WLt728U z^}WgDuzB3zt#)hL;FG4^wi~?NI+{)^*3SOH$$9rF?kD0^vCHuCk+P+625&f6n%uKD zOk$Bdpr2xEkjeBDNY{~q73qB=VMdz4|H6PiHdgJwBy93D3>XG}%mCdF9F@?~m@AZ5 z2O8M|Aks5J3XZ9lpct;v(U>bl4+>MEh$@un7K5p9T<3b9qcK;g!hz}LgXupr-Jvk~ zcljWrAy2Id(ks%f3-|JmyPe?G{}3zHwyo)LYeVhh}j>cReuArEQfTY0~!@#dH@Cl!Ikb(dJ literal 6148 zcmeHKy-ve05WZ_GQPia)qhFzK5UTJ5ojL`QB8rqMSs2*-EZw>x@d`W`-}z2~HVqOB zLe-sg{_Oj)ozIdlCL&(GY!*awBAP-4N7rc1i1bBSK!;!lt&&#?g5h3%f4_6qx zOf@}IdAz^9tZbfv{dC^5>zns@y*Eti;tV(g&cMYnfSN5bohf?j3^)VMK*xaG4*@Ee z8Aio)b)ZR40N@1XBB`mu~2NHo}wrV0|Nsi1A_nqLmERqLlHwRLn=ej#KPr_ERzM8B?NNP z4TF>Oa|?i~7#NuN<^u(#CjVpNWvrac!>kP^W%vs+kQEhdE@Ym@vblkIALC|r4t@@x ZT|klV%#-;=EIB~N0Ig%%93ZlV831zH9_9c5 delta 314 zcmZoMXfc=|#>B)qu~2NHo}wrR0|Nsi1A_nqLo!1m5N9x?GQ@8_$he%b9wf!h5D$a} z43)^z4DmVXhQZ1CxdlKKK#-*jBygzA&3AE0%E?axssNhT5WbTAh$BALDR9+Bgq0U$ zAUmUhVR8bagaC@$7{AR13QA4(W8`IwoE*cb4JKvyQFJl|06n~fWivYmKL;>mfSm8l Vlles~Ie-BJL<|g@14Onk0{}1VLzn;n diff --git a/pyraformer/estimator.py b/pyraformer/estimator.py index efeab40..bb23ca6 100644 --- a/pyraformer/estimator.py +++ b/pyraformer/estimator.py @@ -35,6 +35,7 @@ from module import PyraformerLRModel from torch.utils.data import DataLoader from tools import SingleStepLoss as LossFactory from torch.utils.data.sampler import RandomSampler + PREDICTION_INPUT_NAMES = [ "feat_static_cat", "feat_static_real", @@ -56,30 +57,33 @@ class PyraformerEstimator(PyTorchLightningEstimator): self, freq: str, prediction_length: int, - #Train parameters + # Train parameters inner_batch: int = 8, lr: float = 1e-5, visualize_fre: int = 2000, pretrain: bool = True, - hard_sample_mining:bool=True, + hard_sample_mining: bool = True, covariate_size: int = 3, - - # Model parameters - num_seq: int = 370,# - decoder: str = 'FC',# selection: [FC, attention] + # Model parameters + num_seq: int = 370, # + decoder: str = "FC", # selection: [FC, attention] context_length: Optional[int] = None, input_size: int = 1, dropout: float = 0.1, d_model: int = 512, d_inner_hid: int = 512, d_k: int = 128, - d_v:int = 128, + d_v: int = 128, + d_bottleneck: int = 128, num_heads: int = 4, n_layer: int = 4, + enc_in: int = 1, # depends on dataset used + CSCM: str = "Bottleneck_Construct", # [Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct] + embed_type: str = "CustomEmbedding", #[DataEmbedding, CustomEmbedding] + truncate: bool = False, # loss: DistributionLoss = LossFactory, ignore_zero: bool = True, - single_step: bool = True,#if False, Multistep=True - + single_step: bool = True, # if False, Multistep=True inner_size: int = 3, use_tvm: bool = False, num_feat_dynamic_real: int = 0, @@ -98,7 +102,7 @@ class PyraformerEstimator(PyTorchLightningEstimator): trainer_kwargs: Optional[Dict[str, Any]] = dict(), train_sampler: Optional[InstanceSampler] = None, validation_sampler: Optional[InstanceSampler] = None, - window_size: int = [4, 4, 4] + window_size: int = [4, 4, 4], ) -> None: trainer_kwargs = { "max_epochs": 10, @@ -121,15 +125,25 @@ class PyraformerEstimator(PyTorchLightningEstimator): self.d_inner_hid = d_inner_hid self.d_k = d_k self.d_v = d_v + self.d_bottleneck = d_bottleneck self.num_heads = num_heads self.n_layer = n_layer self.single_step = single_step self.ignore_zero = ignore_zero - self.loss = LossFactory(self.ignore_zero) if self.single_step==True else torch.nn.MSELoss(reduction='none') + self.decoder = decoder + self.enc_in = enc_in + self.CSCM = CSCM + self.embed_type = embed_type + self.truncate = truncate + self.loss = ( + LossFactory(self.ignore_zero) + if self.single_step == True + else torch.nn.MSELoss(reduction="none") + ) self.batch_size = batch_size self.distr_output = distr_output - self.window_size = window_size#[4,4,4]#window_size + self.window_size = window_size # [4,4,4]#window_size self.inner_size = inner_size self.use_tvm = use_tvm self.prediction_length = prediction_length @@ -319,25 +333,67 @@ class PyraformerEstimator(PyTorchLightningEstimator): def create_lightning_module(self) -> PyraformerLightningModule: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.single_step: - model = PyraformerSSModel(freq= self.freq, covariate_size = self.covariate_size, - num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model, - d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v, - num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss, - window_size = self.window_size, inner_size = self.inner_size, - use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, num_feat_dynamic_real= self.num_feat_dynamic_real, - num_feat_static_cat = self.num_feat_static_cat, - num_feat_static_real = self.num_feat_static_real, - cardinality = self.cardinality, - embedding_dimension = self.embedding_dimension, - distr_output=self.distr_output, - scaling=self.scaling,num_parallel_samples=self.num_parallel_samples, device=device) - # else: - # model = PyraformerLRModel(freq= self.freq, covariate_size = self.covariate_size, - # num_seq=self.num_seq, input_size = self.input_size, dropout = self.dropout, d_model = self.d_model, - # d_inner_hid = self.d_inner_hid, d_k = self.d_k, d_v = self.d_v, - # num_heads = self.num_heads, n_layer = self.n_layer, loss = self.loss, - # window_size = self.window_size, inner_size = self.inner_size, - # use_tvm = self.use_tvm, prediction_length = self.prediction_length,context_length = self.context_length, lags_seq = self.lags_seq, device=device) + model = PyraformerSSModel( + freq=self.freq, + covariate_size=self.covariate_size, + num_seq=self.num_seq, + input_size=self.input_size, + dropout=self.dropout, + d_model = self.d_model, + d_inner_hid=self.d_inner_hid, + d_k=self.d_k, + d_v=self.d_v, + num_heads=self.num_heads, + n_layer=self.n_layer, + loss=self.loss, + window_size=self.window_size, + inner_size=self.inner_size, + use_tvm=self.use_tvm, + prediction_length=self.prediction_length, + context_length=self.context_length, + lags_seq=self.lags_seq, + num_feat_dynamic_real=self.num_feat_dynamic_real, + num_feat_static_cat=self.num_feat_static_cat, + num_feat_static_real=self.num_feat_static_real, + cardinality=self.cardinality, + embedding_dimension=self.embedding_dimension, + distr_output=self.distr_output, + scaling=self.scaling, + num_parallel_samples=self.num_parallel_samples, + device=device, + ) + else: + model = PyraformerLRModel( + predict_step=self.prediction_length, + d_model=self.d_model, + input_size=self.input_size, + decoder=self.decoder, + window_size=self.window_size, + truncate=self.truncate, + d_inner_hid=self.d_inner_hid, + d_k=self.d_k, + d_v=self.d_v, + dropout=self.dropout, + enc_in=self.enc_in, + covariate_size=self.covariate_size, + seq_num=self.num_seq, + CSCM=self.CSCM, + d_bottleneck=self.d_bottleneck, + num_head=self.num_heads, + n_layer=self.n_layer, + inner_size=self.inner_size, + use_tvm=self.use_tvm, + prediction_length=self.prediction_length, + context_length=self.context_length, + lags_seq=self.lags_seq, + num_feat_dynamic_real=self.num_feat_dynamic_real, + num_feat_static_cat=self.num_feat_static_cat, + num_feat_static_real=self.num_feat_static_real, + cardinality=self.cardinality, + embedding_dimension=self.embedding_dimension, + num_parallel_samples=self.num_parallel_samples, + embed_type = self.embed_type, + distr_output= self.distr_output, + device=device, + ) return PyraformerLightningModule(model=model, loss=self.loss) - - diff --git a/pyraformer/lightning_module.py b/pyraformer/lightning_module.py index 5753cca..658bdec 100644 --- a/pyraformer/lightning_module.py +++ b/pyraformer/lightning_module.py @@ -6,77 +6,80 @@ from module import PyraformerSSModel from module import PyraformerLRModel from tools import SingleStepLoss as LossFactory from tools import AE_loss + # from module import PyraformerModel class PyraformerLightningModule(pl.LightningModule): - def __init__(self, model: PyraformerSSModel, loss: DistributionLoss = LossFactory, lr: float = 1e-5, weight_decay: float = 1e-8,) -> None: - super().__init__() - self.save_hyperparameters() - self.model = model - self.loss = loss - self.lr = lr - self.weight_decay = weight_decay + def __init__( + self, + model: PyraformerSSModel, + loss: DistributionLoss = LossFactory, + lr: float = 1e-5, + weight_decay: float = 1e-8, + ) -> None: + super().__init__() + self.save_hyperparameters() + self.model = model + self.loss = loss + self.lr = lr + self.weight_decay = weight_decay - def training_step(self, batch, batch_idx: int): + def training_step(self, batch, batch_idx: int): - """Execute training step""" - train_loss = self(batch) - self.log( - "train_loss", - train_loss, - on_epoch=True, - on_step=False, - prog_bar=True, - ) - return train_loss + """Execute training step""" + train_loss = self(batch) + self.log( + "train_loss", + train_loss, + on_epoch=True, + on_step=False, + prog_bar=True, + ) + return train_loss - def validation_step(self, batch, batch_idx: int): - """Execute validation step""" - with torch.inference_mode(): - val_loss = self(batch) - self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) - return val_loss + def validation_step(self, batch, batch_idx: int): + """Execute validation step""" + with torch.inference_mode(): + val_loss = self(batch) + self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) + return val_loss - def configure_optimizers(self): - """Returns the optimizer to use""" - return torch.optim.Adam( - self.model.parameters(), - lr=self.lr, - weight_decay=self.weight_decay, - ) + def configure_optimizers(self): + """Returns the optimizer to use""" + return torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) - def forward(self, batch): - feat_static_cat = batch["feat_static_cat"] - feat_static_real = batch["feat_static_real"] - past_time_feat = batch["past_time_feat"] - past_target = batch["past_target"] - future_time_feat = batch["future_time_feat"] - future_target = batch["future_target"] - past_observed_values = batch["past_observed_values"] - future_observed_values = batch["future_observed_values"] - - Pyraformer_inputs, scale, _ = self.model.create_network_inputs( - feat_static_cat, - feat_static_real, - past_time_feat, - past_target, - past_observed_values, - future_time_feat, - future_target, - ) - params = self.model.output_params(Pyraformer_inputs) - distr = self.model.output_distribution(params, scale) - - loss_values = self.loss(distr, future_target) - - if len(self.model.target_shape) == 0: - loss_weights = future_observed_values - else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) - - return weighted_average(loss_values, weights=loss_weights) + def forward(self, batch): + feat_static_cat = batch["feat_static_cat"] + feat_static_real = batch["feat_static_real"] + past_time_feat = batch["past_time_feat"] + past_target = batch["past_target"] + future_time_feat = batch["future_time_feat"] + future_target = batch["future_target"] + past_observed_values = batch["past_observed_values"] + future_observed_values = batch["future_observed_values"] + Pyraformer_inputs, scale, _ = self.model.create_network_inputs( + feat_static_cat, + feat_static_real, + past_time_feat, + past_target, + past_observed_values, + future_time_feat, + future_target, + ) + params = self.model.output_params(Pyraformer_inputs) + distr = self.model.output_distribution(params, scale) + loss_values = self.loss(distr, future_target) + if len(self.model.target_shape) == 0: + loss_weights = future_observed_values + else: + loss_weights = future_observed_values.min(dim=-1, keepdim=False) + return weighted_average(loss_values, weights=loss_weights) diff --git a/pyraformer/module.py b/pyraformer/module.py index bdd0f6b..a30d2eb 100644 --- a/pyraformer/module.py +++ b/pyraformer/module.py @@ -8,19 +8,43 @@ from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.torch.modules.scaler import MeanScaler, NOPScaler from pyraformer.Layers import EncoderLayer, Predictor, Decoder -from pyraformer.Layers import Bottleneck_Construct, Conv_Construct, MaxPooling_Construct, AvgPooling_Construct -from pyraformer.Layers import get_mask, refer_points, get_k_q, get_q_k, get_subsequent_mask +from pyraformer.Layers import ( + Bottleneck_Construct, + Conv_Construct, + MaxPooling_Construct, + AvgPooling_Construct, +) +from pyraformer.Layers import ( + get_mask, + refer_points, + get_k_q, + get_q_k, + get_subsequent_mask, +) from pyraformer.embed import SingleStepEmbedding, DataEmbedding, CustomEmbedding class EncoderSS(nn.Module): @validated() - def __init__(self, covariate_size, - num_seq, input_size ,dropout , d_model, - d_inner_hid, d_k, d_v , - num_heads , n_layer, loss , - window_size , inner_size, - use_tvm, prediction_length, device): + def __init__( + self, + covariate_size, + num_seq, + input_size, + dropout, + d_model, + d_inner_hid, + d_k, + d_v, + num_heads, + n_layer, + loss, + window_size, + inner_size, + use_tvm, + prediction_length, + device, + ): super().__init__() self.d_model = d_model @@ -30,21 +54,49 @@ class EncoderSS(nn.Module): self.indexes = refer_points(self.all_size, window_size, device) if use_tvm: - - assert len(set(self.window_size)) == 1, "Only constant window size is supported." + + assert ( + len(set(self.window_size)) == 1 + ), "Only constant window size is supported." q_k_mask = get_q_k(input_size, inner_size, window_size[0], device) k_q_mask = get_k_q(q_k_mask) - self.layers = nn.ModuleList([ - EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \ - normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer) - ]) + self.layers = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner_hid, + num_heads, + d_k, + d_v, + dropout=dropout, + normalize_before=False, + use_tvm=True, + q_k_mask=q_k_mask, + k_q_mask=k_q_mask, + ) + for i in range(n_layer) + ] + ) else: - self.layers = nn.ModuleList([ - EncoderLayer(d_model, d_inner_hid, num_heads, d_k, d_v, dropout=dropout, \ - normalize_before=False) for i in range(n_layer) - ]) + self.layers = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner_hid, + num_heads, + d_k, + d_v, + dropout=dropout, + normalize_before=False, + ) + for i in range(n_layer) + ] + ) + + self.embedding = SingleStepEmbedding( + covariate_size, num_seq, d_model, input_size, device + ) - self.embedding = SingleStepEmbedding(covariate_size, num_seq, d_model, input_size, device) self.conv_layers = Bottleneck_Construct(d_model, window_size, d_k) def forward(self, sequence): @@ -57,21 +109,38 @@ class EncoderSS(nn.Module): for i in range(len(self.layers)): seq_enc, _ = self.layers[i](seq_enc, mask) - indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device) + indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to( + seq_enc.device + ) indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) all_enc = torch.gather(seq_enc, 1, indexes) all_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) return all_enc + class PyraformerSSModel(nn.Module): @validated() - def __init__(self, freq, covariate_size, - num_seq, input_size ,dropout, d_model, - d_inner_hid, d_k, d_v, - num_heads , n_layer, loss , - window_size , inner_size, - use_tvm, prediction_length,context_length,lags_seq, + def __init__( + self, + freq, + covariate_size, + num_seq, + input_size, + dropout, + d_model, + d_inner_hid, + d_k, + d_v, + num_heads, + n_layer, + loss, + window_size, + inner_size, + use_tvm, + prediction_length, + context_length, + lags_seq, num_feat_dynamic_real, num_feat_static_cat, num_feat_static_real, @@ -79,25 +148,41 @@ class PyraformerSSModel(nn.Module): embedding_dimension, distr_output, # loss: DistributionLoss = NegativeLogLikelihood(), - scaling,num_parallel_samples,device): - + scaling, + num_parallel_samples, + device, + + ): super().__init__() self.context_length = context_length self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) - self.encoder = EncoderSS(covariate_size, - num_seq, input_size ,dropout , d_model, - d_inner_hid, d_k, d_v , - num_heads , n_layer, loss , - window_size , inner_size, - use_tvm, prediction_length, device) + self.encoder = EncoderSS( + covariate_size, + num_seq, + input_size, + dropout, + d_model, + d_inner_hid, + d_k, + d_v, + num_heads, + n_layer, + loss, + window_size, + inner_size, + use_tvm, + prediction_length, + device, + ) - # convert hidden vectors into two scalar + # convert hidden vectors into two scalar self.mean_hidden = Predictor(4 * d_model, 1) self.var_hidden = Predictor(4 * d_model, 1) - + self.softplus = nn.Softplus() - + self.distr_output = distr_output + def forward(self, data): enc_output = self.encoder(data) @@ -114,10 +199,11 @@ class PyraformerSSModel(nn.Module): sample_mu = mu[:, -1] * v sample_sigma = sigma[:, -1] * v return sample_mu, sample_sigma - + @property def _past_length(self) -> int: return self.context_length + max(self.lags_seq) + @property def _number_of_features(self) -> int: return ( @@ -126,6 +212,7 @@ class PyraformerSSModel(nn.Module): + self.num_feat_static_real + 1 # the log(scale) ) + def get_lagged_subsequences( self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 ) -> torch.Tensor: @@ -178,7 +265,7 @@ class PyraformerSSModel(nn.Module): assert ( features is None or features.shape[2] == self._number_of_features ), f"{features.shape[2]}, expected {self._number_of_features}" - + def create_network_inputs( self, feat_static_cat: torch.Tensor, @@ -274,70 +361,159 @@ class PyraformerSSModel(nn.Module): if trailing_n is not None: sliced_params = [p[:, -trailing_n:] for p in params] return self.distr_output.distribution(sliced_params, scale=scale) - + + class Encoder(nn.Module): @validated() - def __init__(self, model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer ,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device): + def __init__( + self, + # model, + window_size, + truncate, + input_size, + inner_size, + decoder, + d_model, + d_k, + d_v, + d_inner_hid, + dropout, + n_layer, + enc_in, + covariate_size, + seq_num, + CSCM, + d_bottleneck, + num_head, + use_tvm, + embed_type, + device, + ): super().__init__() self.d_model = d_model - self.model_type = model + # self.model_type = model self.window_size = window_size self.truncate = truncate - if decoder == 'attention': - self.mask, self.all_size = get_mask(input_size, window_size, inner_size, device) + if decoder == "attention": + self.mask, self.all_size = get_mask( + input_size, window_size, inner_size, device + ) else: - self.mask, self.all_size = get_mask(input_size+1, window_size, inner_size, device) + self.mask, self.all_size = get_mask( + input_size + 1, window_size, inner_size, device + ) self.decoder_type = decoder - if decoder == 'FC': + if decoder == "FC": self.indexes = refer_points(self.all_size, window_size, device) if use_tvm: - assert len(set(self.window_size)) == 1, "Only constant window size is supported." - padding = 1 if decoder == 'FC' else 0 - q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0],device) + assert ( + len(set(self.window_size)) == 1 + ), "Only constant window size is supported." + padding = 1 if decoder == "FC" else 0 + q_k_mask = get_q_k(input_size + padding, inner_size, window_size[0], device) k_q_mask = get_k_q(q_k_mask) - self.layers = nn.ModuleList([ - EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \ - normalize_before=False, use_tvm=True, q_k_mask=q_k_mask, k_q_mask=k_q_mask) for i in range(n_layer) - ]) + self.layers = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner_hid, + num_head, + d_k, + d_v, + dropout=dropout, + normalize_before=False, + use_tvm=True, + q_k_mask=q_k_mask, + k_q_mask=k_q_mask, + ) + for i in range(n_layer) + ] + ) else: - self.layers = nn.ModuleList([ - EncoderLayer(d_model, d_inner_hid, num_head, d_k, d_v, dropout=dropout, \ - normalize_before=False) for i in range(n_layer) - ]) - - if opt.embed_type == 'CustomEmbedding': - self.enc_embedding = CustomEmbedding(enc_in, d_model, covariate_size, seq_num, dropout) + self.layers = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner_hid, + num_head, + d_k, + d_v, + dropout=dropout, + normalize_before=False, + ) + for i in range(n_layer) + ] + ) + + if embed_type == "CustomEmbedding": + self.enc_embedding = CustomEmbedding( + enc_in, d_model, covariate_size, seq_num, dropout + ) else: self.enc_embedding = DataEmbedding(enc_in, d_model, dropout) - + self.conv_layers = eval(CSCM)(d_model, window_size, d_bottleneck) - - + def forward(self, x_enc, x_mark_enc): seq_enc = self.enc_embedding(x_enc, x_mark_enc) - + mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) seq_enc = self.conv_layers(seq_enc) for i in range(len(self.layers)): seq_enc, _ = self.layers[i](seq_enc, mask) - - if self.decoder_type == 'FC': - indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device) + + if self.decoder_type == "FC": + indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to( + seq_enc.device + ) indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) all_enc = torch.gather(seq_enc, 1, indexes) seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) - elif self.decoder_type == 'attention' and self.truncate: - seq_enc = seq_enc[:, :self.all_size[0]] - + elif self.decoder_type == "attention" and self.truncate: + seq_enc = seq_enc[:, : self.all_size[0]] + return seq_enc - + class PyraformerLRModel(nn.Module): @validated() - def __init__(self, predict_step, d_model, input_size, decoder, window_size, truncate, model,d_inner_hid,d_k,d_v,dropout,enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device): + def __init__( + self, + predict_step, + d_model, + input_size, + decoder, + window_size, + truncate, + d_inner_hid, + d_k, + d_v, + dropout, + enc_in, + covariate_size, + seq_num, + CSCM, + d_bottleneck, + num_head, + n_layer, + inner_size, + use_tvm, + prediction_length, + context_length, + lags_seq, + num_feat_dynamic_real, + num_feat_static_cat, + num_feat_static_real, + cardinality, + embedding_dimension, + num_parallel_samples, + embed_type, + distr_output, + device + ): super().__init__() self.predict_step = predict_step @@ -345,13 +521,50 @@ class PyraformerLRModel(nn.Module): self.input_size = input_size self.decoder_type = decoder self.channels = enc_in - - self.encoder = Encoder(model, window_size, truncate, input_size, inner_size,decoder,d_model, d_k,d_v,d_inner_hid, dropout,n_layer, enc_in,covariate_size,seq_num,CSCM,d_bottleneck,num_head,use_tvm,device) - if decoder == 'attention': + self.distr_output = distr_output + self.context_length = context_length + self.lags_seq = lags_seq + + self.encoder = Encoder( + # model, + window_size, + truncate, + input_size, + inner_size, + decoder, + d_model, + d_k, + d_v, + d_inner_hid, + dropout, + n_layer, + enc_in, + covariate_size, + seq_num, + CSCM, + d_bottleneck, + num_head, + use_tvm, + embed_type, + device, + ) + if decoder == "attention": mask = get_subsequent_mask(input_size, window_size, predict_step, truncate) - self.decoder = Decoder(model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask) + self.decoder = Decoder( + # model, + d_model, + d_inner_hid, + num_head, + d_k, + d_v, + dropout, + enc_in, + covariate_size, + seq_num, + mask, + ) self.predictor = Predictor(d_model, enc_in) - elif opt.decoder == 'FC': + elif decoder == "FC": self.predictor = Predictor(4 * d_model, predict_step * enc_in) def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, pretrain): @@ -364,18 +577,181 @@ class PyraformerLRModel(nn.Module): type_prediction: batch*seq_len*num_classes (not normalized); time_prediction: batch*seq_len. """ - if self.decoder_type == 'attention': + if self.decoder_type == "attention": enc_output = self.encoder(x_enc, x_mark_enc) dec_enc = self.decoder(x_dec, x_mark_dec, enc_output) if pretrain: - dec_enc = torch.cat([enc_output[:, :self.input_size], dec_enc], dim=1) + dec_enc = torch.cat([enc_output[:, : self.input_size], dec_enc], dim=1) pred = self.predictor(dec_enc) else: pred = self.predictor(dec_enc) - elif self.decoder_type == 'FC': + elif self.decoder_type == "FC": enc_output = self.encoder(x_enc, x_mark_enc)[:, -1, :] - pred = self.predictor(enc_output).view(enc_output.size(0), self.predict_step, -1) + pred = self.predictor(enc_output).view( + enc_output.size(0), self.predict_step, -1 + ) return pred + + @property + def _past_length(self) -> int: + return self.predict_step #+ max(0,self.lags_seq) + @property + def _number_of_features(self) -> int: + return ( + sum(self.embedding_dimension) + + self.num_feat_dynamic_real + + self.num_feat_static_real + + 1 # the log(scale) + ) + + def get_lagged_subsequences( + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 + ) -> torch.Tensor: + """ + Returns lagged subsequences of a given sequence. + Parameters + ---------- + sequence : Tensor + the sequence from which lagged subsequences should be extracted. + Shape: (N, T, C). + subsequences_length : int + length of the subsequences to be extracted. + shift: int + shift the lags by this amount back. + Returns + -------- + lagged : Tensor + a tensor of shape (N, S, C, I), where S = subsequences_length and + I = len(indices), containing lagged subsequences. Specifically, + lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :]. + """ + sequence_length = sequence.shape[1] + indices = [lag - shift for lag in self.lags_seq] + + assert max(indices) + subsequences_length <= sequence_length, ( + f"lags cannot go further than history length, found lag {max(indices)} " + f"while history length is only {sequence_length}" + ) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def _check_shapes( + self, + prior_input: torch.Tensor, + inputs: torch.Tensor, + features: Optional[torch.Tensor], + ) -> None: + assert len(prior_input.shape) == len(inputs.shape) + assert ( + len(prior_input.shape) == 2 and self.input_size == 1 + ) or prior_input.shape[2] == self.input_size + assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[ + -1 + ] == self.input_size + assert ( + features is None or features.shape[2] == self._number_of_features + ), f"{features.shape[2]}, expected {self._number_of_features}" + + def create_network_inputs( + self, + feat_static_cat: torch.Tensor, + feat_static_real: torch.Tensor, + past_time_feat: torch.Tensor, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_time_feat: Optional[torch.Tensor] = None, + future_target: Optional[torch.Tensor] = None, + ): + # time feature + time_feat = ( + torch.cat( + ( + past_time_feat[:, self._past_length - self.context_length :, ...], + future_time_feat, + ), + dim=1, + ) + if future_target is not None + else past_time_feat[:, self._past_length - self.context_length :, ...] + ) + + # target + context = past_target[:, -self.context_length :] + observed_context = past_observed_values[:, -self.context_length :] + _, scale = self.scaler(context, observed_context) + + inputs = ( + torch.cat((past_target, future_target), dim=1) / scale + if future_target is not None + else past_target / scale + ) + + inputs_length = ( + self._past_length + self.prediction_length + if future_target is not None + else self._past_length + ) + assert inputs.shape[1] == inputs_length + + subsequences_length = ( + self.context_length + self.prediction_length + if future_target is not None + else self.context_length + ) + + # embeddings + embedded_cat = self.embedder(feat_static_cat) + static_feat = torch.cat( + (embedded_cat, feat_static_real, scale.log()), + dim=1, + ) + expanded_static_feat = static_feat.unsqueeze(1).expand( + -1, time_feat.shape[1], -1 + ) + + features = torch.cat((expanded_static_feat, time_feat), dim=-1) + + # self._check_shapes(prior_input, inputs, features) + + # sequence = torch.cat((prior_input, inputs), dim=1) + lagged_sequence = self.get_lagged_subsequences( + sequence=inputs, + subsequences_length=subsequences_length, + ) + + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape( + lags_shape[0], lags_shape[1], -1 + ) + + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + + return transformer_inputs, scale, static_feat + + def output_params(self, transformer_inputs): + enc_input = transformer_inputs[:, : self.context_length, ...] + dec_input = transformer_inputs[:, self.context_length :, ...] + + enc_out = self.transformer.encoder(enc_input) + dec_output = self.transformer.decoder( + dec_input, enc_out, tgt_mask=self.tgt_mask + ) + + return self.param_proj(dec_output) + + @torch.jit.ignore + def output_distribution( + self, params, scale=None, trailing_n=None + ) -> torch.distributions.Distribution: + sliced_params = params + if trailing_n is not None: + sliced_params = [p[:, -trailing_n:] for p in params] + return self.distr_output.distribution(sliced_params, scale=scale) diff --git a/pyraformer/pyraformer.ipynb b/pyraformer/pyraformer.ipynb index 20cf2ad..ff9d08b 100644 --- a/pyraformer/pyraformer.ipynb +++ b/pyraformer/pyraformer.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"code","execution_count":1,"id":"b19f0e22","metadata":{"id":"b19f0e22","executionInfo":{"status":"ok","timestamp":1650399031126,"user_tz":240,"elapsed":6,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"outputs":[],"source":["%matplotlib inline"]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive/')\n","%cd /content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"M6fjjCc2w6rX","executionInfo":{"status":"ok","timestamp":1650399045066,"user_tz":240,"elapsed":13945,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"f157bf37-2049-4cd7-ce9a-87817bf67d35"},"id":"M6fjjCc2w6rX","execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive/\n","/content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer\n"]}]},{"cell_type":"code","source":["!pip install pytorch-lightning"],"metadata":{"id":"_a4EOr95gtxR","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650399053671,"user_tz":240,"elapsed":8611,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"6d2343ce-108d-4b4a-ede7-1b68fe0cdf03"},"id":"_a4EOr95gtxR","execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting pytorch-lightning\n"," Downloading pytorch_lightning-1.6.1-py3-none-any.whl (582 kB)\n","\u001b[?25l\r\u001b[K |▋ | 10 kB 31.9 MB/s eta 0:00:01\r\u001b[K |█▏ | 20 kB 36.1 MB/s eta 0:00:01\r\u001b[K |█▊ | 30 kB 39.6 MB/s eta 0:00:01\r\u001b[K |██▎ | 40 kB 31.7 MB/s eta 0:00:01\r\u001b[K |██▉ | 51 kB 22.9 MB/s eta 0:00:01\r\u001b[K |███▍ | 61 kB 26.0 MB/s eta 0:00:01\r\u001b[K |████ | 71 kB 27.1 MB/s eta 0:00:01\r\u001b[K |████▌ | 81 kB 28.9 MB/s eta 0:00:01\r\u001b[K |█████ | 92 kB 31.1 MB/s eta 0:00:01\r\u001b[K |█████▋ | 102 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████▏ | 112 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████▊ | 122 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████▎ | 133 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████▉ | 143 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████▍ | 153 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████ | 163 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████▋ | 174 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████▏ | 184 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████▊ | 194 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████▎ | 204 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████▉ | 215 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████▍ | 225 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████ | 235 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████▌ | 245 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████ | 256 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 266 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 276 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 286 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████▎ | 296 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 307 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████▍ | 317 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████ | 327 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████▋ | 337 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 348 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 358 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████▎ | 368 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 378 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████████▍ | 389 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 399 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████▌ | 409 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 419 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████▋ | 430 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 440 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 450 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▎ | 460 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 471 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 481 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 491 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 501 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 512 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▊ | 522 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 532 kB 33.0 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▉ | 542 kB 33.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 552 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 563 kB 33.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 573 kB 33.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 582 kB 33.0 MB/s \n","\u001b[?25hRequirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.21.6)\n","Collecting pyDeprecate<0.4.0,>=0.3.1\n"," Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)\n","Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.3)\n","Requirement already satisfied: torch>=1.8.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.10.0+cu111)\n","Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.64.0)\n","Collecting fsspec[http]!=2021.06.0,>=2021.05.0\n"," Downloading fsspec-2022.3.0-py3-none-any.whl (136 kB)\n","\u001b[K |████████████████████████████████| 136 kB 69.3 MB/s \n","\u001b[?25hCollecting torchmetrics>=0.4.1\n"," Downloading torchmetrics-0.8.0-py3-none-any.whl (408 kB)\n","\u001b[K |████████████████████████████████| 408 kB 74.3 MB/s \n","\u001b[?25hRequirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.8.0)\n","Collecting PyYAML>=5.4\n"," Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n","\u001b[K |████████████████████████████████| 596 kB 63.0 MB/s \n","\u001b[?25hRequirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.1.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0)\n","Collecting aiohttp\n"," Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n","\u001b[K |████████████████████████████████| 1.1 MB 52.5 MB/s \n","\u001b[?25hRequirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (3.0.8)\n","Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.44.0)\n","Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)\n","Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (57.4.0)\n","Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)\n","Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.6)\n","Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.0)\n","Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)\n","Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.17.3)\n","Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)\n","Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)\n","Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch-lightning) (1.15.0)\n","Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)\n","Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)\n","Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)\n","Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.1)\n","Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (4.11.3)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.8.0)\n","Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.10.8)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.24.3)\n","Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.2.0)\n","Collecting asynctest==0.13.0\n"," Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)\n","Collecting frozenlist>=1.1.1\n"," Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)\n","\u001b[K |████████████████████████████████| 144 kB 72.4 MB/s \n","\u001b[?25hRequirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.0.12)\n","Collecting async-timeout<5.0,>=4.0.0a3\n"," Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n","Collecting yarl<2.0,>=1.0\n"," Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)\n","\u001b[K |████████████████████████████████| 271 kB 72.1 MB/s \n","\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.4.0)\n","Collecting multidict<7.0,>=4.5\n"," Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)\n","\u001b[K |████████████████████████████████| 94 kB 4.3 MB/s \n","\u001b[?25hCollecting aiosignal>=1.1.2\n"," Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n","Installing collected packages: multidict, frozenlist, yarl, asynctest, async-timeout, aiosignal, pyDeprecate, fsspec, aiohttp, torchmetrics, PyYAML, pytorch-lightning\n"," Attempting uninstall: PyYAML\n"," Found existing installation: PyYAML 3.13\n"," Uninstalling PyYAML-3.13:\n"," Successfully uninstalled PyYAML-3.13\n","Successfully installed PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 frozenlist-1.3.0 fsspec-2022.3.0 multidict-6.0.2 pyDeprecate-0.3.2 pytorch-lightning-1.6.1 torchmetrics-0.8.0 yarl-1.7.2\n"]}]},{"cell_type":"code","source":["!pip install gluonts\n","!pip install datasets"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"k8r9pwTMhpuZ","executionInfo":{"status":"ok","timestamp":1650399244319,"user_tz":240,"elapsed":9776,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"87f9a7f3-720f-40e4-81de-e773df9b7d1e"},"id":"k8r9pwTMhpuZ","execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting gluonts\n"," Downloading gluonts-0.9.3-py3-none-any.whl (2.8 MB)\n","\u001b[K |████████████████████████████████| 2.8 MB 24.4 MB/s \n","\u001b[?25hRequirement already satisfied: tqdm~=4.23 in /usr/local/lib/python3.7/dist-packages (from gluonts) (4.64.0)\n","Requirement already satisfied: numpy~=1.16 in /usr/local/lib/python3.7/dist-packages (from gluonts) (1.21.6)\n","Requirement already satisfied: holidays>=0.9 in /usr/local/lib/python3.7/dist-packages (from gluonts) (0.10.5.2)\n","Requirement already satisfied: toolz~=0.10 in /usr/local/lib/python3.7/dist-packages (from gluonts) (0.11.2)\n","Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (4.1.1)\n","Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (3.2.2)\n","Requirement already satisfied: pandas~=1.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (1.3.5)\n","Collecting pydantic~=1.1\n"," Downloading pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9 MB)\n","\u001b[K |████████████████████████████████| 10.9 MB 49.6 MB/s \n","\u001b[?25hRequirement already satisfied: convertdate>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.4.0)\n","Requirement already satisfied: korean-lunar-calendar in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (0.2.1)\n","Requirement already satisfied: hijri-converter in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.2.3)\n","Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.8.2)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (1.15.0)\n","Requirement already satisfied: pymeeus<=1,>=0.3.13 in /usr/local/lib/python3.7/dist-packages (from convertdate>=2.3.0->holidays>=0.9->gluonts) (0.5.11)\n","Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (1.4.2)\n","Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (3.0.8)\n","Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (0.11.0)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas~=1.0->gluonts) (2022.1)\n","Installing collected packages: pydantic, gluonts\n","Successfully installed gluonts-0.9.3 pydantic-1.9.0\n","Collecting datasets\n"," Downloading datasets-2.1.0-py3-none-any.whl (325 kB)\n","\u001b[K |████████████████████████████████| 325 kB 22.3 MB/s \n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n","Collecting xxhash\n"," Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n","\u001b[K |████████████████████████████████| 212 kB 62.1 MB/s \n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.1)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0)\n","Requirement already satisfied: pyarrow>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n","Collecting huggingface-hub<1.0.0,>=0.1.0\n"," Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)\n","\u001b[K |████████████████████████████████| 77 kB 8.6 MB/s \n","\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.3)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)\n","Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.3.0)\n","Collecting responses<0.19\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.1.1)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.6.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (3.0.8)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n","Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1\n"," Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)\n","\u001b[K |████████████████████████████████| 127 kB 74.8 MB/s \n","\u001b[?25hRequirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.7.2)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.0)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n","Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n","Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n","Installing collected packages: urllib3, xxhash, responses, huggingface-hub, datasets\n"," Attempting uninstall: urllib3\n"," Found existing installation: urllib3 1.24.3\n"," Uninstalling urllib3-1.24.3:\n"," Successfully uninstalled urllib3-1.24.3\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\u001b[0m\n","Successfully installed datasets-2.1.0 huggingface-hub-0.5.1 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0\n"]}]},{"cell_type":"code","source":["%matplotlib inline\n","from matplotlib import pyplot as plt\n","import matplotlib.dates as mdates\n","\n","from itertools import islice"],"metadata":{"id":"1XLYCBAswBhQ","executionInfo":{"status":"ok","timestamp":1650399247178,"user_tz":240,"elapsed":4,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"id":"1XLYCBAswBhQ","execution_count":7,"outputs":[]},{"cell_type":"code","source":["from gluonts.evaluation import make_evaluation_predictions, Evaluator\n","from gluonts.dataset.repository.datasets import get_dataset\n","\n","from estimator import PyraformerEstimator"],"metadata":{"id":"n0nOWRF-wFl2","executionInfo":{"status":"ok","timestamp":1650399257364,"user_tz":240,"elapsed":9132,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"61e8bd0f-f120-4518-d12f-d54e173a57c9"},"id":"n0nOWRF-wFl2","execution_count":8,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n"," \"Using `json`-module for json-handling. \"\n"]}]},{"cell_type":"code","source":["dataset = get_dataset(\"electricity\")"],"metadata":{"id":"Qzi9eE6q7x5y","executionInfo":{"status":"ok","timestamp":1650399428973,"user_tz":240,"elapsed":145255,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"1a4f1d8e-ef3d-466a-9dd7-e11edf30619e"},"id":"Qzi9eE6q7x5y","execution_count":10,"outputs":[{"output_type":"stream","name":"stdout","text":["saving time-series into /root/.mxnet/gluon-ts/datasets/electricity/train/data.json\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/gluonts/dataset/util.py:32: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return ts + ts.freq * amount\n"]},{"output_type":"stream","name":"stdout","text":["saving time-series into /root/.mxnet/gluon-ts/datasets/electricity/test/data.json\n"]}]},{"cell_type":"code","source":["estimator = PyraformerEstimator(\n"," freq=dataset.metadata.freq,\n"," prediction_length=dataset.metadata.prediction_length,\n"," num_feat_static_cat=1,\n"," cardinality=[321],\n"," batch_size=1,\n"," num_batches_per_epoch=100,\n"," trainer_kwargs=dict(max_epochs=1, accelerator='gpu', gpus=1),)"],"metadata":{"id":"i7AV93A07sQa","executionInfo":{"status":"ok","timestamp":1650399441551,"user_tz":240,"elapsed":142,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"id":"i7AV93A07sQa","execution_count":11,"outputs":[]},{"cell_type":"code","source":["predictor = estimator.train(training_data=dataset.train,num_workers=8)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"hXjDU6rhK9H_","executionInfo":{"status":"error","timestamp":1650399455047,"user_tz":240,"elapsed":10667,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"03fcbe32-f967-4260-c2e2-a6f1741fe328"},"id":"hXjDU6rhK9H_","execution_count":12,"outputs":[{"output_type":"stream","name":"stdout","text":["axis=-1 min_past=0 min_future=24 num_instances=1.0 total_length=0 n=0\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n"," f\"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing.\"\n","/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n"," f\"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing.\"\n","/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n"," cpuset_checked))\n","GPU available: True, used: True\n","TPU available: False, using: 0 TPU cores\n","IPU available: False, using: 0 IPUs\n","HPU available: False, using: 0 HPUs\n","/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n"," rank_zero_warn(\"You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\")\n","Missing logger folder: /content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer/lightning_logs\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","\n"," | Name | Type | Params\n","--------------------------------------------\n","0 | model | PyraformerSSModel | 6.6 M \n","1 | loss | SingleStepLoss | 0 \n","--------------------------------------------\n","6.6 M Trainable params\n","0 Non-trainable params\n","6.6 M Total params\n","26.580 Total estimated model params size (MB)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n"]},{"output_type":"error","ename":"RuntimeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpredictor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraining_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/torch/model/estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, training_data, validation_data, num_workers, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0mcache_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcache_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 198\u001b[0m ).predictor\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/torch/model/estimator.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(self, training_data, validation_data, num_workers, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtraining_data_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalidation_data_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m )\n\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 768\u001b[0m self._call_and_handle_interrupt(\n\u001b[0;32m--> 769\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 770\u001b[0m )\n\u001b[1;32m 771\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 719\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlauncher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 721\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 722\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 723\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 807\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_provided\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_connected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 808\u001b[0m )\n\u001b[0;32m--> 809\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 810\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 811\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1220\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_lightning_module_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"on_fit_start\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1222\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_log_hyperparams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1224\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore_checkpoint_after_setup\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_log_hyperparams\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1288\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloggers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhparams_initial\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1290\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_hyperparams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams_initial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1291\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1292\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/rank_zero.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapped_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrank_zero_only\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrank\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loggers/tensorboard.py\u001b[0m in \u001b[0;36mlog_hyperparams\u001b[0;34m(self, params, metrics)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# format params into the suitable for tensorboard\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_flatten_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loggers/tensorboard.py\u001b[0m in \u001b[0;36m_sanitize_params\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 317\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 318\u001b[0m \u001b[0;31m# logging of arrays with dimension > 1 is not supported, sanitize as string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/logger.py\u001b[0m in \u001b[0;36m_sanitize_params\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/component.py\u001b[0m in \u001b[0;36mvalidated_repr\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidated_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdump_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidated_getnewargs_ex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_repr.py\u001b[0m in \u001b[0;36mdump_code\u001b[0;34m(o)\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mas_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;31m# args need to be a list, since we encode tuples explicitly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0;34m\"args\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m \u001b[0;34m\"kwargs\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 221\u001b[0m }\n\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mvalmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/toolz/dicttoolz.py\u001b[0m in \u001b[0;36mvalmap\u001b[0;34m(func, d, factory)\u001b[0m\n\u001b[1;32m 81\u001b[0m \"\"\"\n\u001b[1;32m 82\u001b[0m \u001b[0mrv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfactory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mrv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbad_type_msg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfqname_for\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: Cannot serialize type torch.device. See the documentation of the `encode` and\n`validate` functions at\n\n http://gluon-ts.mxnet.io/api/gluonts/gluonts.html\n\nand the Python documentation of the `__getnewargs_ex__` magic method at\n\n https://docs.python.org/3/library/pickle.html#object.__getnewargs_ex__\n\nfor more information how to make this type serializable.\n"]}]},{"cell_type":"code","execution_count":null,"id":"d61f32ab","metadata":{"id":"d61f32ab"},"outputs":[],"source":["# plt.figure(figsize=(20, 15))\n","# date_formater = mdates.DateFormatter('%b, %d')\n","# plt.rcParams.update({'font.size': 15})\n","\n","# for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):\n","# ax = plt.subplot(3, 3, idx+1)\n","\n","# plt.plot(ts[-4 * dataset.metadata.prediction_length:], label=\"target\", )\n","# forecast.plot( color='g')\n","# plt.xticks(rotation=60)\n","# ax.xaxis.set_major_formatter(date_formater)\n","\n","# plt.gcf().tight_layout()\n","# plt.legend()\n","# plt.show()"]},{"cell_type":"code","execution_count":null,"id":"d494463f","metadata":{"id":"d494463f"},"outputs":[],"source":["# def plot_prob_forecasts(ts_entry, forecast_entry):\n","# plot_length = 70\n","# prediction_intervals = (50.0, 90.0)\n","# legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1]\n","\n","# fig, ax = plt.subplots(1, 1, figsize=(10, 7))\n","# ts_entry[-plot_length:].plot(ax=ax) # plot the time series\n","# forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')\n","# plt.grid(which=\"both\")\n","# plt.legend(legend, loc=\"best\")\n","# plt.show()"]},{"cell_type":"code","execution_count":null,"id":"5256fde1","metadata":{"id":"5256fde1"},"outputs":[],"source":["# index = 123\n","# plot_prob_forecasts(tss[index], forecasts[index])"]},{"cell_type":"code","execution_count":null,"id":"66a41556","metadata":{"id":"66a41556"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.7"},"colab":{"name":"pyraformer.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5} \ No newline at end of file +{"cells":[{"cell_type":"code","execution_count":1,"id":"b19f0e22","metadata":{"id":"b19f0e22","executionInfo":{"status":"ok","timestamp":1650430416284,"user_tz":240,"elapsed":148,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"outputs":[],"source":["%matplotlib inline"]},{"cell_type":"code","source":["# from google.colab import drive\n","# drive.mount('/content/drive/')\n","# %cd /content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer"],"metadata":{"id":"M6fjjCc2w6rX","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650430417880,"user_tz":240,"elapsed":939,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"041886cb-9371-4348-b4a2-752c7005547f"},"id":"M6fjjCc2w6rX","execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount(\"/content/drive/\", force_remount=True).\n","/content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer\n"]}]},{"cell_type":"code","source":["# !pip install black"],"metadata":{"id":"ew8UZJYmnUbU","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650430422932,"user_tz":240,"elapsed":4326,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"09dce130-ac4f-4189-cc22-51e89131ff7a"},"id":"ew8UZJYmnUbU","execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: black in /usr/local/lib/python3.7/dist-packages (22.3.0)\n","Requirement already satisfied: pathspec>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from black) (0.9.0)\n","Requirement already satisfied: typing-extensions>=3.10.0.0 in /usr/local/lib/python3.7/dist-packages (from black) (4.1.1)\n","Requirement already satisfied: mypy-extensions>=0.4.3 in /usr/local/lib/python3.7/dist-packages (from black) (0.4.3)\n","Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.7/dist-packages (from black) (8.1.2)\n","Requirement already satisfied: typed-ast>=1.4.2 in /usr/local/lib/python3.7/dist-packages (from black) (1.5.3)\n","Requirement already satisfied: tomli>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from black) (2.0.1)\n","Requirement already satisfied: platformdirs>=2 in /usr/local/lib/python3.7/dist-packages (from black) (2.5.2)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from click>=8.0.0->black) (4.11.3)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->click>=8.0.0->black) (3.8.0)\n"]}]},{"cell_type":"code","source":["# !black estimator.py"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Iu_BN-wKnVDw","executionInfo":{"status":"ok","timestamp":1650430245665,"user_tz":240,"elapsed":396,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"c6906f00-aaf6-4f52-ba22-3f50305fb90f"},"id":"Iu_BN-wKnVDw","execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[1mAll done! ✨ 🍰 ✨\u001b[0m\n","\u001b[34m1 file \u001b[0mleft unchanged.\n"]}]},{"cell_type":"code","source":["!pip install pytorch-lightning\n","!pip install gluonts\n","!pip install datasets"],"metadata":{"id":"_a4EOr95gtxR","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650430437557,"user_tz":240,"elapsed":9242,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"bb9d16df-3603-41d0-ef65-42bf02a5c018"},"id":"_a4EOr95gtxR","execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.7/dist-packages (1.6.1)\n","Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.8.0)\n","Requirement already satisfied: torchmetrics>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (0.8.0)\n","Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.3)\n","Requirement already satisfied: pyDeprecate<0.4.0,>=0.3.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (0.3.2)\n","Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2022.3.0)\n","Requirement already satisfied: torch>=1.8.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.10.0+cu111)\n","Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.1.1)\n","Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (6.0)\n","Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.64.0)\n","Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.21.6)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.8.1)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (3.0.8)\n","Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.6)\n","Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.44.0)\n","Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (57.4.0)\n","Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.17.3)\n","Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)\n","Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)\n","Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.0)\n","Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)\n","Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)\n","Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)\n","Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch-lightning) (1.15.0)\n","Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)\n","Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)\n","Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)\n","Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.1)\n","Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (4.11.3)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.8.0)\n","Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.25.11)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4)\n","Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.2.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.4.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.3.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (4.0.2)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.7.2)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (6.0.2)\n","Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (0.13.0)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.2.0)\n","Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.0.12)\n","Requirement already satisfied: gluonts in /usr/local/lib/python3.7/dist-packages (0.9.3)\n","Requirement already satisfied: holidays>=0.9 in /usr/local/lib/python3.7/dist-packages (from gluonts) (0.10.5.2)\n","Requirement already satisfied: toolz~=0.10 in /usr/local/lib/python3.7/dist-packages (from gluonts) (0.11.2)\n","Requirement already satisfied: numpy~=1.16 in /usr/local/lib/python3.7/dist-packages (from gluonts) (1.21.6)\n","Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (4.1.1)\n","Requirement already satisfied: pandas~=1.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (1.3.5)\n","Requirement already satisfied: pydantic~=1.1 in /usr/local/lib/python3.7/dist-packages (from gluonts) (1.9.0)\n","Requirement already satisfied: tqdm~=4.23 in /usr/local/lib/python3.7/dist-packages (from gluonts) (4.64.0)\n","Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.7/dist-packages (from gluonts) (3.2.2)\n","Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.8.2)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (1.15.0)\n","Requirement already satisfied: convertdate>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.4.0)\n","Requirement already satisfied: hijri-converter in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (2.2.3)\n","Requirement already satisfied: korean-lunar-calendar in /usr/local/lib/python3.7/dist-packages (from holidays>=0.9->gluonts) (0.2.1)\n","Requirement already satisfied: pymeeus<=1,>=0.3.13 in /usr/local/lib/python3.7/dist-packages (from convertdate>=2.3.0->holidays>=0.9->gluonts) (0.5.11)\n","Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (3.0.8)\n","Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (0.11.0)\n","Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib~=3.0->gluonts) (1.4.2)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas~=1.0->gluonts) (2022.1)\n","Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.3)\n","Requirement already satisfied: pyarrow>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)\n","Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.3.0)\n","Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.1)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n","Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.18.0)\n","Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.5.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.1.1)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.6.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (3.0.8)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2021.10.8)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.25.11)\n","Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.0)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.7.2)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n","Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n"]}]},{"cell_type":"code","source":["%matplotlib inline\n","from matplotlib import pyplot as plt\n","import matplotlib.dates as mdates\n","\n","from itertools import islice"],"metadata":{"id":"1XLYCBAswBhQ","executionInfo":{"status":"ok","timestamp":1650430437558,"user_tz":240,"elapsed":7,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"id":"1XLYCBAswBhQ","execution_count":5,"outputs":[]},{"cell_type":"code","source":["from gluonts.evaluation import make_evaluation_predictions, Evaluator\n","from gluonts.dataset.repository.datasets import get_dataset\n","\n","from estimator import PyraformerEstimator"],"metadata":{"id":"n0nOWRF-wFl2","executionInfo":{"status":"ok","timestamp":1650430444238,"user_tz":240,"elapsed":2207,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"22b11d98-cc25-4ba0-af98-320628468adc"},"id":"n0nOWRF-wFl2","execution_count":7,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n"," \"Using `json`-module for json-handling. \"\n"]}]},{"cell_type":"code","source":["dataset = get_dataset(\"electricity\")"],"metadata":{"id":"Qzi9eE6q7x5y","executionInfo":{"status":"ok","timestamp":1650430444238,"user_tz":240,"elapsed":3,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"id":"Qzi9eE6q7x5y","execution_count":8,"outputs":[]},{"cell_type":"code","source":["estimator = PyraformerEstimator(\n"," freq=dataset.metadata.freq,\n"," prediction_length=dataset.metadata.prediction_length,\n"," num_feat_static_cat=1,\n"," cardinality=[321],\n"," single_step= True,\n"," batch_size=1,\n"," num_batches_per_epoch=100,\n"," trainer_kwargs=dict(max_epochs=1, accelerator='gpu', gpus=1\n","))"],"metadata":{"id":"i7AV93A07sQa","executionInfo":{"status":"ok","timestamp":1650430475324,"user_tz":240,"elapsed":139,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}}},"id":"i7AV93A07sQa","execution_count":11,"outputs":[]},{"cell_type":"code","source":["predictor = estimator.train(training_data=dataset.train,num_workers=2)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"hXjDU6rhK9H_","executionInfo":{"status":"error","timestamp":1650430476900,"user_tz":240,"elapsed":766,"user":{"displayName":"Hena Ghonia","userId":"03246241722682988409"}},"outputId":"086996d5-ccda-4fb4-8a64-861b8c732bab"},"id":"hXjDU6rhK9H_","execution_count":12,"outputs":[{"output_type":"stream","name":"stdout","text":["axis=-1 min_past=0 min_future=24 num_instances=1.0 total_length=0 n=0\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n"," f\"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing.\"\n","/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n"," f\"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing.\"\n","GPU available: True, used: True\n","TPU available: False, using: 0 TPU cores\n","IPU available: False, using: 0 IPUs\n","HPU available: False, using: 0 HPUs\n","/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n"," rank_zero_warn(\"You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\")\n","Missing logger folder: /content/drive/MyDrive/Udem/Sem2/Representation_Learning/IFT6135_Programming/Pyraformer/transformer/lightning_logs\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","\n"," | Name | Type | Params\n","--------------------------------------------\n","0 | model | PyraformerSSModel | 6.6 M \n","1 | loss | SingleStepLoss | 0 \n","--------------------------------------------\n","6.6 M Trainable params\n","0 Non-trainable params\n","6.6 M Total params\n","26.580 Total estimated model params size (MB)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:324: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp = pd.Timestamp(timestamp_input, freq=freq)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:327: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if isinstance(timestamp.freq, Tick):\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/dataset/common.py:329: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.\n"," timestamp.floor(timestamp.freq), timestamp.freq\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:343: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base = start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," return _shift_timestamp_helper(ts, ts.freq, offset)\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:340: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," self._freq_base is None or self._freq_base == start.freq.base\n","/usr/local/lib/python3.7/dist-packages/gluonts/transform/feature.py:386: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version\n"," if self._full_range_date_features is not None\n"]},{"output_type":"error","ename":"RuntimeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpredictor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraining_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/torch/model/estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, training_data, validation_data, num_workers, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0mcache_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcache_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 198\u001b[0m ).predictor\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/torch/model/estimator.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(self, training_data, validation_data, num_workers, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtraining_data_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalidation_data_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m )\n\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 768\u001b[0m self._call_and_handle_interrupt(\n\u001b[0;32m--> 769\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 770\u001b[0m )\n\u001b[1;32m 771\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 719\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlauncher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 721\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 722\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 723\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 807\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_provided\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_connected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 808\u001b[0m )\n\u001b[0;32m--> 809\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 810\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 811\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1220\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_lightning_module_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"on_fit_start\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1222\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_log_hyperparams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1224\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore_checkpoint_after_setup\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_log_hyperparams\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1288\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloggers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhparams_initial\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1290\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_hyperparams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhparams_initial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1291\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1292\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/rank_zero.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapped_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrank_zero_only\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrank\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loggers/tensorboard.py\u001b[0m in \u001b[0;36mlog_hyperparams\u001b[0;34m(self, params, metrics)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# format params into the suitable for tensorboard\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_flatten_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loggers/tensorboard.py\u001b[0m in \u001b[0;36m_sanitize_params\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 317\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils_sanitize_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 318\u001b[0m \u001b[0;31m# logging of arrays with dimension > 1 is not supported, sanitize as string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/logger.py\u001b[0m in \u001b[0;36m_sanitize_params\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/component.py\u001b[0m in \u001b[0;36mvalidated_repr\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidated_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdump_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvalidated_getnewargs_ex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_repr.py\u001b[0m in \u001b[0;36mdump_code\u001b[0;34m(o)\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mas_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;31m# args need to be a list, since we encode tuples explicitly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0;34m\"args\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m \u001b[0;34m\"kwargs\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 221\u001b[0m }\n\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mvalmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/toolz/dicttoolz.py\u001b[0m in \u001b[0;36mvalmap\u001b[0;34m(func, d, factory)\u001b[0m\n\u001b[1;32m 81\u001b[0m \"\"\"\n\u001b[1;32m 82\u001b[0m \u001b[0mrv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfactory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mrv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/functools.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 838\u001b[0m '1 positional argument')\n\u001b[1;32m 839\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 840\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 841\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[0mfuncname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__name__'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'singledispatch function'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/gluonts/core/serde/_base.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(v)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbad_type_msg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfqname_for\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: Cannot serialize type torch.device. See the documentation of the `encode` and\n`validate` functions at\n\n http://gluon-ts.mxnet.io/api/gluonts/gluonts.html\n\nand the Python documentation of the `__getnewargs_ex__` magic method at\n\n https://docs.python.org/3/library/pickle.html#object.__getnewargs_ex__\n\nfor more information how to make this type serializable.\n"]}]},{"cell_type":"code","execution_count":null,"id":"d61f32ab","metadata":{"id":"d61f32ab"},"outputs":[],"source":["# plt.figure(figsize=(20, 15))\n","# date_formater = mdates.DateFormatter('%b, %d')\n","# plt.rcParams.update({'font.size': 15})\n","\n","# for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):\n","# ax = plt.subplot(3, 3, idx+1)\n","\n","# plt.plot(ts[-4 * dataset.metadata.prediction_length:], label=\"target\", )\n","# forecast.plot( color='g')\n","# plt.xticks(rotation=60)\n","# ax.xaxis.set_major_formatter(date_formater)\n","\n","# plt.gcf().tight_layout()\n","# plt.legend()\n","# plt.show()"]},{"cell_type":"code","execution_count":null,"id":"d494463f","metadata":{"id":"d494463f"},"outputs":[],"source":["# def plot_prob_forecasts(ts_entry, forecast_entry):\n","# plot_length = 70\n","# prediction_intervals = (50.0, 90.0)\n","# legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1]\n","\n","# fig, ax = plt.subplots(1, 1, figsize=(10, 7))\n","# ts_entry[-plot_length:].plot(ax=ax) # plot the time series\n","# forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')\n","# plt.grid(which=\"both\")\n","# plt.legend(legend, loc=\"best\")\n","# plt.show()"]},{"cell_type":"code","execution_count":null,"id":"5256fde1","metadata":{"id":"5256fde1"},"outputs":[],"source":["# index = 123\n","# plot_prob_forecasts(tss[index], forecasts[index])"]},{"cell_type":"code","execution_count":null,"id":"66a41556","metadata":{"id":"66a41556"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.7"},"colab":{"name":"pyraformer.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5} \ No newline at end of file diff --git a/pyraformer/pyraformer/Layers.py b/pyraformer/pyraformer/Layers.py index 2cd96af..dda4314 100644 --- a/pyraformer/pyraformer/Layers.py +++ b/pyraformer/pyraformer/Layers.py @@ -378,10 +378,10 @@ class Predictor(nn.Module): class Decoder(nn.Module): """ A encoder model with self attention mechanism. """ - def __init__(self, model,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask): + def __init__(self,d_model,d_inner_hid,num_head,d_k,d_v,dropout,enc_in,covariate_size,seq_num, mask): super().__init__() - self.model_type = model + # self.model_type = model self.mask = mask self.layers = nn.ModuleList([