mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 17:01:48 +08:00
remove freq from PyTorchPredictor
This commit is contained in:
@@ -278,7 +278,6 @@ class AutoformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -273,7 +273,6 @@ class ETSformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -274,7 +274,6 @@ class HopfieldEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -256,7 +256,6 @@ class HopfieldEstimator(PyTorchEstimator):
|
|||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
prediction_net=prediction_network,
|
prediction_net=prediction_network,
|
||||||
batch_size=self.trainer.batch_size,
|
batch_size=self.trainer.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -281,7 +281,6 @@ class InformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -325,7 +325,6 @@ class PyraformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -270,7 +270,6 @@ class ReformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
+2
-3
File diff suppressed because one or more lines are too long
@@ -284,7 +284,6 @@ class SwitchTransformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -544,7 +544,6 @@
|
|||||||
" input_names=PREDICTION_INPUT_NAMES,\n",
|
" input_names=PREDICTION_INPUT_NAMES,\n",
|
||||||
" prediction_net=module.model,\n",
|
" prediction_net=module.model,\n",
|
||||||
" batch_size=self.batch_size,\n",
|
" batch_size=self.batch_size,\n",
|
||||||
" freq=self.freq,\n",
|
|
||||||
" prediction_length=self.prediction_length,\n",
|
" prediction_length=self.prediction_length,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -274,7 +274,6 @@ class TransformerEstimator(PyTorchLightningEstimator):
|
|||||||
input_names=PREDICTION_INPUT_NAMES,
|
input_names=PREDICTION_INPUT_NAMES,
|
||||||
prediction_net=module.model,
|
prediction_net=module.model,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
freq=self.freq,
|
|
||||||
prediction_length=self.prediction_length,
|
prediction_length=self.prediction_length,
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -821,7 +821,6 @@
|
|||||||
" input_names=PREDICTION_INPUT_NAMES,\n",
|
" input_names=PREDICTION_INPUT_NAMES,\n",
|
||||||
" prediction_net=module.model,\n",
|
" prediction_net=module.model,\n",
|
||||||
" batch_size=self.batch_size,\n",
|
" batch_size=self.batch_size,\n",
|
||||||
" freq=self.freq,\n",
|
|
||||||
" prediction_length=self.prediction_length,\n",
|
" prediction_length=self.prediction_length,\n",
|
||||||
" device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
|
" device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user