From 4ce3818be5e291f38ebdccfdd36248aac95fa024 Mon Sep 17 00:00:00 2001 From: Stan Wang Date: Wed, 26 Dec 2018 19:03:11 +0800 Subject: [PATCH] Average aggregated gradients before put in plasma store (#3631) --- python/ray/experimental/sgd/param_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/experimental/sgd/param_server.py b/python/ray/experimental/sgd/param_server.py index 517d419c3..a69727722 100644 --- a/python/ray/experimental/sgd/param_server.py +++ b/python/ray/experimental/sgd/param_server.py @@ -67,6 +67,7 @@ class ParameterServer(object): client = ray.worker.global_worker.plasma_client assert self.acc_counter == self.num_sgd_workers, self.acc_counter oid = ray.pyarrow.plasma.ObjectID(object_id) + self.accumulated /= self.acc_counter client.put(self.accumulated.flatten(), object_id=oid) self.accumulated = np.zeros_like(self.accumulated) self.acc_counter = 0