mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
set numpy random state based on worker_id
https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
This commit is contained in:
@@ -118,6 +118,7 @@ class PyTorchEstimator(Estimator):
|
||||
num_workers=num_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
pin_memory=True,
|
||||
worker_init_fn=self._worker_init_fn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -138,6 +139,7 @@ class PyTorchEstimator(Estimator):
|
||||
num_workers=num_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
pin_memory=True,
|
||||
worker_init_fn=self._worker_init_fn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -155,6 +157,10 @@ class PyTorchEstimator(Estimator):
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _worker_init_fn(worker_id):
|
||||
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: Dataset,
|
||||
|
||||
Reference in New Issue
Block a user