Add some docs to the util module.

This commit is contained in:
Valentin Boettcher 2024-01-27 19:48:08 -05:00
parent 21ecb60ecf
commit 5d5b561f0f
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -36,15 +36,23 @@ from numpy.typing import NDArray
Aggregate = tuple[int, np.ndarray, np.ndarray]
EnsembleReturn = Union[Aggregate, list[Aggregate]]
class EnsembleValue:
"""A data container to hold data that results from monte-carlo
simulations.
"""A container to hold the values of an ensemble and perform
arithmetic operations on and between them.
It supports saving multiple snapshots at different sample sizes,
recording the variances accordingly.
The ensembles are stored as a list of aggregates (snapshots).
Each aggregate is of the shape (sample number, value, standard
deviation). The values and standard deviations are stored as
numpy arrays. The aggregates are sorted by sample number.
Standard arithmetic operations, as well as integration is supported.
Addition, subtraction, multiplication, division and integration
are defined. The latter is performed using spline interpolation.
:param value: The value of the ensemble. Can be a single
aggregate, or a list of aggregates, or a tuple of two numpy
arrays. In the latter case the first array is the value and
the second the standard deviation, where the sample size is
set to 0.
"""
def __init__(
@ -64,46 +72,61 @@ class EnsembleValue:
else [value]
)
self._value.sort(key=lambda x: x[0])
@property
def final_aggregate(self):
"""The last aggregate."""
return self._value[-1]
@property
def N(self):
"""The number of samples."""
return self.final_aggregate[0]
@property
def value(self):
"""The the values of the last aggregate (snapshot)."""
return self.final_aggregate[1]
@property
def σ(self):
"""The standard deviation of the last aggregate (snapshot)."""
return self.final_aggregate[2]
@property
def Ns(self):
"""The number of samples for each aggregate (snapshot)."""
return [N for N, _, _ in self._value]
@property
def values(self):
"""The values of each aggregate (snapshot)."""
return [val for _, val, _ in self._value]
@property
def σs(self):
"""The standard deviation of each aggregate (snapshot)."""
return [σ for _, _, σ in self._value]
@property
def aggregate_iterator(self):
"""Iterates over all aggregates (snapshots)."""
for agg in self._value:
yield agg
@property
def ensemble_value_iterator(self):
"""Iterates over all values in the aggregates (snapshots)."""
for agg in self._value:
yield EnsembleValue(agg)
@property
def mean(self):
def mean(self) -> EnsembleValue:
"""
Returns the mean of the ensemble as a new EnsembleValue. The
standard deviation is correctly propagated.
"""
values = []
for N, val, σ in self.aggregate_iterator:
@ -114,23 +137,30 @@ class EnsembleValue:
return EnsembleValue(values)
@property
def max(self):
def max(self) -> EnsembleValue:
"""Returns the maximum value of the EnsembleValue as a new EnsembleValue."""
N, val, σ = self.final_aggregate
max_index = np.argmax(val)
return EnsembleValue([(N, val[max_index].copy(), σ[max_index].copy())])
@property
def min(self):
def min(self) -> EnsembleValue:
"""Returns the minimum value of the EnsembleValue as a new EnsembleValue."""
N, val, σ = self.final_aggregate
min_index = np.argmin(val)
return EnsembleValue([(N, val[min_index].copy(), σ[min_index].copy())])
def __getitem__(self, index):
"""Returns the aggregate (snapshot) at ``index``."""
return EnsembleValue([self._value[index]])
def slice(self, slc: Union[np.ndarray, slice]):
def slice(self, slc: Union[np.ndarray, slice]) -> EnsembleValue:
"""
Returns a new EnsembleValue with the values and standard
deviations in the aggregates (snapshots) sliced by ``slc``.
"""
results = []
for N, val, σ in self.aggregate_iterator:
results.append((N, val[slc], σ[slc]))
@ -138,9 +168,17 @@ class EnsembleValue:
return EnsembleValue(results)
def __len__(self) -> int:
"""Returns the number of aggregates (snapshots)."""
return len(self._value)
def for_bath(self, bath: int):
def for_bath(self, bath: int) -> EnsembleValue:
"""
Returns a new EnsembleValue with the values and standard
deviations for the bath ``bath``.
This is specific to values of the form ``[for bath 1, for bath
2, ...]``.
"""
if self.num_baths == 1 and len(self.value.shape) in [0, 1]:
return self
@ -148,10 +186,21 @@ class EnsembleValue:
@property
def num_baths(self) -> int:
"""The number of baths.
This is specific to values of the form ``[for bath 1, for
bath, ...]``.
"""
shape = self.value.shape
return self.value.shape[0] if len(shape) > 1 else 1
def sum_baths(self) -> EnsembleValue:
"""Returns a new EnsembleValue where the values and standard
deviations are summed over the baths.
This is specific to values of the form ``[for bath 1, for
bath, ...]``.
"""
final = self.for_bath(0)
for i in range(1, self.num_baths):
final = final + self.for_bath(i)
@ -159,6 +208,7 @@ class EnsembleValue:
return final
def insert(self, value: Aggregate):
"""Inserts a new aggregate (snapshot) so that the aggregate remains sorted by sample count."""
where = len(self._value)
for i, (N, _, _) in enumerate(self._value):
if N > value[0]:
@ -168,10 +218,20 @@ class EnsembleValue:
self._value.insert(where, value)
def insert_multi(self, values: list[Aggregate]):
"""Inserts multiple aggregates (snapshots) so that the
aggregates remain sorted by sample count.
See :any:`insert` for details.
"""
for value in values:
self.insert(value)
def consistency(self, other: Union[EnsembleValue, np.ndarray]) -> float:
"""
Determines weather two EnsembleValues are consistent by
checking whether their last values are within the standard
deviation of each other.
"""
diff = abs(
self[-1] - (other[-1] if isinstance(other, self.__class__) else other)
)
@ -184,7 +244,10 @@ class EnsembleValue:
)
def integrate(self, τ: np.ndarray) -> EnsembleValue:
"""Calculate the antiderivative along a 'time axis' ``τ``."""
"""
Calculate the integral of the value and standard deviation
along a 'time axis' ``τ``.
"""
results = []
for N, val, σ in self.aggregate_iterator:
@ -193,6 +256,10 @@ class EnsembleValue:
return EnsembleValue(results)
def __abs__(self) -> "EnsembleValue":
"""
Returns a new EnsembleValue where the values are replaced by
their absolute value.
"""
out = []
for N, value, σ in self._value:
@ -203,6 +270,10 @@ class EnsembleValue:
def __add__(
self, other: Union["EnsembleValue", float, int, np.ndarray]
) -> EnsembleValue:
"""
Add two EnsembleValues or an ensemble value and a number,
another ensemble value or an array.
"""
if isinstance(other, EnsembleValue):
if len(self) != len(other):
logging.warn(
@ -267,6 +338,7 @@ class EnsembleValue:
__radd__ = __add__
def __mul__(self, other: Union["EnsembleValue", float, int, np.ndarray]):
"""Multiply two EnsembleValues or an EnsembleValue and a number or array."""
if (
isinstance(other, float)
or isinstance(other, int)
@ -311,6 +383,7 @@ class EnsembleValue:
__rmul__ = __mul__
def __truediv__(self, other: Union["EnsembleValue", float, int, np.ndarray]):
"""Divide two EnsembleValues or an EnsembleValue and a number or array."""
if (
isinstance(other, float)
or isinstance(other, int)
@ -360,6 +433,7 @@ class EnsembleValue:
def __sub__(
self, other: Union["EnsembleValue", float, int, np.ndarray]
) -> "EnsembleValue":
"""Subtract two EnsembleValues or an EnsembleValue and a number or array."""
if (
type(self) == type(other)
or isinstance(other, float)
@ -371,6 +445,7 @@ class EnsembleValue:
return NotImplemented
def __rsub__(self, other: Union["EnsembleValue", float, int]) -> "EnsembleValue":
"""Subtract two EnsembleValues or an EnsembleValue and a number or array."""
if (
type(self) == type(other)
or isinstance(other, float)
@ -714,21 +789,22 @@ def integrate_array(
class WelfordAggregator:
"""
A helper class to calculate means and variances on datasets
incremenally (online) in a numerically stable fashion.
"""A class to aggregate values using the Welford algorithm.
The Welford algorithm is an online algorithm to calculate the mean
and variance of a series of values.
The aggregator keeps track of the number of samples the mean and
the variance. Aggregation of identical values is prevented by
checking the sample index. Tracking can be disabled by setting
the initial index to ``None``.
See also the `Wikipedia article
<https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm>`_.
This implementation features serialization to numpy saves and
optionally keeps track of sample indices to avoid accumulating the
same value twice.
:param first_value: the first value to aggregate
:param i: The sample index associated with the value. If no sample
index is provided, sample tracking is _disabled_ for the
instance.
:param first_value: The first value to aggregate.
:param i: The index of the first value. If ``None`` tracking is
disabled.
"""
__slots__ = ["n", "mean", "_m_2", "_tracker"]
@ -745,10 +821,10 @@ class WelfordAggregator:
self._tracker[i] = True
def dump(self, path: str):
"""
Serialize the instance into a file under ``path`` using
:any:`numpy.save`.
"""
"""Dumps the aggregator to a file at ``path``.
See also :any:`from_dump`."""
save = dict(
n=self.n, mean=self.mean, m_2=self._m_2, variance=self.sample_variance
)
@ -762,7 +838,9 @@ class WelfordAggregator:
@classmethod
def from_dump(cls, path: str):
"""Resurrect the instance from the dump under ``path``."""
"""Loads the aggregator from a file at ``path``.
See also :any:`dump`."""
instance = cls(np.empty(1))
with portalocker.Lock(path, "rb", flags=portalocker.LockFlags.EXCLUSIVE) as f:
@ -781,13 +859,12 @@ class WelfordAggregator:
return instance
def update(self, new_value: np.ndarray, i: Optional[int] = None):
"""Update mean and variance with a ``new_value`` with the
sample index ``i``.
"""Updates the aggregator with a new value.
Attempting to update without a sample index if tracking is
enabled results in an error.
If ``i`` is given, the aggregator will check if the value was
already added. Note that the index has to be supplied if
tracking is enabled.
"""
if self._tracker is not None:
if i is None:
raise ValueError("Tracking is enabled but no index was supplied.")
@ -810,8 +887,9 @@ class WelfordAggregator:
self._m_2 += np.abs(delta) * np.abs(delta2)
def has_sample(self, i: int) -> bool:
"""Whether the sample with index ``i`` has been accumulated so far."""
"""Returns whether the aggregator has already seen the sample
with index ``i``.
"""
if self._tracker is None:
return False # don't know
@ -831,8 +909,7 @@ class WelfordAggregator:
@property
def ensemble_variance(self) -> np.ndarray:
"""The sample variance."""
"""The ensemble variance."""
return self.sample_variance / self.n
@property
@ -842,7 +919,7 @@ class WelfordAggregator:
@property
def ensemble_value(self) -> EnsembleValue:
"""Convert the instance to an :any:`EnsembleValue`."""
"""Constructs an :any:`EnsembleValue` from the aggregator."""
return EnsembleValue([(self.n, self.mean, self.ensemble_std)])
@ -910,6 +987,28 @@ def ensemble_mean_online(
every: Optional[Union[int, Callable[[int], bool]]] = None,
aggregator: Optional[WelfordAggregator] = None,
) -> Optional[EnsembleValue]:
"""Calculates the ensemble mean of ``function`` applied to
``args``.
The result is aggregated using the Welford algorithm. If ``save``
is given, the aggregator will loaded from and be dumped to
``save``. Alternatively a WellfordAggregator can be passed in
``aggregator``.
:param args: The arguments to pass to ``function``.
:param function: The function to apply to ``args``.
:param save: The path to save the aggregator to.
:param i: The index of the sample. If ``None`` tracking is
:param every: If ``None`` the aggregator will be dumped after
every update. If ``int`` the aggregator will be
dumped after every ``every`` updates. If a function
the aggregator will be dumped after every update
where ``every(n)`` returns :any:`True`.
:param aggregator: The aggregator to use. If ``None`` a new
aggregator will be created.
:returns: The aggregator.
"""
if args is None:
result = None
else:
@ -972,6 +1071,24 @@ def ensemble_mean(
in_flight: Optional[int] = None,
gc_sleep: float = 0,
) -> EnsembleValue:
"""Calculates the ensemble mean of ``function`` applied to
``args``. The result is aggregated using the Welford algorithm.
:param arg_iter: An iterator over the arguments to pass to
``function``.
:param function: The function to apply to ``args``.
:param N: The number of samples to take.
:param every: If ``None`` the aggregator will be dumped after
every update. If ``int`` the aggregator will be
dumped after every ``every`` updates.
:param save: The path to save the aggregator to.
:param overwrite_cache: Whether to overwrite the cache if it
exists.
:param chunk_size: The size of the chunks to send to the workers.
:param in_flight: The number of chunks to keep in flight.
:param gc_sleep: The time to sleep after each chunk to allow the
garbage collector to catch up.
:returns: The aggregator."""
results = []
path = None