mirror of
https://github.com/vale981/two_qubit_model
synced 2025-03-05 09:41:41 -05:00
make import more convenient
This commit is contained in:
parent
da01ad7cfa
commit
ebbdfe36b3
1 changed files with 85 additions and 26 deletions
|
@ -1,6 +1,6 @@
|
|||
"""Functionality to integrate :any:`Model` instances and analyze the results."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from hops.core.hierarchy_data import HIData
|
||||
from qutip.steadystate import _default_steadystate_args
|
||||
|
@ -12,7 +12,7 @@ from filelock import FileLock
|
|||
from pathlib import Path
|
||||
from .one_qubit_model import QubitModel
|
||||
from .two_qubit_model import TwoQubitModel
|
||||
from collections.abc import Sequence, Iterator
|
||||
from collections.abc import Sequence, Iterator, Iterable
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
|
@ -38,7 +38,9 @@ def model_db(data_path: str = "./.data"):
|
|||
with FileLock(db_lock):
|
||||
with db_path.open("r+") as f:
|
||||
data = f.read()
|
||||
db = JSONEncoder.loads(data) if len(data) > 0 else {}
|
||||
db = (
|
||||
JSONEncoder.loads(data, object_hook=model_hook) if len(data) > 0 else {}
|
||||
)
|
||||
|
||||
yield db
|
||||
|
||||
|
@ -53,11 +55,16 @@ def model_hook(dct: dict[str, Any]):
|
|||
if "__model__" in dct:
|
||||
model = dct["__model__"]
|
||||
|
||||
treated_vals = {
|
||||
key: object_hook(val) if isinstance(val, dict) else val
|
||||
for key, val in dct.items()
|
||||
}
|
||||
|
||||
if model == "QubitModel":
|
||||
return QubitModel.from_dict(dct)
|
||||
return QubitModel.from_dict(treated_vals)
|
||||
|
||||
if model == "TwoQubitModel":
|
||||
return TwoQubitModel.from_dict(dct)
|
||||
return TwoQubitModel.from_dict(treated_vals)
|
||||
|
||||
return object_hook(dct)
|
||||
|
||||
|
@ -125,7 +132,7 @@ def get_data(
|
|||
|
||||
|
||||
def model_data_iterator(
|
||||
models: Model, *args, **kwargs
|
||||
models: Iterable[Model], *args, **kwargs
|
||||
) -> Iterator[tuple[Model, HIData]]:
|
||||
"""
|
||||
Yields tuples of ``model, data``, where ``data`` is already opened
|
||||
|
@ -139,52 +146,98 @@ def model_data_iterator(
|
|||
yield model, data
|
||||
|
||||
|
||||
def import_results(data_path: str = "./.data", other_data_path: str = "./.data_other"):
|
||||
def is_smaller(first: Path, second: Path) -> bool:
|
||||
"""
|
||||
:returns: Wether the file ``first`` is smaller that ``second``.
|
||||
"""
|
||||
|
||||
if not first.exists():
|
||||
return True
|
||||
|
||||
return first.stat().st_size < second.stat().st_size
|
||||
|
||||
|
||||
def import_results(
|
||||
data_path: str = "./.data",
|
||||
other_data_path: str = "./.data_other",
|
||||
interactive: bool = False,
|
||||
models_to_import: Optional[Iterable[Model]] = None,
|
||||
):
|
||||
"""
|
||||
Imports results from the ``other_data_path`` into the
|
||||
``other_data_path`` if the files are newer.
|
||||
|
||||
If ``interactive`` is any :any:`True`, the routine will ask before
|
||||
copying.
|
||||
|
||||
If ``models_to_import`` is specified, only data of models matching
|
||||
those in ``models_to_import`` will be imported.
|
||||
"""
|
||||
|
||||
with model_db(data_path) as db:
|
||||
with model_db(other_data_path) as other_db:
|
||||
for hash, data in other_db.items():
|
||||
hashes_to_import = (
|
||||
[model.hexhash for model in models_to_import] if models_to_import else []
|
||||
)
|
||||
|
||||
with model_db(other_data_path) as other_db:
|
||||
for current_hash, data in other_db.items():
|
||||
with model_db(data_path) as db:
|
||||
if "data_path" not in data:
|
||||
continue
|
||||
|
||||
do_import = False
|
||||
|
||||
if hash not in db:
|
||||
if hashes_to_import and current_hash not in hashes_to_import:
|
||||
logging.info(f"Skipping {current_hash}.")
|
||||
continue
|
||||
|
||||
if current_hash not in db:
|
||||
do_import = True
|
||||
elif "data_path" not in db[hash]:
|
||||
elif "data_path" not in db[current_hash]:
|
||||
do_import = True
|
||||
elif (Path(data_path) / db[hash]["data_path"]).stat().st_size < (
|
||||
Path(other_data_path) / data["data_path"]
|
||||
).stat().st_size:
|
||||
elif is_smaller(
|
||||
Path(data_path) / db[current_hash]["data_path"],
|
||||
Path(other_data_path) / data["data_path"],
|
||||
):
|
||||
do_import = True
|
||||
|
||||
if do_import:
|
||||
this_path = Path(data_path) / data["data_path"]
|
||||
this_path_tmp = this_path.with_suffix(".part")
|
||||
other_path = Path(other_data_path) / data["data_path"]
|
||||
|
||||
this_path.parents[0].mkdir(exist_ok=True, parents=True)
|
||||
config = data["model_config"]
|
||||
logging.info(f"Importing {other_path} to {this_path}.")
|
||||
logging.info(
|
||||
f"The model description is '{data.model_config.description}'."
|
||||
)
|
||||
shutil.copy(other_path, this_path)
|
||||
logging.info(f"The model description is '{config.description}'.")
|
||||
|
||||
db[hash] = data
|
||||
if (
|
||||
interactive
|
||||
and input(f"Import {other_path}?\n[Y/N]: ").upper() != "Y"
|
||||
):
|
||||
continue
|
||||
|
||||
this_path.parents[0].mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if is_smaller(this_path, other_path):
|
||||
shutil.copy(other_path, this_path_tmp)
|
||||
shutil.move(this_path_tmp, this_path)
|
||||
|
||||
db[current_hash] = data
|
||||
|
||||
|
||||
def cleanup(models_to_keep: list[Model], data_path: str = "./.data"):
|
||||
"""Delete all model data except ``models_to_keep`` from ``data_path``."""
|
||||
def cleanup(
|
||||
models_to_keep: list[Model], data_path: str = "./.data", preview: bool = True
|
||||
):
|
||||
"""Delete all model data except ``models_to_keep`` from
|
||||
``data_path``. If ``preview`` is :any:`True`, only warning
|
||||
messages about which files would be deleted will be printed.
|
||||
"""
|
||||
|
||||
hashes_to_keep = [model.hexhash for model in models_to_keep]
|
||||
data_path_resolved = Path(data_path)
|
||||
with model_db(data_path) as db:
|
||||
for hash in list(db.keys()):
|
||||
if hash not in hashes_to_keep:
|
||||
logging.info(f"Deleting model '{hash}'.")
|
||||
logging.warning(f"Deleting model '{hash}'.")
|
||||
info = db[hash]
|
||||
if "data_path" in info:
|
||||
this_path = data_path_resolved / info["data_path"]
|
||||
|
@ -192,5 +245,11 @@ def cleanup(models_to_keep: list[Model], data_path: str = "./.data"):
|
|||
while this_path.parent != data_path_resolved:
|
||||
this_path = this_path.parent
|
||||
|
||||
logging.debug(f"Removing '{this_path}'.")
|
||||
# this_path.unlink()
|
||||
logging.warning(f"Removing '{this_path}'.")
|
||||
|
||||
if not preview:
|
||||
this_path.unlink()
|
||||
logging.warning(f"Done.")
|
||||
|
||||
if not preview:
|
||||
del db[hash]
|
||||
|
|
Loading…
Add table
Reference in a new issue