mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 15:16:27 +08:00
remove freq from PyTorchPredictor
This commit is contained in:
@@ -278,7 +278,6 @@ class AutoformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -273,7 +273,6 @@ class ETSformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -274,7 +274,6 @@ class HopfieldEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -256,7 +256,6 @@ class HopfieldEstimator(PyTorchEstimator):
|
||||
input_names=input_names,
|
||||
prediction_net=prediction_network,
|
||||
batch_size=self.trainer.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -281,7 +281,6 @@ class InformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -325,7 +325,6 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -270,7 +270,6 @@ class ReformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
+2
-3
File diff suppressed because one or more lines are too long
@@ -284,7 +284,6 @@ class SwitchTransformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -544,7 +544,6 @@
|
||||
" input_names=PREDICTION_INPUT_NAMES,\n",
|
||||
" prediction_net=module.model,\n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" freq=self.freq,\n",
|
||||
" prediction_length=self.prediction_length,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
|
||||
@@ -274,7 +274,6 @@ class TransformerEstimator(PyTorchLightningEstimator):
|
||||
input_names=PREDICTION_INPUT_NAMES,
|
||||
prediction_net=module.model,
|
||||
batch_size=self.batch_size,
|
||||
freq=self.freq,
|
||||
prediction_length=self.prediction_length,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
@@ -821,7 +821,6 @@
|
||||
" input_names=PREDICTION_INPUT_NAMES,\n",
|
||||
" prediction_net=module.model,\n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" freq=self.freq,\n",
|
||||
" prediction_length=self.prediction_length,\n",
|
||||
" device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
|
||||
" )\n",
|
||||
|
||||
Reference in New Issue
Block a user