Files
Run-Skeleton-Run/baselines/baselines_common/mpi_moments.py
T
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

53 lines
1.6 KiB
Python

from mpi4py import MPI
import numpy as np
from baselines.baselines_common import zipsame
def mpi_moments(x, axis=0):
x = np.asarray(x, dtype='float64')
newshape = list(x.shape)
newshape.pop(axis)
n = np.prod(newshape, dtype=int)
totalvec = np.zeros(n * 2 + 1, 'float64')
addvec = np.concatenate([x.sum(axis=axis).ravel(),
np.square(x).sum(axis=axis).ravel(),
np.array([x.shape[axis]], dtype='float64')])
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
sum = totalvec[:n]
sumsq = totalvec[n:2 * n]
count = totalvec[2 * n]
if count == 0:
mean = np.empty(newshape);
mean[:] = np.nan
std = np.empty(newshape);
std[:] = np.nan
else:
mean = sum / count
std = np.sqrt(np.maximum(sumsq / count - np.square(mean), 0))
return mean, std, count
def test_runningmeanstd():
comm = MPI.COMM_WORLD
np.random.seed(0)
for (triple, axis) in [
((np.random.randn(3), np.random.randn(4), np.random.randn(5)), 0),
((np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), 0),
((np.random.randn(2, 3), np.random.randn(2, 4), np.random.randn(2, 4)), 1),
]:
x = np.concatenate(triple, axis=axis)
ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
ms2 = mpi_moments(triple[comm.Get_rank()], axis=axis)
for (a1, a2) in zipsame(ms1, ms2):
print(a1, a2)
assert np.allclose(a1, a2)
print("ok!")
if __name__ == "__main__":
# mpirun -np 3 python <script>
test_runningmeanstd()