diff --git a/hiro_models/model_auxiliary.py b/hiro_models/model_auxiliary.py index 0962649..d87e1e5 100644 --- a/hiro_models/model_auxiliary.py +++ b/hiro_models/model_auxiliary.py @@ -4,6 +4,7 @@ from typing import Any, Optional from hops.core.hierarchy_data import HIData from qutip.steadystate import _default_steadystate_args +from typing import Any from .model_base import Model from hops.core.integration import HOPSSupervisor from contextlib import contextmanager @@ -17,6 +18,7 @@ import shutil import logging import copy import os +import numpy as np @contextmanager @@ -290,3 +292,33 @@ def migrate_db_to_new_hashes( os.path.join(results_path, result), os.path.join(results_path, result.replace(old_hash, new_hash)), ) + + +def model_diff_dict(models: Iterable[Model], **kwargs) -> dict[str, Any]: + """ + Generate a which only contains paramaters that differ from between + the instances in ``models``. + + The ``kwargs`` are passed to :any:`Model.to_dict`. + """ + + keys = set() + dicts = [model.to_dict(**kwargs) for model in models] + model_type = dicts[0]["__model__"] + + for model_dict in dicts: + if model_dict["__model__"] != model_type: + raise ValueError("All compared models must be of the same type.") + + for key, value in dicts[0].items(): + last_value = value + for model_dict in dicts[1:]: + value = model_dict[key] + comp = last_value != value + if comp.all() if isinstance(value, np.ndarray) else comp: + keys.add(key) + break + + last_value = value + + return {key: [dct[key] for dct in dicts] for key in keys}