mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
@@ -221,7 +221,7 @@
|
||||
" weight_decay: float = 1e-8,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__()\n",
|
||||
" self.save_hyperparameters(ignore=['loss', 'model'])\n",
|
||||
" self.save_hyperparameters()\n",
|
||||
" self.model = model\n",
|
||||
" self.loss = loss\n",
|
||||
" self.lr = lr\n",
|
||||
|
||||
@@ -15,7 +15,7 @@ class TFTLightningModule(pl.LightningModule):
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=["loss", "model"])
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
|
||||
@@ -15,7 +15,7 @@ class TransformerLightningModule(pl.LightningModule):
|
||||
weight_decay: float = 1e-8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=["loss", "model"])
|
||||
self.save_hyperparameters()
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lr = lr
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user