From ebbdfe36b326bd4d9f20719b0b746fd19a68378d Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Thu, 7 Apr 2022 14:30:30 +0200 Subject: [PATCH] make import more convenient --- hiro_models/model_auxiliary.py | 111 +++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 26 deletions(-) diff --git a/hiro_models/model_auxiliary.py b/hiro_models/model_auxiliary.py index 40fc781..e7e10f7 100644 --- a/hiro_models/model_auxiliary.py +++ b/hiro_models/model_auxiliary.py @@ -1,6 +1,6 @@ """Functionality to integrate :any:`Model` instances and analyze the results.""" -from typing import Any +from typing import Any, Optional from hops.core.hierarchy_data import HIData from qutip.steadystate import _default_steadystate_args @@ -12,7 +12,7 @@ from filelock import FileLock from pathlib import Path from .one_qubit_model import QubitModel from .two_qubit_model import TwoQubitModel -from collections.abc import Sequence, Iterator +from collections.abc import Sequence, Iterator, Iterable import shutil import logging @@ -38,7 +38,9 @@ def model_db(data_path: str = "./.data"): with FileLock(db_lock): with db_path.open("r+") as f: data = f.read() - db = JSONEncoder.loads(data) if len(data) > 0 else {} + db = ( + JSONEncoder.loads(data, object_hook=model_hook) if len(data) > 0 else {} + ) yield db @@ -53,11 +55,16 @@ def model_hook(dct: dict[str, Any]): if "__model__" in dct: model = dct["__model__"] + treated_vals = { + key: object_hook(val) if isinstance(val, dict) else val + for key, val in dct.items() + } + if model == "QubitModel": - return QubitModel.from_dict(dct) + return QubitModel.from_dict(treated_vals) if model == "TwoQubitModel": - return TwoQubitModel.from_dict(dct) + return TwoQubitModel.from_dict(treated_vals) return object_hook(dct) @@ -125,7 +132,7 @@ def get_data( def model_data_iterator( - models: Model, *args, **kwargs + models: Iterable[Model], *args, **kwargs ) -> Iterator[tuple[Model, HIData]]: """ Yields tuples of ``model, data``, where ``data`` is already opened @@ -139,52 +146,98 @@ def model_data_iterator( yield model, data -def import_results(data_path: str = "./.data", other_data_path: str = "./.data_other"): +def is_smaller(first: Path, second: Path) -> bool: + """ + :returns: Wether the file ``first`` is smaller that ``second``. + """ + + if not first.exists(): + return True + + return first.stat().st_size < second.stat().st_size + + +def import_results( + data_path: str = "./.data", + other_data_path: str = "./.data_other", + interactive: bool = False, + models_to_import: Optional[Iterable[Model]] = None, +): """ Imports results from the ``other_data_path`` into the ``other_data_path`` if the files are newer. + + If ``interactive`` is any :any:`True`, the routine will ask before + copying. + + If ``models_to_import`` is specified, only data of models matching + those in ``models_to_import`` will be imported. """ - with model_db(data_path) as db: - with model_db(other_data_path) as other_db: - for hash, data in other_db.items(): + hashes_to_import = ( + [model.hexhash for model in models_to_import] if models_to_import else [] + ) + + with model_db(other_data_path) as other_db: + for current_hash, data in other_db.items(): + with model_db(data_path) as db: if "data_path" not in data: continue do_import = False - if hash not in db: + if hashes_to_import and current_hash not in hashes_to_import: + logging.info(f"Skipping {current_hash}.") + continue + + if current_hash not in db: do_import = True - elif "data_path" not in db[hash]: + elif "data_path" not in db[current_hash]: do_import = True - elif (Path(data_path) / db[hash]["data_path"]).stat().st_size < ( - Path(other_data_path) / data["data_path"] - ).stat().st_size: + elif is_smaller( + Path(data_path) / db[current_hash]["data_path"], + Path(other_data_path) / data["data_path"], + ): do_import = True if do_import: this_path = Path(data_path) / data["data_path"] + this_path_tmp = this_path.with_suffix(".part") other_path = Path(other_data_path) / data["data_path"] - this_path.parents[0].mkdir(exist_ok=True, parents=True) + config = data["model_config"] logging.info(f"Importing {other_path} to {this_path}.") - logging.info( - f"The model description is '{data.model_config.description}'." - ) - shutil.copy(other_path, this_path) + logging.info(f"The model description is '{config.description}'.") - db[hash] = data + if ( + interactive + and input(f"Import {other_path}?\n[Y/N]: ").upper() != "Y" + ): + continue + + this_path.parents[0].mkdir(exist_ok=True, parents=True) + + if is_smaller(this_path, other_path): + shutil.copy(other_path, this_path_tmp) + shutil.move(this_path_tmp, this_path) + + db[current_hash] = data -def cleanup(models_to_keep: list[Model], data_path: str = "./.data"): - """Delete all model data except ``models_to_keep`` from ``data_path``.""" +def cleanup( + models_to_keep: list[Model], data_path: str = "./.data", preview: bool = True +): + """Delete all model data except ``models_to_keep`` from + ``data_path``. If ``preview`` is :any:`True`, only warning + messages about which files would be deleted will be printed. + """ hashes_to_keep = [model.hexhash for model in models_to_keep] data_path_resolved = Path(data_path) with model_db(data_path) as db: for hash in list(db.keys()): if hash not in hashes_to_keep: - logging.info(f"Deleting model '{hash}'.") + logging.warning(f"Deleting model '{hash}'.") info = db[hash] if "data_path" in info: this_path = data_path_resolved / info["data_path"] @@ -192,5 +245,11 @@ def cleanup(models_to_keep: list[Model], data_path: str = "./.data"): while this_path.parent != data_path_resolved: this_path = this_path.parent - logging.debug(f"Removing '{this_path}'.") - # this_path.unlink() + logging.warning(f"Removing '{this_path}'.") + + if not preview: + this_path.unlink() + logging.warning(f"Done.") + + if not preview: + del db[hash]