mirror of
https://github.com/vale981/two_qubit_model
synced 2025-03-05 09:41:41 -05:00
implement model diff dict
This commit is contained in:
parent
b15aa6d77d
commit
930ed00b6e
1 changed files with 32 additions and 0 deletions
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from hops.core.hierarchy_data import HIData
|
||||
from qutip.steadystate import _default_steadystate_args
|
||||
from typing import Any
|
||||
from .model_base import Model
|
||||
from hops.core.integration import HOPSSupervisor
|
||||
from contextlib import contextmanager
|
||||
|
@ -17,6 +18,7 @@ import shutil
|
|||
import logging
|
||||
import copy
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -290,3 +292,33 @@ def migrate_db_to_new_hashes(
|
|||
os.path.join(results_path, result),
|
||||
os.path.join(results_path, result.replace(old_hash, new_hash)),
|
||||
)
|
||||
|
||||
|
||||
def model_diff_dict(models: Iterable[Model], **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a which only contains paramaters that differ from between
|
||||
the instances in ``models``.
|
||||
|
||||
The ``kwargs`` are passed to :any:`Model.to_dict`.
|
||||
"""
|
||||
|
||||
keys = set()
|
||||
dicts = [model.to_dict(**kwargs) for model in models]
|
||||
model_type = dicts[0]["__model__"]
|
||||
|
||||
for model_dict in dicts:
|
||||
if model_dict["__model__"] != model_type:
|
||||
raise ValueError("All compared models must be of the same type.")
|
||||
|
||||
for key, value in dicts[0].items():
|
||||
last_value = value
|
||||
for model_dict in dicts[1:]:
|
||||
value = model_dict[key]
|
||||
comp = last_value != value
|
||||
if comp.all() if isinstance(value, np.ndarray) else comp:
|
||||
keys.add(key)
|
||||
break
|
||||
|
||||
last_value = value
|
||||
|
||||
return {key: [dct[key] for dct in dicts] for key in keys}
|
||||
|
|
Loading…
Add table
Reference in a new issue