mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
346 lines
10 KiB
Python
346 lines
10 KiB
Python
import logging
|
|
import threading
|
|
|
|
import numpy as np
|
|
import tree # pip install dm_tree
|
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.numpy import SMALL_NUMBER
|
|
from ray.rllib.utils.typing import TensorStructType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# TODO(jungong) : Add Adapters to use these filters as agent connectors.
|
|
|
|
|
|
@DeveloperAPI
|
|
class Filter:
|
|
"""Processes input, possibly statefully."""
|
|
|
|
def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
|
|
"""Updates self with "new state" from other filter."""
|
|
raise NotImplementedError
|
|
|
|
def copy(self) -> "Filter":
|
|
"""Creates a new object with same state as self.
|
|
|
|
Returns:
|
|
A copy of self.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def sync(self, other: "Filter") -> None:
|
|
"""Copies all state from other filter to self."""
|
|
raise NotImplementedError
|
|
|
|
def reset_buffer(self) -> None:
|
|
"""Creates copy of current state and resets accumulated state"""
|
|
raise NotImplementedError
|
|
|
|
def as_serializable(self) -> "Filter":
|
|
raise NotImplementedError
|
|
|
|
@Deprecated(new="Filter.reset_buffer()", error=False)
|
|
def clear_buffer(self):
|
|
return self.reset_buffer()
|
|
|
|
|
|
@DeveloperAPI
|
|
class NoFilter(Filter):
|
|
is_concurrent = True
|
|
|
|
def __call__(self, x: TensorStructType, update=True):
|
|
# Process no further if already np.ndarray, dict, or tuple.
|
|
if isinstance(x, (np.ndarray, dict, tuple)):
|
|
return x
|
|
|
|
try:
|
|
return np.asarray(x)
|
|
except Exception:
|
|
raise ValueError("Failed to convert to array", x)
|
|
|
|
def apply_changes(self, other: "NoFilter", *args, **kwargs) -> None:
|
|
pass
|
|
|
|
def copy(self) -> "NoFilter":
|
|
return self
|
|
|
|
def sync(self, other: "NoFilter") -> None:
|
|
pass
|
|
|
|
def reset_buffer(self) -> None:
|
|
pass
|
|
|
|
def as_serializable(self) -> "NoFilter":
|
|
return self
|
|
|
|
|
|
# http://www.johndcook.com/blog/standard_deviation/
|
|
@DeveloperAPI
|
|
class RunningStat:
|
|
def __init__(self, shape=None):
|
|
self._n = 0
|
|
self._M = np.zeros(shape)
|
|
self._S = np.zeros(shape)
|
|
|
|
def copy(self):
|
|
other = RunningStat()
|
|
other._n = self._n
|
|
other._M = np.copy(self._M)
|
|
other._S = np.copy(self._S)
|
|
return other
|
|
|
|
def push(self, x):
|
|
x = np.asarray(x)
|
|
# Unvectorized update of the running statistics.
|
|
if x.shape != self._M.shape:
|
|
raise ValueError(
|
|
"Unexpected input shape {}, expected {}, value = {}".format(
|
|
x.shape, self._M.shape, x
|
|
)
|
|
)
|
|
n1 = self._n
|
|
self._n += 1
|
|
if self._n == 1:
|
|
self._M[...] = x
|
|
else:
|
|
delta = x - self._M
|
|
self._M[...] += delta / self._n
|
|
self._S[...] += delta * delta * n1 / self._n
|
|
|
|
def update(self, other):
|
|
n1 = self._n
|
|
n2 = other._n
|
|
n = n1 + n2
|
|
if n == 0:
|
|
# Avoid divide by zero, which creates nans
|
|
return
|
|
delta = self._M - other._M
|
|
delta2 = delta * delta
|
|
M = (n1 * self._M + n2 * other._M) / n
|
|
S = self._S + other._S + delta2 * n1 * n2 / n
|
|
self._n = n
|
|
self._M = M
|
|
self._S = S
|
|
|
|
def __repr__(self):
|
|
return "(n={}, mean_mean={}, mean_std={})".format(
|
|
self.n, np.mean(self.mean), np.mean(self.std)
|
|
)
|
|
|
|
@property
|
|
def n(self):
|
|
return self._n
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._M
|
|
|
|
@property
|
|
def var(self):
|
|
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
|
|
|
|
@property
|
|
def std(self):
|
|
return np.sqrt(self.var)
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._M.shape
|
|
|
|
|
|
@DeveloperAPI
|
|
class MeanStdFilter(Filter):
|
|
"""Keeps track of a running mean for seen states"""
|
|
|
|
is_concurrent = False
|
|
|
|
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
|
self.shape = shape
|
|
# We don't have a preprocessor, if shape is None (Discrete) or
|
|
# flat_shape is Tuple[np.ndarray] or Dict[str, np.ndarray]
|
|
# (complex inputs).
|
|
flat_shape = tree.flatten(self.shape)
|
|
self.no_preprocessor = shape is None or (
|
|
isinstance(self.shape, (dict, tuple))
|
|
and len(flat_shape) > 0
|
|
and isinstance(flat_shape[0], np.ndarray)
|
|
)
|
|
# If preprocessing (flattning dicts/tuples), make sure shape
|
|
# is an np.ndarray so we don't confuse it with a complex Tuple
|
|
# space's shape structure (which is a Tuple[np.ndarray]).
|
|
if not self.no_preprocessor:
|
|
self.shape = np.array(self.shape)
|
|
self.demean = demean
|
|
self.destd = destd
|
|
self.clip = clip
|
|
# Running stats.
|
|
self.rs = tree.map_structure(lambda s: RunningStat(s), self.shape)
|
|
|
|
# In distributed rollouts, each worker sees different states.
|
|
# The buffer is used to keep track of deltas amongst all the
|
|
# observation filters.
|
|
self.buffer = None
|
|
self.reset_buffer()
|
|
|
|
def reset_buffer(self) -> None:
|
|
self.buffer = tree.map_structure(lambda s: RunningStat(s), self.shape)
|
|
|
|
def apply_changes(
|
|
self, other: "MeanStdFilter", with_buffer: bool = False, *args, **kwargs
|
|
) -> None:
|
|
"""Applies updates from the buffer of another filter.
|
|
|
|
Args:
|
|
other: Other filter to apply info from
|
|
with_buffer: Flag for specifying if the buffer should be
|
|
copied from other.
|
|
|
|
Examples:
|
|
>>> a = MeanStdFilter(())
|
|
>>> a(1)
|
|
>>> a(2)
|
|
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
|
|
[2, 1.5, 2]
|
|
>>> b = MeanStdFilter(())
|
|
>>> b(10)
|
|
>>> a.apply_changes(b, with_buffer=False)
|
|
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
|
|
[3, 4.333333333333333, 2]
|
|
>>> a.apply_changes(b, with_buffer=True)
|
|
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
|
|
[4, 5.75, 1]
|
|
"""
|
|
tree.map_structure(
|
|
lambda rs, other_rs: rs.update(other_rs), self.rs, other.buffer
|
|
)
|
|
if with_buffer:
|
|
self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)
|
|
|
|
def copy(self) -> "MeanStdFilter":
|
|
"""Returns a copy of `self`."""
|
|
other = MeanStdFilter(self.shape)
|
|
other.sync(self)
|
|
return other
|
|
|
|
def as_serializable(self) -> "MeanStdFilter":
|
|
return self.copy()
|
|
|
|
def sync(self, other: "MeanStdFilter") -> None:
|
|
"""Syncs all fields together from other filter.
|
|
|
|
Examples:
|
|
>>> a = MeanStdFilter(())
|
|
>>> a(1)
|
|
>>> a(2)
|
|
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
|
|
[2, array(1.5), 2]
|
|
>>> b = MeanStdFilter(())
|
|
>>> b(10)
|
|
>>> print([b.rs.n, b.rs.mean, b.buffer.n])
|
|
[1, array(10.0), 1]
|
|
>>> a.sync(b)
|
|
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
|
|
[1, array(10.0), 1]
|
|
"""
|
|
self.demean = other.demean
|
|
self.destd = other.destd
|
|
self.clip = other.clip
|
|
self.rs = tree.map_structure(lambda rs: rs.copy(), other.rs)
|
|
self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)
|
|
|
|
def __call__(self, x: TensorStructType, update: bool = True) -> TensorStructType:
|
|
if self.no_preprocessor:
|
|
x = tree.map_structure(lambda x_: np.asarray(x_), x)
|
|
else:
|
|
x = np.asarray(x)
|
|
|
|
def _helper(x, rs, buffer, shape):
|
|
# Discrete|MultiDiscrete spaces -> No normalization.
|
|
if shape is None:
|
|
return x
|
|
|
|
# Keep dtype as is througout this filter.
|
|
orig_dtype = x.dtype
|
|
|
|
if update:
|
|
if len(x.shape) == len(rs.shape) + 1:
|
|
# The vectorized case.
|
|
for i in range(x.shape[0]):
|
|
rs.push(x[i])
|
|
buffer.push(x[i])
|
|
else:
|
|
# The unvectorized case.
|
|
rs.push(x)
|
|
buffer.push(x)
|
|
if self.demean:
|
|
x = x - rs.mean
|
|
if self.destd:
|
|
x = x / (rs.std + SMALL_NUMBER)
|
|
if self.clip:
|
|
x = np.clip(x, -self.clip, self.clip)
|
|
return x.astype(orig_dtype)
|
|
|
|
if self.no_preprocessor:
|
|
return tree.map_structure_up_to(
|
|
x, _helper, x, self.rs, self.buffer, self.shape
|
|
)
|
|
else:
|
|
return _helper(x, self.rs, self.buffer, self.shape)
|
|
|
|
def __repr__(self) -> str:
|
|
return "MeanStdFilter({}, {}, {}, {}, {}, {})".format(
|
|
self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer
|
|
)
|
|
|
|
|
|
@DeveloperAPI
|
|
class ConcurrentMeanStdFilter(MeanStdFilter):
|
|
is_concurrent = True
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ConcurrentMeanStdFilter, self).__init__(*args, **kwargs)
|
|
self._lock = threading.RLock()
|
|
|
|
def lock_wrap(func):
|
|
def wrapper(*args, **kwargs):
|
|
with self._lock:
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
self.__getattribute__ = lock_wrap(self.__getattribute__)
|
|
|
|
def as_serializable(self) -> "MeanStdFilter":
|
|
"""Returns non-concurrent version of current class"""
|
|
other = MeanStdFilter(self.shape)
|
|
other.sync(self)
|
|
return other
|
|
|
|
def copy(self) -> "ConcurrentMeanStdFilter":
|
|
"""Returns a copy of Filter."""
|
|
other = ConcurrentMeanStdFilter(self.shape)
|
|
other.sync(self)
|
|
return other
|
|
|
|
def __repr__(self) -> str:
|
|
return "ConcurrentMeanStdFilter({}, {}, {}, {}, {}, {})".format(
|
|
self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer
|
|
)
|
|
|
|
|
|
@DeveloperAPI
|
|
def get_filter(filter_config, shape):
|
|
# TODO(rliaw): move this into filter manager
|
|
if filter_config == "MeanStdFilter":
|
|
return MeanStdFilter(shape, clip=None)
|
|
elif filter_config == "ConcurrentMeanStdFilter":
|
|
return ConcurrentMeanStdFilter(shape, clip=None)
|
|
elif filter_config == "NoFilter":
|
|
return NoFilter()
|
|
elif callable(filter_config):
|
|
return filter_config(shape)
|
|
else:
|
|
raise Exception("Unknown observation_filter: " + str(filter_config))
|