implement model diff dict

This commit is contained in:
Valentin Boettcher 2022-07-11 14:21:57 +02:00
parent b15aa6d77d
commit 930ed00b6e

View file

@ -4,6 +4,7 @@ from typing import Any, Optional
from hops.core.hierarchy_data import HIData
from qutip.steadystate import _default_steadystate_args
from typing import Any
from .model_base import Model
from hops.core.integration import HOPSSupervisor
from contextlib import contextmanager
@ -17,6 +18,7 @@ import shutil
import logging
import copy
import os
import numpy as np
@contextmanager
@ -290,3 +292,33 @@ def migrate_db_to_new_hashes(
os.path.join(results_path, result),
os.path.join(results_path, result.replace(old_hash, new_hash)),
)
def model_diff_dict(models: Iterable[Model], **kwargs) -> dict[str, Any]:
"""
Generate a which only contains paramaters that differ from between
the instances in ``models``.
The ``kwargs`` are passed to :any:`Model.to_dict`.
"""
keys = set()
dicts = [model.to_dict(**kwargs) for model in models]
model_type = dicts[0]["__model__"]
for model_dict in dicts:
if model_dict["__model__"] != model_type:
raise ValueError("All compared models must be of the same type.")
for key, value in dicts[0].items():
last_value = value
for model_dict in dicts[1:]:
value = model_dict[key]
comp = last_value != value
if comp.all() if isinstance(value, np.ndarray) else comp:
keys.add(key)
break
last_value = value
return {key: [dct[key] for dct in dicts] for key in keys}