mirror of
https://github.com/vale981/hopsflow
synced 2025-03-04 16:31:38 -05:00
add sample tracking to ensemblevalue
This commit is contained in:
parent
f8b554fcfe
commit
d75ca67126
3 changed files with 46 additions and 6 deletions
|
@ -30,6 +30,7 @@ import math
|
|||
import time
|
||||
import pickle
|
||||
from hops.core.hierarchy_data import HIData
|
||||
from sortedcontainers import SortedList
|
||||
|
||||
Aggregate = tuple[int, np.ndarray, np.ndarray]
|
||||
EnsembleReturn = Union[Aggregate, list[Aggregate]]
|
||||
|
@ -37,7 +38,9 @@ EnsembleReturn = Union[Aggregate, list[Aggregate]]
|
|||
|
||||
class EnsembleValue:
|
||||
def __init__(
|
||||
self, value: Union[Aggregate, list[Aggregate], tuple[np.ndarray, np.ndarray]]
|
||||
self,
|
||||
value: Union[Aggregate, list[Aggregate], tuple[np.ndarray, np.ndarray]],
|
||||
track=False,
|
||||
):
|
||||
if (
|
||||
isinstance(value, tuple)
|
||||
|
@ -53,6 +56,17 @@ class EnsembleValue:
|
|||
else [value]
|
||||
)
|
||||
|
||||
self._tracker: Optional[SortedList] = None
|
||||
|
||||
if track:
|
||||
self._tracker = SortedList()
|
||||
|
||||
def has_sample(self, i: int) -> bool:
|
||||
if self._tracker is None:
|
||||
return False # don't know
|
||||
|
||||
return i in self._tracker
|
||||
|
||||
@property
|
||||
def final_aggregate(self):
|
||||
return self._value[-1]
|
||||
|
@ -144,7 +158,16 @@ class EnsembleValue:
|
|||
|
||||
return final
|
||||
|
||||
def insert(self, value: Aggregate):
|
||||
def insert(self, value: Aggregate, i: Optional[int] = None):
|
||||
if self._tracker is not None:
|
||||
if i is None:
|
||||
raise ValueError("Tracking is enabled but no index was supplied.")
|
||||
|
||||
if self.has_sample(i):
|
||||
return
|
||||
|
||||
self._tracker.add(i)
|
||||
|
||||
where = len(self._value)
|
||||
for i, (N, _, _) in enumerate(self._value):
|
||||
if N > value[0]:
|
||||
|
@ -153,9 +176,13 @@ class EnsembleValue:
|
|||
|
||||
self._value.insert(where, value)
|
||||
|
||||
def insert_multi(self, values: list[Aggregate]):
|
||||
for value in values:
|
||||
self.insert(value)
|
||||
def insert_multi(self, values: list[Aggregate], i_list: Optional[list[int]] = None):
|
||||
if self._tracker is not None:
|
||||
if i_list is None:
|
||||
raise ValueError("Tracking is enabled but no indices were supplied.")
|
||||
|
||||
for value, i in zip(values, i_list if i_list else itertools.repeat(None)):
|
||||
self.insert(value, i)
|
||||
|
||||
def consistency(self, other: Union[EnsembleValue, np.ndarray]) -> float:
|
||||
diff = abs(
|
||||
|
|
14
poetry.lock
generated
14
poetry.lock
generated
|
@ -846,6 +846,14 @@ category = "dev"
|
|||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "sortedcontainers"
|
||||
version = "2.4.0"
|
||||
description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "soupsieve"
|
||||
version = "2.3.2.post1"
|
||||
|
@ -1135,7 +1143,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.9,<3.11"
|
||||
content-hash = "7ad847d5cf00cb2fe16647abf6431d8cc46dfde7bee71e01fc288101065fe897"
|
||||
content-hash = "53d4c1d9dfae83142adfad9d1bde9f04ad2443ba9a23075ab355559532ab94af"
|
||||
|
||||
[metadata.files]
|
||||
aiosignal = [
|
||||
|
@ -1774,6 +1782,10 @@ snowballstemmer = [
|
|||
{file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"},
|
||||
{file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
|
||||
]
|
||||
sortedcontainers = [
|
||||
{file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"},
|
||||
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
|
||||
]
|
||||
soupsieve = [
|
||||
{file = "soupsieve-2.3.2.post1-py3-none-any.whl", hash = "sha256:3b2503d3c7084a42b1ebd08116e5f81aadfaea95863628c80a3b774a11b7c759"},
|
||||
{file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"},
|
||||
|
|
|
@ -16,6 +16,7 @@ lmfit = "=1.0.2"
|
|||
ray = "^1.11.0"
|
||||
hops = { git = "git@gitlab.hrz.tu-chemnitz.de:s8896854--tu-dresden.de/hops.git", branch="main" }
|
||||
opt-einsum = "^3.3.0"
|
||||
sortedcontainers = "^2.4.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
mypy = "^0.910"
|
||||
|
|
Loading…
Add table
Reference in a new issue