diff --git a/pts/model/estimator.py b/pts/model/estimator.py index 85863b9..5044893 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -118,6 +118,7 @@ class PyTorchEstimator(Estimator): num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, + worker_init_fn=self._worker_init_fn, **kwargs, ) @@ -138,6 +139,7 @@ class PyTorchEstimator(Estimator): num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, + worker_init_fn=self._worker_init_fn, **kwargs, ) @@ -155,6 +157,10 @@ class PyTorchEstimator(Estimator): ), ) + @staticmethod + def _worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + def train( self, training_data: Dataset,