mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 20:22:32 +08:00
stick with one cycle for now
This commit is contained in:
+2
-15
@@ -23,10 +23,7 @@ class Trainer:
|
||||
num_batches_per_epoch: int = 50,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-6,
|
||||
learning_rate_decay_factor: float = 0.5,
|
||||
patience: int = 10,
|
||||
minimum_learning_rate: float = 5e-5,
|
||||
maximum_learning_rate: float = 0.01,
|
||||
maximum_learning_rate: float = 1e-2,
|
||||
clip_gradient: Optional[float] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
**kwargs,
|
||||
@@ -36,9 +33,6 @@ class Trainer:
|
||||
self.num_batches_per_epoch = num_batches_per_epoch
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.learning_rate_decay_factor = learning_rate_decay_factor
|
||||
self.patience = patience
|
||||
self.minimum_learning_rate = minimum_learning_rate
|
||||
self.maximum_learning_rate = maximum_learning_rate
|
||||
self.clip_gradient = clip_gradient
|
||||
self.device = device
|
||||
@@ -64,13 +58,6 @@ class Trainer:
|
||||
steps_per_epoch=self.num_batches_per_epoch,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
# lr_scheduler = ReduceLROnPlateau(
|
||||
# optimizer,
|
||||
# mode='min',
|
||||
# factor=self.learning_rate_decay_factor,
|
||||
# patience=self.patience,
|
||||
# min_lr=self.minimum_learning_rate,
|
||||
# )
|
||||
|
||||
for epoch_no in range(self.epochs):
|
||||
# mark epoch start time
|
||||
@@ -101,11 +88,11 @@ class Trainer:
|
||||
loss.backward()
|
||||
if self.clip_gradient is not None:
|
||||
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
# lr_scheduler.step(avg_epoch_loss / batch_no)
|
||||
break
|
||||
|
||||
# mark epoch end time and log time cost of current epoch
|
||||
|
||||
Reference in New Issue
Block a user