two_qubit_model/hiro_models/model_auxiliary.py

331 lines
10 KiB
Python
Raw Normal View History

2022-03-21 14:58:57 +01:00
"""Functionality to integrate :any:`Model` instances and analyze the results."""
2022-04-07 14:30:30 +02:00
from typing import Any, Optional
2022-03-22 09:48:08 +01:00
from hops.core.hierarchy_data import HIData
2022-04-03 17:25:22 +02:00
from qutip.steadystate import _default_steadystate_args
2022-07-11 14:21:57 +02:00
from typing import Any
2022-03-22 10:25:58 +01:00
from .model_base import Model
2022-03-21 14:58:57 +01:00
from hops.core.integration import HOPSSupervisor
from contextlib import contextmanager
2022-03-23 14:59:02 +01:00
from .utility import JSONEncoder, object_hook
2022-03-21 14:58:57 +01:00
from filelock import FileLock
from pathlib import Path
from .one_qubit_model import QubitModel, QubitModelMutliBath
2022-03-22 10:25:58 +01:00
from .two_qubit_model import TwoQubitModel
2022-11-28 19:28:32 +01:00
from .otto_cycle import OttoEngine
2022-04-07 14:30:30 +02:00
from collections.abc import Sequence, Iterator, Iterable
2022-03-31 17:47:54 +02:00
import shutil
import logging
2022-04-11 14:49:02 +02:00
import copy
import os
2022-07-11 14:21:57 +02:00
import numpy as np
2022-03-21 14:58:57 +01:00
@contextmanager
2022-03-22 15:43:01 +01:00
def model_db(data_path: str = "./.data"):
2022-03-21 14:58:57 +01:00
"""
Opens the model database json file in the folder ``data_path`` as
a dictionary.
Mutations will be synchronized to the file. Access is managed via
a lock file.
"""
2022-03-22 09:34:10 +01:00
2022-03-22 15:43:01 +01:00
path = Path(data_path)
path.mkdir(exist_ok=True, parents=True)
2022-03-21 14:58:57 +01:00
2022-03-22 15:43:01 +01:00
db_path = path / "model_data.json"
db_lock = path / "model_data.json.lock"
with FileLock(db_lock):
2022-08-30 15:35:48 +02:00
db_path.touch(exist_ok=True)
2022-03-22 15:43:01 +01:00
with db_path.open("r+") as f:
data = f.read()
2022-04-07 14:30:30 +02:00
db = (
JSONEncoder.loads(data, object_hook=model_hook) if len(data) > 0 else {}
)
2022-03-21 14:58:57 +01:00
yield db
2022-03-22 15:43:01 +01:00
f.truncate(0)
f.seek(0)
2022-03-21 14:58:57 +01:00
f.write(JSONEncoder.dumps(db))
def model_hook(dct: dict[str, Any]):
"""A custom decoder for the model types."""
if "__model__" in dct:
model = dct["__model__"]
2022-04-07 14:30:30 +02:00
treated_vals = {
key: object_hook(val) if isinstance(val, dict) else val
for key, val in dct.items()
}
if model == "QubitModel":
2022-04-07 14:30:30 +02:00
return QubitModel.from_dict(treated_vals)
if model == "TwoQubitModel":
2022-04-07 14:30:30 +02:00
return TwoQubitModel.from_dict(treated_vals)
if model == "QubitModelMutliBath":
return QubitModelMutliBath.from_dict(treated_vals)
2022-11-28 19:28:32 +01:00
if model == "OttoEngine":
return OttoEngine.from_dict(treated_vals)
return object_hook(dct)
2022-03-24 16:34:42 +01:00
def integrate_multi(models: Sequence[Model], *args, **kwargs):
"""Integrate the hops equations for the ``models``.
Like :any:`integrate` just for many models.
A call to :any:`ray.init` may be required.
"""
for model in models:
integrate(model, *args, **kwargs)
2022-03-24 16:34:42 +01:00
def integrate(model: Model, n: int, data_path: str = "./.data", clear_pd: bool = False):
2022-03-21 14:58:57 +01:00
"""Integrate the hops equations for the model.
A call to :any:`ray.init` may be required.
:param n: The number of samples to be integrated.
:param clear_pd: Whether to clear the data file and redo the integration.
2022-03-21 14:58:57 +01:00
"""
2022-03-22 15:43:01 +01:00
hash = model.hexhash
2022-03-24 16:34:42 +01:00
# with model_db(data_path) as db:
# if hash in db and "data" db[hash]
2022-03-21 14:58:57 +01:00
supervisor = HOPSSupervisor(
2022-03-22 15:43:01 +01:00
model.hops_config,
n,
data_path=data_path,
data_name=hash,
2022-03-21 14:58:57 +01:00
)
supervisor.integrate(clear_pd)
2022-03-21 14:58:57 +01:00
with supervisor.get_data(True) as data:
with model_db(data_path) as db:
2022-03-22 09:48:08 +01:00
db[hash] = {
"model_config": model.to_dict(),
2022-03-21 14:58:57 +01:00
"data_path": str(Path(data.hdf5_name).relative_to(data_path)),
}
2022-03-22 09:48:08 +01:00
def get_data(
2022-03-22 15:43:01 +01:00
model: Model, data_path: str = "./.data", read_only: bool = True, **kwargs
2022-03-22 09:48:08 +01:00
) -> HIData:
"""
Get the integration data of the model ``model`` based on the
``data_path``. If ``read_only`` is :any:`True` the file is opened
in read-only mode. The ``kwargs`` are passed on to :any:`HIData`.
"""
2022-04-07 15:58:24 +02:00
hexhash = model.hexhash
2022-03-22 15:43:01 +01:00
2022-03-22 09:48:08 +01:00
with model_db(data_path) as db:
2022-04-07 15:58:24 +02:00
if hexhash in db and "data_path" in db[hexhash]:
path = Path(data_path) / db[hexhash]["data_path"]
try:
return HIData(path, read_only=read_only, robust=False, **kwargs)
2022-04-07 15:58:24 +02:00
except:
return HIData(
path,
hi_key=model.hops_config,
read_only=False,
check_consistency=False,
overwrite_key=True,
robust=False,
2022-04-07 15:58:24 +02:00
**kwargs,
)
2022-03-22 09:48:08 +01:00
else:
2022-04-07 15:58:24 +02:00
raise RuntimeError(f"No data found for model with hash '{hexhash}'.")
2022-03-31 17:47:54 +02:00
2022-04-03 17:25:22 +02:00
def model_data_iterator(
2022-04-07 14:30:30 +02:00
models: Iterable[Model], *args, **kwargs
2022-04-03 17:25:22 +02:00
) -> Iterator[tuple[Model, HIData]]:
"""
Yields tuples of ``model, data``, where ``data`` is already opened
and will be closed automatically.
For the rest of the arguments see :any:`get_data`.
"""
for model in models:
with get_data(model, *args, **kwargs) as data:
yield model, data
2022-04-07 14:30:30 +02:00
def is_smaller(first: Path, second: Path) -> bool:
"""
:returns: Wether the file ``first`` is smaller that ``second``.
"""
if not first.exists():
return True
return first.stat().st_size < second.stat().st_size
def import_results(
data_path: str = "./.data",
other_data_path: str = "./.data_other",
interactive: bool = False,
models_to_import: Optional[Iterable[Model]] = None,
):
2022-03-31 17:47:54 +02:00
"""
Imports results from the ``other_data_path`` into the
``other_data_path`` if the files are newer.
2022-04-07 14:30:30 +02:00
If ``interactive`` is any :any:`True`, the routine will ask before
copying.
If ``models_to_import`` is specified, only data of models matching
those in ``models_to_import`` will be imported.
2022-03-31 17:47:54 +02:00
"""
2022-04-07 14:30:30 +02:00
hashes_to_import = (
[model.hexhash for model in models_to_import] if models_to_import else []
)
with model_db(other_data_path) as other_db:
for current_hash, data in other_db.items():
with model_db(data_path) as db:
2022-04-03 17:25:06 +02:00
if "data_path" not in data:
continue
2022-03-31 17:47:54 +02:00
do_import = False
2022-04-07 14:30:30 +02:00
if hashes_to_import and current_hash not in hashes_to_import:
logging.info(f"Skipping {current_hash}.")
continue
if current_hash not in db:
2022-03-31 17:47:54 +02:00
do_import = True
2022-04-07 14:30:30 +02:00
elif "data_path" not in db[current_hash]:
2022-03-31 17:47:54 +02:00
do_import = True
2022-04-07 14:30:30 +02:00
elif is_smaller(
Path(data_path) / db[current_hash]["data_path"],
Path(other_data_path) / data["data_path"],
):
2022-03-31 17:47:54 +02:00
do_import = True
if do_import:
this_path = Path(data_path) / data["data_path"]
2022-04-07 14:30:30 +02:00
this_path_tmp = this_path.with_suffix(".part")
2022-03-31 17:47:54 +02:00
other_path = Path(other_data_path) / data["data_path"]
2022-04-07 14:30:30 +02:00
config = data["model_config"]
2022-04-07 14:38:31 +02:00
logging.warning(f"Importing {other_path} to {this_path}.")
logging.warning(f"The model description is '{config.description}'.")
2022-04-07 14:30:30 +02:00
if (
interactive
and input(f"Import {other_path}?\n[Y/N]: ").upper() != "Y"
):
continue
2022-04-03 17:25:06 +02:00
2022-04-07 14:30:30 +02:00
this_path.parents[0].mkdir(exist_ok=True, parents=True)
2022-04-06 18:28:53 +02:00
2022-04-07 14:30:30 +02:00
if is_smaller(this_path, other_path):
2022-04-07 14:46:11 +02:00
shutil.copy2(other_path, this_path_tmp)
2022-04-07 14:30:30 +02:00
shutil.move(this_path_tmp, this_path)
2022-04-06 18:28:53 +02:00
2022-04-07 14:30:30 +02:00
db[current_hash] = data
def cleanup(
models_to_keep: list[Model], data_path: str = "./.data", preview: bool = True
):
"""Delete all model data except ``models_to_keep`` from
``data_path``. If ``preview`` is :any:`True`, only warning
messages about which files would be deleted will be printed.
"""
2022-04-06 18:28:53 +02:00
hashes_to_keep = [model.hexhash for model in models_to_keep]
data_path_resolved = Path(data_path)
with model_db(data_path) as db:
for hash in list(db.keys()):
if hash not in hashes_to_keep:
2022-04-07 14:30:30 +02:00
logging.warning(f"Deleting model '{hash}'.")
2022-04-06 18:28:53 +02:00
info = db[hash]
if "data_path" in info:
this_path = data_path_resolved / info["data_path"]
while this_path.parent != data_path_resolved:
this_path = this_path.parent
2022-04-07 14:30:30 +02:00
logging.warning(f"Removing '{this_path}'.")
if not preview:
this_path.unlink()
logging.warning(f"Done.")
if not preview:
del db[hash]
2022-04-11 14:49:02 +02:00
def migrate_db_to_new_hashes(
data_path: str = "./.data", results_path: str = "./results"
):
2022-04-11 14:49:02 +02:00
"""
Recomputes all the hashes of the models in the database under
``data_path`` and updates the database.
"""
with model_db(data_path) as db:
for old_hash in list(db.keys()):
data = copy.deepcopy(db[old_hash])
new_hash = data["model_config"].hexhash
del db[old_hash]
db[new_hash] = data
for result in os.listdir(results_path):
if old_hash in result:
os.rename(
os.path.join(results_path, result),
os.path.join(results_path, result.replace(old_hash, new_hash)),
)
2022-07-11 14:21:57 +02:00
def model_diff_dict(models: Iterable[Model], **kwargs) -> dict[str, Any]:
"""
Generate a which only contains paramaters that differ from between
the instances in ``models``.
The ``kwargs`` are passed to :any:`Model.to_dict`.
"""
keys = set()
dicts = [model.to_dict(**kwargs) for model in models]
model_type = dicts[0]["__model__"]
for model_dict in dicts:
if model_dict["__model__"] != model_type:
raise ValueError("All compared models must be of the same type.")
for key, value in dicts[0].items():
last_value = value
for model_dict in dicts[1:]:
value = model_dict[key]
comp = last_value != value
if comp.all() if isinstance(value, np.ndarray) else comp:
keys.add(key)
break
last_value = value
return {key: [dct[key] for dct in dicts] for key in keys}