Fixing MeanSTDFilter (#1101)

* Fixing MeanSTDFilter

* documentation
This commit is contained in:
Richard Liaw 2017-10-11 18:16:08 -07:00 committed by Robert Nishihara
parent 46f6c163dc
commit 379b0604b4

View file

@ -86,15 +86,27 @@ class RunningStat(object):
class MeanStdFilter(object):
"""Keeps track of a running mean for seen states"""
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.shape = shape
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
# In distributed rollouts, each worker sees different states.
# The buffer is used to keep track of deltas amongst all the
# observation filters.
self.buffer = RunningStat(shape)
def clear_buffer(self):
self.buffer = RunningStat(self.shape)
def update(self, other):
self.rs.update(other.rs)
# `update` takes another filter and
# only applies the information from the buffer.
self.rs.update(other.buffer)
def copy(self):
other = MeanStdFilter(self.shape)
@ -102,6 +114,7 @@ class MeanStdFilter(object):
other.destd = self.destd
other.clip = self.clip
other.rs = self.rs.copy()
other.buffer = self.buffer.copy()
return other
def __call__(self, x, update=True):
@ -111,9 +124,11 @@ class MeanStdFilter(object):
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
self.buffer.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
self.buffer.push(x[i])
if self.demean:
x = x - self.rs.mean
if self.destd: