use cloudpickle

This commit is contained in:
Valentin Boettcher 2022-11-30 14:24:24 -05:00
parent 9f816d8a1b
commit 76c4878d04
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE
3 changed files with 19 additions and 11 deletions

View file

@ -32,6 +32,7 @@ import pickle
from hops.core.hierarchy_data import HIData
from sortedcontainers import SortedList
import portalocker
import cloudpickle
Aggregate = tuple[int, np.ndarray, np.ndarray]
EnsembleReturn = Union[Aggregate, list[Aggregate]]
@ -758,15 +759,9 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
return res, index
def get_online_data_path(save: str):
return Path("results") / Path(f"online_{save}.pickle")
def load_online_cache(save: str):
path = get_online_data_path(save)
with portalocker.Lock(path, "rb") as agg_file:
aggregate = pickle.load(agg_file)
with portalocker.Lock(save, "rb") as agg_file:
aggregate = cloudpickle.load(agg_file)
return aggregate.ensemble_value
@ -774,7 +769,7 @@ def load_online_cache(save: str):
def ensemble_mean_online(
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
) -> Optional[EnsembleValue]:
path = get_online_data_path(save)
path = Path(save)
if args is None:
result = None
@ -797,7 +792,7 @@ def ensemble_mean_online(
aggregate = WelfordAggregator(result, i)
with portalocker.Lock(path, "wb") as agg_file:
pickle.dump(aggregate, agg_file)
cloudpickle.dump(aggregate, agg_file)
return aggregate.ensemble_value

14
poetry.lock generated
View file

@ -182,6 +182,14 @@ python-versions = "*"
[package.extras]
test = ["click", "pytest", "six"]
[[package]]
name = "cloudpickle"
version = "2.2.0"
description = "Extended pickling support for Python objects"
category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "colorama"
version = "0.4.6"
@ -1167,7 +1175,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "1.1"
python-versions = ">=3.9,<3.11"
content-hash = "627b0aae7fea72536c5575cdbec9f518b7dfee6fbae680a08c3357c902af61ea"
content-hash = "13e674887514171239994d6be6ec47437dee15f8667b30a0265b911fda2817f1"
[metadata.files]
aiosignal = [
@ -1224,6 +1232,10 @@ click-spinner = [
{file = "click-spinner-0.1.10.tar.gz", hash = "sha256:87eacf9d7298973a25d7615ef57d4782aebf913a532bba4b28a37e366e975daf"},
{file = "click_spinner-0.1.10-py2.py3-none-any.whl", hash = "sha256:d1ffcff1fdad9882396367f15fb957bcf7f5c64ab91927dee2127e0d2991ee84"},
]
cloudpickle = [
{file = "cloudpickle-2.2.0-py3-none-any.whl", hash = "sha256:7428798d5926d8fcbfd092d18d01a2a03daf8237d8fcdc8095d256b8490796f0"},
{file = "cloudpickle-2.2.0.tar.gz", hash = "sha256:3f4219469c55453cfe4737e564b67c2a149109dabf7f242478948b895f61106f"},
]
colorama = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},

View file

@ -18,6 +18,7 @@ hops = { git = "git@gitlab.hrz.tu-chemnitz.de:s8896854--tu-dresden.de/hops.git",
opt-einsum = "^3.3.0"
sortedcontainers = "^2.4.0"
portalocker = "^2.6.0"
cloudpickle = "^2.2.0"
[tool.poetry.dev-dependencies]
mypy = "^0.910"