make import more convenient

This commit is contained in:
Valentin Boettcher 2022-04-07 14:30:30 +02:00
parent da01ad7cfa
commit ebbdfe36b3

View file

@ -1,6 +1,6 @@
"""Functionality to integrate :any:`Model` instances and analyze the results.""" """Functionality to integrate :any:`Model` instances and analyze the results."""
from typing import Any from typing import Any, Optional
from hops.core.hierarchy_data import HIData from hops.core.hierarchy_data import HIData
from qutip.steadystate import _default_steadystate_args from qutip.steadystate import _default_steadystate_args
@ -12,7 +12,7 @@ from filelock import FileLock
from pathlib import Path from pathlib import Path
from .one_qubit_model import QubitModel from .one_qubit_model import QubitModel
from .two_qubit_model import TwoQubitModel from .two_qubit_model import TwoQubitModel
from collections.abc import Sequence, Iterator from collections.abc import Sequence, Iterator, Iterable
import shutil import shutil
import logging import logging
@ -38,7 +38,9 @@ def model_db(data_path: str = "./.data"):
with FileLock(db_lock): with FileLock(db_lock):
with db_path.open("r+") as f: with db_path.open("r+") as f:
data = f.read() data = f.read()
db = JSONEncoder.loads(data) if len(data) > 0 else {} db = (
JSONEncoder.loads(data, object_hook=model_hook) if len(data) > 0 else {}
)
yield db yield db
@ -53,11 +55,16 @@ def model_hook(dct: dict[str, Any]):
if "__model__" in dct: if "__model__" in dct:
model = dct["__model__"] model = dct["__model__"]
treated_vals = {
key: object_hook(val) if isinstance(val, dict) else val
for key, val in dct.items()
}
if model == "QubitModel": if model == "QubitModel":
return QubitModel.from_dict(dct) return QubitModel.from_dict(treated_vals)
if model == "TwoQubitModel": if model == "TwoQubitModel":
return TwoQubitModel.from_dict(dct) return TwoQubitModel.from_dict(treated_vals)
return object_hook(dct) return object_hook(dct)
@ -125,7 +132,7 @@ def get_data(
def model_data_iterator( def model_data_iterator(
models: Model, *args, **kwargs models: Iterable[Model], *args, **kwargs
) -> Iterator[tuple[Model, HIData]]: ) -> Iterator[tuple[Model, HIData]]:
""" """
Yields tuples of ``model, data``, where ``data`` is already opened Yields tuples of ``model, data``, where ``data`` is already opened
@ -139,52 +146,98 @@ def model_data_iterator(
yield model, data yield model, data
def import_results(data_path: str = "./.data", other_data_path: str = "./.data_other"): def is_smaller(first: Path, second: Path) -> bool:
"""
:returns: Wether the file ``first`` is smaller that ``second``.
"""
if not first.exists():
return True
return first.stat().st_size < second.stat().st_size
def import_results(
data_path: str = "./.data",
other_data_path: str = "./.data_other",
interactive: bool = False,
models_to_import: Optional[Iterable[Model]] = None,
):
""" """
Imports results from the ``other_data_path`` into the Imports results from the ``other_data_path`` into the
``other_data_path`` if the files are newer. ``other_data_path`` if the files are newer.
If ``interactive`` is any :any:`True`, the routine will ask before
copying.
If ``models_to_import`` is specified, only data of models matching
those in ``models_to_import`` will be imported.
""" """
with model_db(data_path) as db: hashes_to_import = (
with model_db(other_data_path) as other_db: [model.hexhash for model in models_to_import] if models_to_import else []
for hash, data in other_db.items(): )
with model_db(other_data_path) as other_db:
for current_hash, data in other_db.items():
with model_db(data_path) as db:
if "data_path" not in data: if "data_path" not in data:
continue continue
do_import = False do_import = False
if hash not in db: if hashes_to_import and current_hash not in hashes_to_import:
logging.info(f"Skipping {current_hash}.")
continue
if current_hash not in db:
do_import = True do_import = True
elif "data_path" not in db[hash]: elif "data_path" not in db[current_hash]:
do_import = True do_import = True
elif (Path(data_path) / db[hash]["data_path"]).stat().st_size < ( elif is_smaller(
Path(other_data_path) / data["data_path"] Path(data_path) / db[current_hash]["data_path"],
).stat().st_size: Path(other_data_path) / data["data_path"],
):
do_import = True do_import = True
if do_import: if do_import:
this_path = Path(data_path) / data["data_path"] this_path = Path(data_path) / data["data_path"]
this_path_tmp = this_path.with_suffix(".part")
other_path = Path(other_data_path) / data["data_path"] other_path = Path(other_data_path) / data["data_path"]
this_path.parents[0].mkdir(exist_ok=True, parents=True) config = data["model_config"]
logging.info(f"Importing {other_path} to {this_path}.") logging.info(f"Importing {other_path} to {this_path}.")
logging.info( logging.info(f"The model description is '{config.description}'.")
f"The model description is '{data.model_config.description}'."
)
shutil.copy(other_path, this_path)
db[hash] = data if (
interactive
and input(f"Import {other_path}?\n[Y/N]: ").upper() != "Y"
):
continue
this_path.parents[0].mkdir(exist_ok=True, parents=True)
if is_smaller(this_path, other_path):
shutil.copy(other_path, this_path_tmp)
shutil.move(this_path_tmp, this_path)
db[current_hash] = data
def cleanup(models_to_keep: list[Model], data_path: str = "./.data"): def cleanup(
"""Delete all model data except ``models_to_keep`` from ``data_path``.""" models_to_keep: list[Model], data_path: str = "./.data", preview: bool = True
):
"""Delete all model data except ``models_to_keep`` from
``data_path``. If ``preview`` is :any:`True`, only warning
messages about which files would be deleted will be printed.
"""
hashes_to_keep = [model.hexhash for model in models_to_keep] hashes_to_keep = [model.hexhash for model in models_to_keep]
data_path_resolved = Path(data_path) data_path_resolved = Path(data_path)
with model_db(data_path) as db: with model_db(data_path) as db:
for hash in list(db.keys()): for hash in list(db.keys()):
if hash not in hashes_to_keep: if hash not in hashes_to_keep:
logging.info(f"Deleting model '{hash}'.") logging.warning(f"Deleting model '{hash}'.")
info = db[hash] info = db[hash]
if "data_path" in info: if "data_path" in info:
this_path = data_path_resolved / info["data_path"] this_path = data_path_resolved / info["data_path"]
@ -192,5 +245,11 @@ def cleanup(models_to_keep: list[Model], data_path: str = "./.data"):
while this_path.parent != data_path_resolved: while this_path.parent != data_path_resolved:
this_path = this_path.parent this_path = this_path.parent
logging.debug(f"Removing '{this_path}'.") logging.warning(f"Removing '{this_path}'.")
# this_path.unlink()
if not preview:
this_path.unlink()
logging.warning(f"Done.")
if not preview:
del db[hash]