mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
46f6c163dc
commit
379b0604b4
1 changed files with 16 additions and 1 deletions
|
@ -86,15 +86,27 @@ class RunningStat(object):
|
|||
|
||||
|
||||
class MeanStdFilter(object):
|
||||
"""Keeps track of a running mean for seen states"""
|
||||
|
||||
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
||||
self.shape = shape
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.clip = clip
|
||||
self.rs = RunningStat(shape)
|
||||
# In distributed rollouts, each worker sees different states.
|
||||
# The buffer is used to keep track of deltas amongst all the
|
||||
# observation filters.
|
||||
|
||||
self.buffer = RunningStat(shape)
|
||||
|
||||
def clear_buffer(self):
|
||||
self.buffer = RunningStat(self.shape)
|
||||
|
||||
def update(self, other):
|
||||
self.rs.update(other.rs)
|
||||
# `update` takes another filter and
|
||||
# only applies the information from the buffer.
|
||||
self.rs.update(other.buffer)
|
||||
|
||||
def copy(self):
|
||||
other = MeanStdFilter(self.shape)
|
||||
|
@ -102,6 +114,7 @@ class MeanStdFilter(object):
|
|||
other.destd = self.destd
|
||||
other.clip = self.clip
|
||||
other.rs = self.rs.copy()
|
||||
other.buffer = self.buffer.copy()
|
||||
return other
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
|
@ -111,9 +124,11 @@ class MeanStdFilter(object):
|
|||
# The vectorized case.
|
||||
for i in range(x.shape[0]):
|
||||
self.rs.push(x[i])
|
||||
self.buffer.push(x[i])
|
||||
else:
|
||||
# The unvectorized case.
|
||||
self.rs.push(x)
|
||||
self.buffer.push(x[i])
|
||||
if self.demean:
|
||||
x = x - self.rs.mean
|
||||
if self.destd:
|
||||
|
|
Loading…
Add table
Reference in a new issue