two_qubit_model/hiro_models/model_auxiliary.py

197 lines
5.9 KiB
Python
Raw Normal View History

2022-03-21 14:58:57 +01:00
"""Functionality to integrate :any:`Model` instances and analyze the results."""
from typing import Any
2022-03-22 09:48:08 +01:00
from hops.core.hierarchy_data import HIData
2022-04-03 17:25:22 +02:00
from qutip.steadystate import _default_steadystate_args
2022-03-22 10:25:58 +01:00
from .model_base import Model
2022-03-21 14:58:57 +01:00
from hops.core.integration import HOPSSupervisor
from contextlib import contextmanager
2022-03-23 14:59:02 +01:00
from .utility import JSONEncoder, object_hook
2022-03-21 14:58:57 +01:00
from filelock import FileLock
from pathlib import Path
2022-03-22 10:25:58 +01:00
from .one_qubit_model import QubitModel
from .two_qubit_model import TwoQubitModel
2022-04-03 17:25:22 +02:00
from collections.abc import Sequence, Iterator
2022-03-31 17:47:54 +02:00
import shutil
import logging
2022-03-21 14:58:57 +01:00
@contextmanager
2022-03-22 15:43:01 +01:00
def model_db(data_path: str = "./.data"):
2022-03-21 14:58:57 +01:00
"""
Opens the model database json file in the folder ``data_path`` as
a dictionary.
Mutations will be synchronized to the file. Access is managed via
a lock file.
"""
2022-03-22 09:34:10 +01:00
2022-03-22 15:43:01 +01:00
path = Path(data_path)
path.mkdir(exist_ok=True, parents=True)
2022-03-21 14:58:57 +01:00
2022-03-22 15:43:01 +01:00
db_path = path / "model_data.json"
db_lock = path / "model_data.json.lock"
db_path.touch(exist_ok=True)
with FileLock(db_lock):
with db_path.open("r+") as f:
data = f.read()
db = JSONEncoder.loads(data) if len(data) > 0 else {}
2022-03-21 14:58:57 +01:00
yield db
2022-03-22 15:43:01 +01:00
f.truncate(0)
f.seek(0)
2022-03-21 14:58:57 +01:00
f.write(JSONEncoder.dumps(db))
def model_hook(dct: dict[str, Any]):
"""A custom decoder for the model types."""
if "__model__" in dct:
model = dct["__model__"]
if model == "QubitModel":
return QubitModel.from_dict(dct)
if model == "TwoQubitModel":
return TwoQubitModel.from_dict(dct)
return object_hook(dct)
2022-03-24 16:34:42 +01:00
def integrate_multi(models: Sequence[Model], *args, **kwargs):
"""Integrate the hops equations for the ``models``.
Like :any:`integrate` just for many models.
A call to :any:`ray.init` may be required.
"""
for model in models:
integrate(model, *args, **kwargs)
2022-03-24 16:34:42 +01:00
def integrate(model: Model, n: int, data_path: str = "./.data", clear_pd: bool = False):
2022-03-21 14:58:57 +01:00
"""Integrate the hops equations for the model.
A call to :any:`ray.init` may be required.
:param n: The number of samples to be integrated.
:param clear_pd: Whether to clear the data file and redo the integration.
2022-03-21 14:58:57 +01:00
"""
2022-03-22 15:43:01 +01:00
hash = model.hexhash
2022-03-24 16:34:42 +01:00
# with model_db(data_path) as db:
# if hash in db and "data" db[hash]
2022-03-21 14:58:57 +01:00
supervisor = HOPSSupervisor(
2022-03-22 15:43:01 +01:00
model.hops_config,
n,
data_path=data_path,
data_name=hash,
2022-03-21 14:58:57 +01:00
)
supervisor.integrate(clear_pd)
2022-03-21 14:58:57 +01:00
with supervisor.get_data(True) as data:
with model_db(data_path) as db:
2022-03-22 09:48:08 +01:00
db[hash] = {
"model_config": model.to_dict(),
2022-03-21 14:58:57 +01:00
"data_path": str(Path(data.hdf5_name).relative_to(data_path)),
}
2022-03-22 09:48:08 +01:00
def get_data(
2022-03-22 15:43:01 +01:00
model: Model, data_path: str = "./.data", read_only: bool = True, **kwargs
2022-03-22 09:48:08 +01:00
) -> HIData:
"""
Get the integration data of the model ``model`` based on the
``data_path``. If ``read_only`` is :any:`True` the file is opened
in read-only mode. The ``kwargs`` are passed on to :any:`HIData`.
"""
2022-03-22 15:43:01 +01:00
hash = model.hexhash
2022-03-22 09:48:08 +01:00
with model_db(data_path) as db:
if hash in db and "data_path" in db[hash]:
2022-03-23 12:55:37 +01:00
return HIData(
Path(data_path) / db[hash]["data_path"], read_only=read_only, **kwargs
)
2022-03-22 09:48:08 +01:00
else:
raise RuntimeError(f"No data found for model with hash '{hash}'.")
2022-03-31 17:47:54 +02:00
2022-04-03 17:25:22 +02:00
def model_data_iterator(
models: Model, *args, **kwargs
) -> Iterator[tuple[Model, HIData]]:
"""
Yields tuples of ``model, data``, where ``data`` is already opened
and will be closed automatically.
For the rest of the arguments see :any:`get_data`.
"""
for model in models:
with get_data(model, *args, **kwargs) as data:
yield model, data
2022-03-31 17:47:54 +02:00
def import_results(data_path: str = "./.data", other_data_path: str = "./.data_other"):
"""
Imports results from the ``other_data_path`` into the
``other_data_path`` if the files are newer.
"""
with model_db(data_path) as db:
with model_db(other_data_path) as other_db:
for hash, data in other_db.items():
2022-04-03 17:25:06 +02:00
if "data_path" not in data:
continue
2022-03-31 17:47:54 +02:00
do_import = False
if hash not in db:
do_import = True
elif "data_path" not in db[hash]:
do_import = True
elif (Path(data_path) / db[hash]["data_path"]).stat().st_size < (
Path(other_data_path) / data["data_path"]
).stat().st_size:
do_import = True
if do_import:
this_path = Path(data_path) / data["data_path"]
other_path = Path(other_data_path) / data["data_path"]
this_path.parents[0].mkdir(exist_ok=True, parents=True)
logging.info(f"Importing {other_path} to {this_path}.")
2022-04-06 18:28:53 +02:00
logging.info(
f"The model description is '{data.model_config.description}'."
)
2022-03-31 17:47:54 +02:00
shutil.copy(other_path, this_path)
2022-04-03 17:25:06 +02:00
2022-03-31 17:47:54 +02:00
db[hash] = data
2022-04-06 18:28:53 +02:00
def cleanup(models_to_keep: list[Model], data_path: str = "./.data"):
"""Delete all model data except ``models_to_keep`` from ``data_path``."""
hashes_to_keep = [model.hexhash for model in models_to_keep]
data_path_resolved = Path(data_path)
with model_db(data_path) as db:
for hash in list(db.keys()):
if hash not in hashes_to_keep:
logging.info(f"Deleting model '{hash}'.")
info = db[hash]
if "data_path" in info:
this_path = data_path_resolved / info["data_path"]
while this_path.parent != data_path_resolved:
this_path = this_path.parent
logging.debug(f"Removing '{this_path}'.")
# this_path.unlink()