mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
Average aggregated gradients before put in plasma store (#3631)
This commit is contained in:
@@ -67,6 +67,7 @@ class ParameterServer(object):
|
||||
client = ray.worker.global_worker.plasma_client
|
||||
assert self.acc_counter == self.num_sgd_workers, self.acc_counter
|
||||
oid = ray.pyarrow.plasma.ObjectID(object_id)
|
||||
self.accumulated /= self.acc_counter
|
||||
client.put(self.accumulated.flatten(), object_id=oid)
|
||||
self.accumulated = np.zeros_like(self.accumulated)
|
||||
self.acc_counter = 0
|
||||
|
||||
Reference in New Issue
Block a user