Files
pytorch-a2c-ppo-acktr/running_stat.py
T
Ilya Kostrikov 09e75e26ae Add MuJoCo
2017-09-27 08:20:19 -04:00

45 lines
1.2 KiB
Python

import random
import torch
class ObsNorm(object):
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.demean = demean
self.destd = destd
self.clip = clip
self.count = torch.zeros(1).double() + 1e-2
self.sum = torch.zeros(shape).double()
self.sum_sqr = torch.zeros(shape).double() + 1e-2
self.mean = torch.zeros(shape)
self.std = torch.ones(shape)
def cuda(self):
self.count = self.count.cuda()
self.sum = self.sum.cuda()
self.sum_sqr = self.sum_sqr.cuda()
self.mean = self.mean.cuda()
self.std = self.std.cuda()
def update(self, x):
self.count += x.size(0)
self.sum += x.sum(0, keepdim=True).double()
self.sum_sqr += x.pow(2).sum(0, keepdim=True).double()
self.mean = self.sum / self.count
self.std = (self.sum_sqr / self.count - self.mean.pow(2)).clamp(1e-2, 1e9).sqrt()
self.mean = self.mean.float()
self.std = self.std.float()
def __call__(self, x):
if self.demean:
x = x - self.mean
if self.destd:
x = x / self.std
if self.clip:
x = x.clamp(-self.clip, self.clip)
return x