Average aggregated gradients before put in plasma store (#3631)

This commit is contained in:
Stan Wang
2018-12-26 19:03:11 +08:00
committed by Eric Liang
parent 4cde971916
commit 4ce3818be5
@@ -67,6 +67,7 @@ class ParameterServer(object):
client = ray.worker.global_worker.plasma_client
assert self.acc_counter == self.num_sgd_workers, self.acc_counter
oid = ray.pyarrow.plasma.ObjectID(object_id)
self.accumulated /= self.acc_counter
client.put(self.accumulated.flatten(), object_id=oid)
self.accumulated = np.zeros_like(self.accumulated)
self.acc_counter = 0