Files
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

113 lines
3.8 KiB
Python

from mpi4py import MPI
import tensorflow as tf
import baselines.baselines_common.tf_util as U
import numpy as np
class RunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-2, shape=()):
self._sum = tf.get_variable(
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(0.0),
name="runningsum", trainable=False)
self._sumsq = tf.get_variable(
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(epsilon),
name="runningsumsq", trainable=False)
self._count = tf.get_variable(
dtype=tf.float64,
shape=(),
initializer=tf.constant_initializer(epsilon),
name="count", trainable=False)
self.shape = shape
self.mean = tf.to_float(self._sum / self._count)
self.std = tf.sqrt(
tf.maximum(tf.to_float(self._sumsq / self._count) - tf.square(self.mean), 1e-2))
newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
updates=[tf.assign_add(self._sum, newsum),
tf.assign_add(self._sumsq, newsumsq),
tf.assign_add(self._count, newcount)])
def update(self, x):
x = x.astype('float64')
n = int(np.prod(self.shape))
totalvec = np.zeros(n * 2 + 1, 'float64')
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(),
np.array([len(x)], dtype='float64')])
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2 * n].reshape(self.shape),
totalvec[2 * n])
@U.in_session
def test_runningmeanstd():
for (x1, x2, x3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
]:
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
U.initialize()
x = np.concatenate([x1, x2, x3], axis=0)
ms1 = [x.mean(axis=0), x.std(axis=0)]
rms.update(x1)
rms.update(x2)
rms.update(x3)
ms2 = U.eval([rms.mean, rms.std])
assert np.allclose(ms1, ms2)
@U.in_session
def test_dist():
np.random.seed(0)
p1, p2, p3 = (np.random.randn(3, 1), np.random.randn(4, 1), np.random.randn(5, 1))
q1, q2, q3 = (np.random.randn(6, 1), np.random.randn(7, 1), np.random.randn(8, 1))
# p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
# q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
comm = MPI.COMM_WORLD
assert comm.Get_size() == 2
if comm.Get_rank() == 0:
x1, x2, x3 = p1, p2, p3
elif comm.Get_rank() == 1:
x1, x2, x3 = q1, q2, q3
else:
assert False
rms = RunningMeanStd(epsilon=0.0, shape=(1,))
U.initialize()
rms.update(x1)
rms.update(x2)
rms.update(x3)
bigvec = np.concatenate([p1, p2, p3, q1, q2, q3])
def checkallclose(x, y):
print(x, y)
return np.allclose(x, y)
assert checkallclose(
bigvec.mean(axis=0),
U.eval(rms.mean)
)
assert checkallclose(
bigvec.std(axis=0),
U.eval(rms.std)
)
if __name__ == "__main__":
# Run with mpirun -np 2 python <filename>
test_dist()