fix some serialization issues

This commit is contained in:
Valentin Boettcher 2022-12-08 19:17:16 -05:00
parent 13a684aead
commit a9876fd423
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE
2 changed files with 27 additions and 17 deletions

View file

@ -63,7 +63,6 @@ def model_hook(dct: dict[str, Any]):
if "__model__" in dct:
model = dct["__model__"]
treated_vals = {
key: object_hook(val) if isinstance(val, dict) else val
for key, val in dct.items()
@ -81,6 +80,10 @@ def model_hook(dct: dict[str, Any]):
if model == "OttoEngine":
return OttoEngine.from_dict(treated_vals)
for key, value in dct.items():
if isinstance(value, dict):
dct[key] = model_hook(value)
return dct
@ -375,7 +378,6 @@ def migrate_db_to_new_hashes(
for old_hash in list(db.keys()):
data = copy.deepcopy(db[old_hash])
new_hash = data["model_config"].hexhash
print(new_hash == old_hash)
del db[old_hash]
db[new_hash] = data

View file

@ -13,6 +13,7 @@ import hops.util.dynamic_matrix as dynamic_matrix
from hops.util.dynamic_matrix import DynamicMatrix, SmoothStep
import scipy.special
from numpy.typing import NDArray
from collections.abc import Iterable
@beartype
@ -30,6 +31,24 @@ class StocProcTolerances:
"""Interpolation tolerance."""
def hint_tuples(item: Any, rec=False):
if hasattr(item, "to_dict"):
item = item.to_dict()
if isinstance(item, dict):
return {key: hint_tuples(value, True) for key, value in item.items()}
if isinstance(item, tuple):
return {
"type": "tuple",
"value": [hint_tuples(i) for i in item],
}
if isinstance(item, list):
return [hint_tuples(e, True) for e in item]
else:
return item
class JSONEncoder(json.JSONEncoder):
"""
A custom encoder to serialize objects occuring in
@ -37,25 +56,14 @@ class JSONEncoder(json.JSONEncoder):
"""
def encode(self, obj: Any):
def hint_tuples(item: Any):
if isinstance(item, tuple):
return {
"type": "tuple",
"value": [
hint_tuples(i) if isinstance(i, tuple) else i for i in item
],
}
if isinstance(item, list):
return [hint_tuples(e) for e in item]
if isinstance(item, dict):
return {key: hint_tuples(value) for key, value in item.items()}
else:
return item
return super().encode(hint_tuples(obj))
@singledispatchmethod
def default(self, obj: Any):
if isinstance(obj, tuple):
import pdb
pdb.set_trace()
if hasattr(obj, "to_dict"):
return obj.to_dict()