diff --git a/pts/model/estimator.py b/pts/model/estimator.py index a10429c..85863b9 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -117,6 +117,7 @@ class PyTorchEstimator(Estimator): batch_size=self.trainer.batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, + pin_memory=True, **kwargs, ) @@ -136,6 +137,7 @@ class PyTorchEstimator(Estimator): batch_size=self.trainer.batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, + pin_memory=True, **kwargs, )