remove freq from PyTorchPredictor

This commit is contained in:
Kashif Rasul
2022-06-17 19:22:59 +02:00
parent 296f9c360a
commit ab0c038700
12 changed files with 2 additions and 14 deletions
-1
View File
@@ -278,7 +278,6 @@ class AutoformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -273,7 +273,6 @@ class ETSformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -274,7 +274,6 @@ class HopfieldEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -256,7 +256,6 @@ class HopfieldEstimator(PyTorchEstimator):
input_names=input_names,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=device,
)
-1
View File
@@ -281,7 +281,6 @@ class InformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -325,7 +325,6 @@ class PyraformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -270,7 +270,6 @@ class ReformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
+1 -2
View File
File diff suppressed because one or more lines are too long
-1
View File
@@ -284,7 +284,6 @@ class SwitchTransformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -544,7 +544,6 @@
" input_names=PREDICTION_INPUT_NAMES,\n",
" prediction_net=module.model,\n",
" batch_size=self.batch_size,\n",
" freq=self.freq,\n",
" prediction_length=self.prediction_length,\n",
" )\n",
"\n",
-1
View File
@@ -274,7 +274,6 @@ class TransformerEstimator(PyTorchLightningEstimator):
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
-1
View File
@@ -821,7 +821,6 @@
" input_names=PREDICTION_INPUT_NAMES,\n",
" prediction_net=module.model,\n",
" batch_size=self.batch_size,\n",
" freq=self.freq,\n",
" prediction_length=self.prediction_length,\n",
" device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
" )\n",