diff --git a/hopsflow/util.py b/hopsflow/util.py index 9581751..28b1edd 100644 --- a/hopsflow/util.py +++ b/hopsflow/util.py @@ -158,9 +158,8 @@ class EnsembleValue: for value in values: self.insert(value) - def consistency(self, other: EnsembleValue) -> float: + def consistency(self, other: Union[EnsembleValue, np.ndarray]) -> float: diff = abs(self[-1] - other[-1]) - return (diff.value < diff.σ).sum() / len(diff.for_bath(0).value) * 100 def integrate(self, τ: np.ndarray) -> EnsembleValue: @@ -180,7 +179,9 @@ class EnsembleValue: return EnsembleValue(out) - def __add__(self, other: Any) -> EnsembleValue: + def __add__( + self, other: Union["EnsembleValue", float, int, np.ndarray] + ) -> EnsembleValue: if type(self) == type(other): if len(self) != len(other): raise RuntimeError("Can only add values of equal length.") @@ -224,7 +225,7 @@ class EnsembleValue: new.insert_multi(other) return new - if isinstance(other, numbers.Number): + if isinstance(other, numbers.Number) or isinstance(other, np.ndarray): out = [] for N, value, σ in self.aggregate_iterator: @@ -236,8 +237,12 @@ class EnsembleValue: __radd__ = __add__ - def __mul__(self, other): - if isinstance(other, float) or isinstance(other, int): + def __mul__(self, other: Union["EnsembleValue", float, int, np.ndarray]): + if ( + isinstance(other, float) + or isinstance(other, int) + or isinstance(other, np.ndarray) + ): return EnsembleValue( [(N, val * other, np.abs(σ * other)) for N, val, σ in self._value] ) @@ -272,11 +277,14 @@ class EnsembleValue: __rmul__ = __mul__ - def __sub__(self, other: Union["EnsembleValue", float, int]) -> "EnsembleValue": + def __sub__( + self, other: Union["EnsembleValue", float, int, np.ndarray] + ) -> "EnsembleValue": if ( type(self) == type(other) or isinstance(other, float) or isinstance(other, int) + or isinstance(other, np.ndarray) ): return self + (-1 * other)