implement dumpy to np

This commit is contained in:
Valentin Boettcher 2022-11-30 18:14:42 -05:00
parent 76c4878d04
commit 5afd2b6e08
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE
3 changed files with 48 additions and 46 deletions

View file

@ -32,7 +32,8 @@ import pickle
from hops.core.hierarchy_data import HIData
from sortedcontainers import SortedList
import portalocker
import cloudpickle
import os
from numpy.typing import NDArray
Aggregate = tuple[int, np.ndarray, np.ndarray]
EnsembleReturn = Union[Aggregate, list[Aggregate]]
@ -651,15 +652,48 @@ def integrate_array(
class WelfordAggregator:
__slots__ = ["n", "mean", "_m_2", "_tracker"]
_chunk_size = 1
def __init__(self, first_value: np.ndarray, i: Optional[int] = None):
self.n = 1
self.mean = first_value
self._m_2 = np.zeros_like(first_value)
self._tracker: Optional[SortedList] = None
self._tracker: Optional[NDArray] = None
if i is not None:
self._tracker = SortedList([i])
self._tracker = np.zeros(i + 100, dtype=bool)
self._tracker[i] = True
def dump(self, path: str):
save = dict(
n=self.n, mean=self.mean, m_2=self._m_2, variance=self.sample_variance
)
if self._tracker is not None:
save["tracker"] = self._tracker
with open(path, "wb") as f:
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
np.savez(f, **save)
portalocker.unlock(f)
@classmethod
def from_dump(cls, path: str):
instance = cls(np.empty(1))
with portalocker.Lock(path, "rb", flags=portalocker.LockFlags.EXCLUSIVE) as f:
dump_file = np.load(f)
instance.n = dump_file["n"]
instance.mean = dump_file["mean"]
instance._m_2 = dump_file["m_2"]
if "tracker" in dump_file:
instance._tracker = dump_file["tracker"]
else:
instance._tracker = None
return instance
def update(self, new_value: np.ndarray, i: Optional[int] = None):
if self._tracker is not None:
@ -669,7 +703,12 @@ class WelfordAggregator:
if self.has_sample(i):
return
self._tracker.add(i)
if self._tracker.size <= i:
self._tracker = np.pad(
self._tracker, (0, self._chunk_size), constant_values=False
)
self._tracker[i] = True
self.n += 1
delta = new_value - self.mean
@ -681,7 +720,7 @@ class WelfordAggregator:
if self._tracker is None:
return False # don't know
return i in self._tracker
return self._tracker.size > i and self._tracker[i]
@property
def sample_variance(self) -> np.ndarray:
@ -759,13 +798,6 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
return res, index
def load_online_cache(save: str):
with portalocker.Lock(save, "rb") as agg_file:
aggregate = cloudpickle.load(agg_file)
return aggregate.ensemble_value
def ensemble_mean_online(
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
) -> Optional[EnsembleValue]:
@ -780,9 +812,7 @@ def ensemble_mean_online(
result = None
if path.exists():
with portalocker.Lock(path, "rb") as agg_file:
aggregate: WelfordAggregator = pickle.load(agg_file)
if result is not None:
aggregate = WelfordAggregator.from_dump(str(path))
aggregate.update(result, i)
else:
@ -791,9 +821,7 @@ def ensemble_mean_online(
aggregate = WelfordAggregator(result, i)
with portalocker.Lock(path, "wb") as agg_file:
cloudpickle.dump(aggregate, agg_file)
aggregate.dump(str(path))
return aggregate.ensemble_value

26
poetry.lock generated
View file

@ -182,14 +182,6 @@ python-versions = "*"
[package.extras]
test = ["click", "pytest", "six"]
[[package]]
name = "cloudpickle"
version = "2.2.0"
description = "Extended pickling support for Python objects"
category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "colorama"
version = "0.4.6"
@ -878,14 +870,6 @@ category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "sortedcontainers"
version = "2.4.0"
description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "soupsieve"
version = "2.3.2.post1"
@ -1175,7 +1159,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "1.1"
python-versions = ">=3.9,<3.11"
content-hash = "13e674887514171239994d6be6ec47437dee15f8667b30a0265b911fda2817f1"
content-hash = "62e10216bcfba15c0cc64bf97c0d4dbd3112c0c454b8b89b5597cf4a547e7e23"
[metadata.files]
aiosignal = [
@ -1232,10 +1216,6 @@ click-spinner = [
{file = "click-spinner-0.1.10.tar.gz", hash = "sha256:87eacf9d7298973a25d7615ef57d4782aebf913a532bba4b28a37e366e975daf"},
{file = "click_spinner-0.1.10-py2.py3-none-any.whl", hash = "sha256:d1ffcff1fdad9882396367f15fb957bcf7f5c64ab91927dee2127e0d2991ee84"},
]
cloudpickle = [
{file = "cloudpickle-2.2.0-py3-none-any.whl", hash = "sha256:7428798d5926d8fcbfd092d18d01a2a03daf8237d8fcdc8095d256b8490796f0"},
{file = "cloudpickle-2.2.0.tar.gz", hash = "sha256:3f4219469c55453cfe4737e564b67c2a149109dabf7f242478948b895f61106f"},
]
colorama = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
@ -1838,10 +1818,6 @@ snowballstemmer = [
{file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"},
{file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
]
sortedcontainers = [
{file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"},
{file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
]
soupsieve = [
{file = "soupsieve-2.3.2.post1-py3-none-any.whl", hash = "sha256:3b2503d3c7084a42b1ebd08116e5f81aadfaea95863628c80a3b774a11b7c759"},
{file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"},

View file

@ -16,9 +16,7 @@ lmfit = "=1.0.2"
ray = "^1.11.0"
hops = { git = "git@gitlab.hrz.tu-chemnitz.de:s8896854--tu-dresden.de/hops.git", branch="main" }
opt-einsum = "^3.3.0"
sortedcontainers = "^2.4.0"
portalocker = "^2.6.0"
cloudpickle = "^2.2.0"
[tool.poetry.dev-dependencies]
mypy = "^0.910"