implement snapshotting for online analysis

This commit is contained in:
Valentin Boettcher 2022-12-09 16:15:20 -05:00
parent 8a587ab02b
commit 098c83ab44
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE
3 changed files with 41 additions and 13 deletions

View file

@ -23,6 +23,8 @@ import numpy as np
from multiprocessing import Process
import hops.core.signal_delay as signal_delay
import signal
from typing import Union
import hopsflow.util
@contextmanager
@ -55,7 +57,7 @@ def model_db(data_path: str = "./.data"):
f.truncate(0)
f.seek(0)
f.write(JSONEncoder.dumps(db))
f.write(JSONEncoder.dumps(db, indent=4))
def model_hook(dct: dict[str, Any]):
@ -261,6 +263,8 @@ def is_smaller(first: Path, second: Path) -> bool:
def import_results(
data_path: str = "./.data",
other_data_path: str = "./.data_other",
results_path: Union[Path, str] = "./results",
other_results_path: Union[Path, str] = "./results_other",
interactive: bool = False,
models_to_import: Optional[Iterable[Model]] = None,
):
@ -279,6 +283,9 @@ def import_results(
[model.hexhash for model in models_to_import] if models_to_import else []
)
results_path = Path(results_path)
other_results_path = Path(other_results_path)
with model_db(other_data_path) as other_db:
for current_hash, data in other_db.items():
with model_db(data_path) as db:
@ -291,21 +298,23 @@ def import_results(
logging.info(f"Skipping {current_hash}.")
continue
this_path = Path(data_path) / data["data_path"]
this_path_tmp = this_path.with_suffix(".part")
other_path = Path(other_data_path) / data["data_path"]
if current_hash not in db:
do_import = True
elif "data_path" not in db[current_hash]:
do_import = True
elif is_smaller(
Path(data_path) / db[current_hash]["data_path"],
Path(other_data_path) / data["data_path"],
this_path,
other_path,
):
do_import = True
if do_import:
this_path = Path(data_path) / data["data_path"]
this_path_tmp = this_path.with_suffix(".part")
other_path = Path(other_data_path) / data["data_path"]
logging.info(f"Not importing {current_hash}.")
if do_import:
config = data["model_config"]
logging.warning(f"Importing {other_path} to {this_path}.")
logging.warning(f"The model description is '{config.description}'.")
@ -317,11 +326,30 @@ def import_results(
continue
this_path.parents[0].mkdir(exist_ok=True, parents=True)
if is_smaller(this_path, other_path):
shutil.copy2(other_path, this_path_tmp)
os.system("sync")
shutil.move(this_path_tmp, this_path)
if "analysis_files" in data:
for fname in data["analysis_files"].values():
other_path = other_results_path / fname
for (
other_sub_path
) in hopsflow.util.get_all_snaphot_paths(other_path):
this_path = results_path / other_sub_path.name
this_path_tmp = this_path.with_suffix(".tmp")
logging.warning(
f"Importing {other_path} to {this_path}."
)
if other_sub_path.exists():
shutil.copy2(other_sub_path, this_path_tmp)
os.system("sync")
shutil.move(this_path_tmp, this_path)
db[current_hash] = data

View file

@ -296,7 +296,7 @@ class Model(ABC):
if not os.path.exists(file_path):
raise RuntimeError(f"No data found under '{file_path}'.")
return hopsflow.util.WelfordAggregator.from_dump(file_path).ensemble_value
return hopsflow.util.get_online_values_from_cache(file_path)
def system_energy(
self, data: Optional[HIData] = None, results_path: str = "results", **kwargs
@ -589,7 +589,7 @@ class Model(ABC):
return self.interaction_power(data, **kwargs).integrate(self.t)
def bath_energy(self, data: Optional[HIData], **kwargs) -> EnsembleValue:
def bath_energy(self, data: Optional[HIData] = None, **kwargs) -> EnsembleValue:
"""Calculates bath energy by integrating the bath energy flow
calculated from the ``data`` or, if not supplied, tries to load
the online results from ``results_path``.

6
poetry.lock generated
View file

@ -87,7 +87,7 @@ test-tox-coverage = ["coverage (>=5.5)"]
type = "git"
url = "https://github.com/beartype/beartype"
reference = "main"
resolved_reference = "f536570a1b8dc1d8f5cb3c07e93ce7915eabb899"
resolved_reference = "b48a56fb497b36aa9640c107caa6bc85e02f5782"
[[package]]
name = "binfootprint"
@ -362,7 +362,7 @@ plotting = ["matplotlib (>=3.5.0,<4.0.0)"]
type = "git"
url = "git@gitlab.hrz.tu-chemnitz.de:s8896854--tu-dresden.de/hops.git"
reference = "main"
resolved_reference = "ef9c3a500f9b2aa954a7f5b228c81d5363630b0d"
resolved_reference = "573274ec04be0a65f0e035885ce387247887b1e8"
[[package]]
name = "hopsflow"
@ -388,7 +388,7 @@ tqdm = "^4.62.3"
type = "git"
url = "https://github.com/vale981/hopsflow"
reference = "main"
resolved_reference = "9c3fc669f6a103e70af7b04e7f5f057448be66d2"
resolved_reference = "e27e38b656b18e0bf3066d72e76ed1d467f26a5c"
[[package]]
name = "humanfriendly"