mirror of
https://github.com/vale981/hopsflow
synced 2025-03-04 16:31:38 -05:00
Add some docs to the util module.
This commit is contained in:
parent
21ecb60ecf
commit
5d5b561f0f
1 changed files with 155 additions and 38 deletions
193
hopsflow/util.py
193
hopsflow/util.py
|
@ -36,15 +36,23 @@ from numpy.typing import NDArray
|
|||
Aggregate = tuple[int, np.ndarray, np.ndarray]
|
||||
EnsembleReturn = Union[Aggregate, list[Aggregate]]
|
||||
|
||||
|
||||
class EnsembleValue:
|
||||
"""A data container to hold data that results from monte-carlo
|
||||
simulations.
|
||||
"""A container to hold the values of an ensemble and perform
|
||||
arithmetic operations on and between them.
|
||||
|
||||
It supports saving multiple snapshots at different sample sizes,
|
||||
recording the variances accordingly.
|
||||
The ensembles are stored as a list of aggregates (snapshots).
|
||||
Each aggregate is of the shape (sample number, value, standard
|
||||
deviation). The values and standard deviations are stored as
|
||||
numpy arrays. The aggregates are sorted by sample number.
|
||||
|
||||
Standard arithmetic operations, as well as integration is supported.
|
||||
Addition, subtraction, multiplication, division and integration
|
||||
are defined. The latter is performed using spline interpolation.
|
||||
|
||||
:param value: The value of the ensemble. Can be a single
|
||||
aggregate, or a list of aggregates, or a tuple of two numpy
|
||||
arrays. In the latter case the first array is the value and
|
||||
the second the standard deviation, where the sample size is
|
||||
set to 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -64,46 +72,61 @@ class EnsembleValue:
|
|||
else [value]
|
||||
)
|
||||
|
||||
self._value.sort(key=lambda x: x[0])
|
||||
|
||||
@property
|
||||
def final_aggregate(self):
|
||||
"""The last aggregate."""
|
||||
return self._value[-1]
|
||||
|
||||
@property
|
||||
def N(self):
|
||||
"""The number of samples."""
|
||||
return self.final_aggregate[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""The the values of the last aggregate (snapshot)."""
|
||||
return self.final_aggregate[1]
|
||||
|
||||
@property
|
||||
def σ(self):
|
||||
"""The standard deviation of the last aggregate (snapshot)."""
|
||||
return self.final_aggregate[2]
|
||||
|
||||
@property
|
||||
def Ns(self):
|
||||
"""The number of samples for each aggregate (snapshot)."""
|
||||
return [N for N, _, _ in self._value]
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""The values of each aggregate (snapshot)."""
|
||||
return [val for _, val, _ in self._value]
|
||||
|
||||
@property
|
||||
def σs(self):
|
||||
"""The standard deviation of each aggregate (snapshot)."""
|
||||
return [σ for _, _, σ in self._value]
|
||||
|
||||
@property
|
||||
def aggregate_iterator(self):
|
||||
"""Iterates over all aggregates (snapshots)."""
|
||||
for agg in self._value:
|
||||
yield agg
|
||||
|
||||
@property
|
||||
def ensemble_value_iterator(self):
|
||||
"""Iterates over all values in the aggregates (snapshots)."""
|
||||
for agg in self._value:
|
||||
yield EnsembleValue(agg)
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
def mean(self) -> EnsembleValue:
|
||||
"""
|
||||
Returns the mean of the ensemble as a new EnsembleValue. The
|
||||
standard deviation is correctly propagated.
|
||||
"""
|
||||
values = []
|
||||
|
||||
for N, val, σ in self.aggregate_iterator:
|
||||
|
@ -114,23 +137,30 @@ class EnsembleValue:
|
|||
return EnsembleValue(values)
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
def max(self) -> EnsembleValue:
|
||||
"""Returns the maximum value of the EnsembleValue as a new EnsembleValue."""
|
||||
N, val, σ = self.final_aggregate
|
||||
max_index = np.argmax(val)
|
||||
|
||||
return EnsembleValue([(N, val[max_index].copy(), σ[max_index].copy())])
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
def min(self) -> EnsembleValue:
|
||||
"""Returns the minimum value of the EnsembleValue as a new EnsembleValue."""
|
||||
N, val, σ = self.final_aggregate
|
||||
min_index = np.argmin(val)
|
||||
|
||||
return EnsembleValue([(N, val[min_index].copy(), σ[min_index].copy())])
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Returns the aggregate (snapshot) at ``index``."""
|
||||
return EnsembleValue([self._value[index]])
|
||||
|
||||
def slice(self, slc: Union[np.ndarray, slice]):
|
||||
def slice(self, slc: Union[np.ndarray, slice]) -> EnsembleValue:
|
||||
"""
|
||||
Returns a new EnsembleValue with the values and standard
|
||||
deviations in the aggregates (snapshots) sliced by ``slc``.
|
||||
"""
|
||||
results = []
|
||||
for N, val, σ in self.aggregate_iterator:
|
||||
results.append((N, val[slc], σ[slc]))
|
||||
|
@ -138,9 +168,17 @@ class EnsembleValue:
|
|||
return EnsembleValue(results)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of aggregates (snapshots)."""
|
||||
return len(self._value)
|
||||
|
||||
def for_bath(self, bath: int):
|
||||
def for_bath(self, bath: int) -> EnsembleValue:
|
||||
"""
|
||||
Returns a new EnsembleValue with the values and standard
|
||||
deviations for the bath ``bath``.
|
||||
|
||||
This is specific to values of the form ``[for bath 1, for bath
|
||||
2, ...]``.
|
||||
"""
|
||||
if self.num_baths == 1 and len(self.value.shape) in [0, 1]:
|
||||
return self
|
||||
|
||||
|
@ -148,10 +186,21 @@ class EnsembleValue:
|
|||
|
||||
@property
|
||||
def num_baths(self) -> int:
|
||||
"""The number of baths.
|
||||
|
||||
This is specific to values of the form ``[for bath 1, for
|
||||
bath, ...]``.
|
||||
"""
|
||||
shape = self.value.shape
|
||||
return self.value.shape[0] if len(shape) > 1 else 1
|
||||
|
||||
def sum_baths(self) -> EnsembleValue:
|
||||
"""Returns a new EnsembleValue where the values and standard
|
||||
deviations are summed over the baths.
|
||||
|
||||
This is specific to values of the form ``[for bath 1, for
|
||||
bath, ...]``.
|
||||
"""
|
||||
final = self.for_bath(0)
|
||||
for i in range(1, self.num_baths):
|
||||
final = final + self.for_bath(i)
|
||||
|
@ -159,6 +208,7 @@ class EnsembleValue:
|
|||
return final
|
||||
|
||||
def insert(self, value: Aggregate):
|
||||
"""Inserts a new aggregate (snapshot) so that the aggregate remains sorted by sample count."""
|
||||
where = len(self._value)
|
||||
for i, (N, _, _) in enumerate(self._value):
|
||||
if N > value[0]:
|
||||
|
@ -168,10 +218,20 @@ class EnsembleValue:
|
|||
self._value.insert(where, value)
|
||||
|
||||
def insert_multi(self, values: list[Aggregate]):
|
||||
"""Inserts multiple aggregates (snapshots) so that the
|
||||
aggregates remain sorted by sample count.
|
||||
|
||||
See :any:`insert` for details.
|
||||
"""
|
||||
for value in values:
|
||||
self.insert(value)
|
||||
|
||||
def consistency(self, other: Union[EnsembleValue, np.ndarray]) -> float:
|
||||
"""
|
||||
Determines weather two EnsembleValues are consistent by
|
||||
checking whether their last values are within the standard
|
||||
deviation of each other.
|
||||
"""
|
||||
diff = abs(
|
||||
self[-1] - (other[-1] if isinstance(other, self.__class__) else other)
|
||||
)
|
||||
|
@ -184,7 +244,10 @@ class EnsembleValue:
|
|||
)
|
||||
|
||||
def integrate(self, τ: np.ndarray) -> EnsembleValue:
|
||||
"""Calculate the antiderivative along a 'time axis' ``τ``."""
|
||||
"""
|
||||
Calculate the integral of the value and standard deviation
|
||||
along a 'time axis' ``τ``.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for N, val, σ in self.aggregate_iterator:
|
||||
|
@ -193,6 +256,10 @@ class EnsembleValue:
|
|||
return EnsembleValue(results)
|
||||
|
||||
def __abs__(self) -> "EnsembleValue":
|
||||
"""
|
||||
Returns a new EnsembleValue where the values are replaced by
|
||||
their absolute value.
|
||||
"""
|
||||
out = []
|
||||
|
||||
for N, value, σ in self._value:
|
||||
|
@ -203,6 +270,10 @@ class EnsembleValue:
|
|||
def __add__(
|
||||
self, other: Union["EnsembleValue", float, int, np.ndarray]
|
||||
) -> EnsembleValue:
|
||||
"""
|
||||
Add two EnsembleValues or an ensemble value and a number,
|
||||
another ensemble value or an array.
|
||||
"""
|
||||
if isinstance(other, EnsembleValue):
|
||||
if len(self) != len(other):
|
||||
logging.warn(
|
||||
|
@ -267,6 +338,7 @@ class EnsembleValue:
|
|||
__radd__ = __add__
|
||||
|
||||
def __mul__(self, other: Union["EnsembleValue", float, int, np.ndarray]):
|
||||
"""Multiply two EnsembleValues or an EnsembleValue and a number or array."""
|
||||
if (
|
||||
isinstance(other, float)
|
||||
or isinstance(other, int)
|
||||
|
@ -311,6 +383,7 @@ class EnsembleValue:
|
|||
__rmul__ = __mul__
|
||||
|
||||
def __truediv__(self, other: Union["EnsembleValue", float, int, np.ndarray]):
|
||||
"""Divide two EnsembleValues or an EnsembleValue and a number or array."""
|
||||
if (
|
||||
isinstance(other, float)
|
||||
or isinstance(other, int)
|
||||
|
@ -360,6 +433,7 @@ class EnsembleValue:
|
|||
def __sub__(
|
||||
self, other: Union["EnsembleValue", float, int, np.ndarray]
|
||||
) -> "EnsembleValue":
|
||||
"""Subtract two EnsembleValues or an EnsembleValue and a number or array."""
|
||||
if (
|
||||
type(self) == type(other)
|
||||
or isinstance(other, float)
|
||||
|
@ -371,6 +445,7 @@ class EnsembleValue:
|
|||
return NotImplemented
|
||||
|
||||
def __rsub__(self, other: Union["EnsembleValue", float, int]) -> "EnsembleValue":
|
||||
"""Subtract two EnsembleValues or an EnsembleValue and a number or array."""
|
||||
if (
|
||||
type(self) == type(other)
|
||||
or isinstance(other, float)
|
||||
|
@ -714,21 +789,22 @@ def integrate_array(
|
|||
|
||||
|
||||
class WelfordAggregator:
|
||||
"""
|
||||
A helper class to calculate means and variances on datasets
|
||||
incremenally (online) in a numerically stable fashion.
|
||||
"""A class to aggregate values using the Welford algorithm.
|
||||
|
||||
The Welford algorithm is an online algorithm to calculate the mean
|
||||
and variance of a series of values.
|
||||
|
||||
The aggregator keeps track of the number of samples the mean and
|
||||
the variance. Aggregation of identical values is prevented by
|
||||
checking the sample index. Tracking can be disabled by setting
|
||||
the initial index to ``None``.
|
||||
|
||||
See also the `Wikipedia article
|
||||
<https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm>`_.
|
||||
|
||||
This implementation features serialization to numpy saves and
|
||||
optionally keeps track of sample indices to avoid accumulating the
|
||||
same value twice.
|
||||
|
||||
:param first_value: the first value to aggregate
|
||||
:param i: The sample index associated with the value. If no sample
|
||||
index is provided, sample tracking is _disabled_ for the
|
||||
instance.
|
||||
:param first_value: The first value to aggregate.
|
||||
:param i: The index of the first value. If ``None`` tracking is
|
||||
disabled.
|
||||
"""
|
||||
|
||||
__slots__ = ["n", "mean", "_m_2", "_tracker"]
|
||||
|
@ -745,10 +821,10 @@ class WelfordAggregator:
|
|||
self._tracker[i] = True
|
||||
|
||||
def dump(self, path: str):
|
||||
"""
|
||||
Serialize the instance into a file under ``path`` using
|
||||
:any:`numpy.save`.
|
||||
"""
|
||||
"""Dumps the aggregator to a file at ``path``.
|
||||
|
||||
See also :any:`from_dump`."""
|
||||
|
||||
save = dict(
|
||||
n=self.n, mean=self.mean, m_2=self._m_2, variance=self.sample_variance
|
||||
)
|
||||
|
@ -762,7 +838,9 @@ class WelfordAggregator:
|
|||
|
||||
@classmethod
|
||||
def from_dump(cls, path: str):
|
||||
"""Resurrect the instance from the dump under ``path``."""
|
||||
"""Loads the aggregator from a file at ``path``.
|
||||
|
||||
See also :any:`dump`."""
|
||||
|
||||
instance = cls(np.empty(1))
|
||||
with portalocker.Lock(path, "rb", flags=portalocker.LockFlags.EXCLUSIVE) as f:
|
||||
|
@ -781,13 +859,12 @@ class WelfordAggregator:
|
|||
return instance
|
||||
|
||||
def update(self, new_value: np.ndarray, i: Optional[int] = None):
|
||||
"""Update mean and variance with a ``new_value`` with the
|
||||
sample index ``i``.
|
||||
"""Updates the aggregator with a new value.
|
||||
|
||||
Attempting to update without a sample index if tracking is
|
||||
enabled results in an error.
|
||||
If ``i`` is given, the aggregator will check if the value was
|
||||
already added. Note that the index has to be supplied if
|
||||
tracking is enabled.
|
||||
"""
|
||||
|
||||
if self._tracker is not None:
|
||||
if i is None:
|
||||
raise ValueError("Tracking is enabled but no index was supplied.")
|
||||
|
@ -810,8 +887,9 @@ class WelfordAggregator:
|
|||
self._m_2 += np.abs(delta) * np.abs(delta2)
|
||||
|
||||
def has_sample(self, i: int) -> bool:
|
||||
"""Whether the sample with index ``i`` has been accumulated so far."""
|
||||
|
||||
"""Returns whether the aggregator has already seen the sample
|
||||
with index ``i``.
|
||||
"""
|
||||
if self._tracker is None:
|
||||
return False # don't know
|
||||
|
||||
|
@ -831,8 +909,7 @@ class WelfordAggregator:
|
|||
|
||||
@property
|
||||
def ensemble_variance(self) -> np.ndarray:
|
||||
"""The sample variance."""
|
||||
|
||||
"""The ensemble variance."""
|
||||
return self.sample_variance / self.n
|
||||
|
||||
@property
|
||||
|
@ -842,7 +919,7 @@ class WelfordAggregator:
|
|||
|
||||
@property
|
||||
def ensemble_value(self) -> EnsembleValue:
|
||||
"""Convert the instance to an :any:`EnsembleValue`."""
|
||||
"""Constructs an :any:`EnsembleValue` from the aggregator."""
|
||||
return EnsembleValue([(self.n, self.mean, self.ensemble_std)])
|
||||
|
||||
|
||||
|
@ -910,6 +987,28 @@ def ensemble_mean_online(
|
|||
every: Optional[Union[int, Callable[[int], bool]]] = None,
|
||||
aggregator: Optional[WelfordAggregator] = None,
|
||||
) -> Optional[EnsembleValue]:
|
||||
"""Calculates the ensemble mean of ``function`` applied to
|
||||
``args``.
|
||||
|
||||
The result is aggregated using the Welford algorithm. If ``save``
|
||||
is given, the aggregator will loaded from and be dumped to
|
||||
``save``. Alternatively a WellfordAggregator can be passed in
|
||||
``aggregator``.
|
||||
|
||||
:param args: The arguments to pass to ``function``.
|
||||
:param function: The function to apply to ``args``.
|
||||
:param save: The path to save the aggregator to.
|
||||
:param i: The index of the sample. If ``None`` tracking is
|
||||
:param every: If ``None`` the aggregator will be dumped after
|
||||
every update. If ``int`` the aggregator will be
|
||||
dumped after every ``every`` updates. If a function
|
||||
the aggregator will be dumped after every update
|
||||
where ``every(n)`` returns :any:`True`.
|
||||
:param aggregator: The aggregator to use. If ``None`` a new
|
||||
aggregator will be created.
|
||||
:returns: The aggregator.
|
||||
"""
|
||||
|
||||
if args is None:
|
||||
result = None
|
||||
else:
|
||||
|
@ -972,6 +1071,24 @@ def ensemble_mean(
|
|||
in_flight: Optional[int] = None,
|
||||
gc_sleep: float = 0,
|
||||
) -> EnsembleValue:
|
||||
"""Calculates the ensemble mean of ``function`` applied to
|
||||
``args``. The result is aggregated using the Welford algorithm.
|
||||
|
||||
:param arg_iter: An iterator over the arguments to pass to
|
||||
``function``.
|
||||
:param function: The function to apply to ``args``.
|
||||
:param N: The number of samples to take.
|
||||
:param every: If ``None`` the aggregator will be dumped after
|
||||
every update. If ``int`` the aggregator will be
|
||||
dumped after every ``every`` updates.
|
||||
:param save: The path to save the aggregator to.
|
||||
:param overwrite_cache: Whether to overwrite the cache if it
|
||||
exists.
|
||||
:param chunk_size: The size of the chunks to send to the workers.
|
||||
:param in_flight: The number of chunks to keep in flight.
|
||||
:param gc_sleep: The time to sleep after each chunk to allow the
|
||||
garbage collector to catch up.
|
||||
:returns: The aggregator."""
|
||||
results = []
|
||||
|
||||
path = None
|
||||
|
|
Loading…
Add table
Reference in a new issue