set numpy random state based on worker_id

https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
This commit is contained in:
Dr. Kashif Rasul
2021-04-11 09:52:16 +02:00
parent 4e97483fab
commit 499dae83f2
+6
View File
@@ -118,6 +118,7 @@ class PyTorchEstimator(Estimator):
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
worker_init_fn=self._worker_init_fn,
**kwargs,
)
@@ -138,6 +139,7 @@ class PyTorchEstimator(Estimator):
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
worker_init_fn=self._worker_init_fn,
**kwargs,
)
@@ -155,6 +157,10 @@ class PyTorchEstimator(Estimator):
),
)
@staticmethod
def _worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def train(
self,
training_data: Dataset,