allow consistency calculation with an ndarray

This commit is contained in:
Valentin Boettcher 2022-07-22 17:03:42 +02:00
parent 913dde147d
commit 9995fc2c35

View file

@ -158,9 +158,8 @@ class EnsembleValue:
for value in values:
self.insert(value)
def consistency(self, other: EnsembleValue) -> float:
def consistency(self, other: Union[EnsembleValue, np.ndarray]) -> float:
diff = abs(self[-1] - other[-1])
return (diff.value < diff.σ).sum() / len(diff.for_bath(0).value) * 100
def integrate(self, τ: np.ndarray) -> EnsembleValue:
@ -180,7 +179,9 @@ class EnsembleValue:
return EnsembleValue(out)
def __add__(self, other: Any) -> EnsembleValue:
def __add__(
self, other: Union["EnsembleValue", float, int, np.ndarray]
) -> EnsembleValue:
if type(self) == type(other):
if len(self) != len(other):
raise RuntimeError("Can only add values of equal length.")
@ -224,7 +225,7 @@ class EnsembleValue:
new.insert_multi(other)
return new
if isinstance(other, numbers.Number):
if isinstance(other, numbers.Number) or isinstance(other, np.ndarray):
out = []
for N, value, σ in self.aggregate_iterator:
@ -236,8 +237,12 @@ class EnsembleValue:
__radd__ = __add__
def __mul__(self, other):
if isinstance(other, float) or isinstance(other, int):
def __mul__(self, other: Union["EnsembleValue", float, int, np.ndarray]):
if (
isinstance(other, float)
or isinstance(other, int)
or isinstance(other, np.ndarray)
):
return EnsembleValue(
[(N, val * other, np.abs(σ * other)) for N, val, σ in self._value]
)
@ -272,11 +277,14 @@ class EnsembleValue:
__rmul__ = __mul__
def __sub__(self, other: Union["EnsembleValue", float, int]) -> "EnsembleValue":
def __sub__(
self, other: Union["EnsembleValue", float, int, np.ndarray]
) -> "EnsembleValue":
if (
type(self) == type(other)
or isinstance(other, float)
or isinstance(other, int)
or isinstance(other, np.ndarray)
):
return self + (-1 * other)