diff --git a/python/ray/experimental/sgd/param_server.py b/python/ray/experimental/sgd/param_server.py index 517d419c3..a69727722 100644 --- a/python/ray/experimental/sgd/param_server.py +++ b/python/ray/experimental/sgd/param_server.py @@ -67,6 +67,7 @@ class ParameterServer(object): client = ray.worker.global_worker.plasma_client assert self.acc_counter == self.num_sgd_workers, self.acc_counter oid = ray.pyarrow.plasma.ObjectID(object_id) + self.accumulated /= self.acc_counter client.put(self.accumulated.flatten(), object_id=oid) self.accumulated = np.zeros_like(self.accumulated) self.acc_counter = 0