From 499dae83f2b78c905c348406049941a6099e56af Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Sun, 11 Apr 2021 09:52:16 +0200 Subject: [PATCH] set numpy random state based on worker_id https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ --- pts/model/estimator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pts/model/estimator.py b/pts/model/estimator.py index 85863b9..5044893 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -118,6 +118,7 @@ class PyTorchEstimator(Estimator): num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, + worker_init_fn=self._worker_init_fn, **kwargs, ) @@ -138,6 +139,7 @@ class PyTorchEstimator(Estimator): num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, + worker_init_fn=self._worker_init_fn, **kwargs, ) @@ -155,6 +157,10 @@ class PyTorchEstimator(Estimator): ), ) + @staticmethod + def _worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + def train( self, training_data: Dataset,