mirror of
https://github.com/vale981/two_qubit_model
synced 2025-03-04 17:21:43 -05:00
fix some serialization issues
This commit is contained in:
parent
13a684aead
commit
a9876fd423
2 changed files with 27 additions and 17 deletions
|
@ -63,7 +63,6 @@ def model_hook(dct: dict[str, Any]):
|
|||
|
||||
if "__model__" in dct:
|
||||
model = dct["__model__"]
|
||||
|
||||
treated_vals = {
|
||||
key: object_hook(val) if isinstance(val, dict) else val
|
||||
for key, val in dct.items()
|
||||
|
@ -81,6 +80,10 @@ def model_hook(dct: dict[str, Any]):
|
|||
if model == "OttoEngine":
|
||||
return OttoEngine.from_dict(treated_vals)
|
||||
|
||||
for key, value in dct.items():
|
||||
if isinstance(value, dict):
|
||||
dct[key] = model_hook(value)
|
||||
|
||||
return dct
|
||||
|
||||
|
||||
|
@ -375,7 +378,6 @@ def migrate_db_to_new_hashes(
|
|||
for old_hash in list(db.keys()):
|
||||
data = copy.deepcopy(db[old_hash])
|
||||
new_hash = data["model_config"].hexhash
|
||||
print(new_hash == old_hash)
|
||||
del db[old_hash]
|
||||
db[new_hash] = data
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import hops.util.dynamic_matrix as dynamic_matrix
|
|||
from hops.util.dynamic_matrix import DynamicMatrix, SmoothStep
|
||||
import scipy.special
|
||||
from numpy.typing import NDArray
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
@beartype
|
||||
|
@ -30,6 +31,24 @@ class StocProcTolerances:
|
|||
"""Interpolation tolerance."""
|
||||
|
||||
|
||||
def hint_tuples(item: Any, rec=False):
|
||||
if hasattr(item, "to_dict"):
|
||||
item = item.to_dict()
|
||||
|
||||
if isinstance(item, dict):
|
||||
return {key: hint_tuples(value, True) for key, value in item.items()}
|
||||
if isinstance(item, tuple):
|
||||
return {
|
||||
"type": "tuple",
|
||||
"value": [hint_tuples(i) for i in item],
|
||||
}
|
||||
|
||||
if isinstance(item, list):
|
||||
return [hint_tuples(e, True) for e in item]
|
||||
else:
|
||||
return item
|
||||
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
A custom encoder to serialize objects occuring in
|
||||
|
@ -37,25 +56,14 @@ class JSONEncoder(json.JSONEncoder):
|
|||
"""
|
||||
|
||||
def encode(self, obj: Any):
|
||||
def hint_tuples(item: Any):
|
||||
if isinstance(item, tuple):
|
||||
return {
|
||||
"type": "tuple",
|
||||
"value": [
|
||||
hint_tuples(i) if isinstance(i, tuple) else i for i in item
|
||||
],
|
||||
}
|
||||
if isinstance(item, list):
|
||||
return [hint_tuples(e) for e in item]
|
||||
if isinstance(item, dict):
|
||||
return {key: hint_tuples(value) for key, value in item.items()}
|
||||
else:
|
||||
return item
|
||||
|
||||
return super().encode(hint_tuples(obj))
|
||||
|
||||
@singledispatchmethod
|
||||
def default(self, obj: Any):
|
||||
if isinstance(obj, tuple):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
if hasattr(obj, "to_dict"):
|
||||
return obj.to_dict()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue