diff --git a/jobmanager/binfootprint.py b/jobmanager/binfootprint.py index 0c92cd8..2a92561 100644 --- a/jobmanager/binfootprint.py +++ b/jobmanager/binfootprint.py @@ -244,13 +244,13 @@ def _dump_tuple(t): b += _dump(ti) return b -def _load_tuple(b): +def _load_tuple(b, classes): assert comp_id(b[0], _TUPLE) size = struct.unpack('>I', b[1:5])[0] idx = 5 t = [] for i in range(size): - ob, len_ob = _load(b[idx:]) + ob, len_ob = _load(b[idx:], classes) t.append(ob) idx += len_ob return tuple(t), idx @@ -267,19 +267,19 @@ def _dump_namedtuple(t): b += _dump(t[i]) return b -def _load_namedtuple(b): +def _load_namedtuple(b, classes): assert comp_id(b[0], _NAMEDTUPLE) size = struct.unpack('>I', b[1:5])[0] - class_name, len_ob = _load(b[5:]) + class_name, len_ob = _load_str(b[5:]) idx = 5 + len_ob t = [] fields = [] for i in range(size): - ob, len_ob = _load(b[idx:]) + ob, len_ob = _load(b[idx:], classes) fields.append(ob) idx += len_ob - ob, len_ob = _load(b[idx:]) + ob, len_ob = _load(b[idx:], classes) t.append(ob) idx += len_ob @@ -293,13 +293,13 @@ def _dump_list(t): b += _dump(ti) return b -def _load_list(b): +def _load_list(b, classes): assert comp_id(b[0], _LIST) size = struct.unpack('>I', b[1:5])[0] idx = 5 t = [] for i in range(size): - ob, len_ob = _load(b[idx:]) + ob, len_ob = _load(b[idx:], classes) t.append(ob) idx += len_ob return t, idx @@ -327,28 +327,31 @@ def _dump_getstate(ob): return b -def _load_getstate(b): +def _load_getstate(b, classes): assert comp_id(b[0], _GETSTATE) - obj_type, l_obj_type = _load(b[1:]) - state, l_state = _load(b[l_obj_type+1:]) - return (obj_type, state), l_obj_type+l_state+1 + obj_type, l_obj_type = _load_str(b[1:]) + state, l_state = _load(b[l_obj_type+1:], classes) + cls = classes[obj_type] + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj, l_obj_type+l_state+1 def _dump_dict(ob): b = init_BYTES([_DICT]) keys = ob.keys() bin_keys = [] for k in keys: - bin_keys.append( (dump(k), dump(ob[k])) ) + bin_keys.append( (_dump(k), _dump(ob[k])) ) b += _dump_list(sorted(bin_keys)) return b -def _load_dict(b): +def _load_dict(b, classes): assert comp_id(b[0], _DICT) - sorted_keys_value, l = _load_list(b[1:]) + sorted_keys_value, l = _load_list(b[1:], classes) res_dict = {} for i in range(len(sorted_keys_value)): - key = load(sorted_keys_value[i][0]) - value = load(sorted_keys_value[i][1]) + key = _load(sorted_keys_value[i][0], classes)[0] + value = _load(sorted_keys_value[i][1], classes)[0] res_dict[key] = value return res_dict, l+1 @@ -385,7 +388,7 @@ def _dump(ob): else: raise RuntimeError("unsupported type for dump '{}'".format(type(ob))) -def _load(b): +def _load(b, classes): identifier = b[0] if isinstance(identifier, str): identifier = ord(identifier) @@ -404,17 +407,17 @@ def _load(b): elif identifier == _BYTES: return _load_bytes(b) elif identifier == _TUPLE: - return _load_tuple(b) + return _load_tuple(b, classes) elif identifier == _NAMEDTUPLE: - return _load_namedtuple(b) + return _load_namedtuple(b, classes) elif identifier == _LIST: - return _load_list(b) + return _load_list(b, classes) elif identifier == _NPARRAY: return _load_np_array(b) elif identifier == _DICT: - return _load_dict(b) + return _load_dict(b, classes) elif identifier == _GETSTATE: - return _load_getstate(b) + return _load_getstate(b, classes) else: raise BFLoadError("unknown identifier '{}'".format(hex(identifier))) @@ -433,20 +436,20 @@ def dump(ob, vers=_VERS): return res -def load(b): +def load(b, classes=None): """ reconstruct the object from the binary footprint given an bytes 'ba' """ global _load vers = b[0] if byte_to_ord(vers) == _VERS: - return _load(b[1:])[0] + return _load(b[1:], classes)[0] elif byte_to_ord(vers) < 0x80: # very first version # has not even a version tag __load_tmp = _load _load = _load_00 - res = _load(b)[0] + res = _load(b, classes)[0] _load = __load_tmp return res @@ -463,14 +466,14 @@ def load(b): # so the first two bytes must correspond to an identifier which are assumed # to be < 128 = 0x80 -def _load_namedtuple_00(b): +def _load_namedtuple_00(b, classes): assert comp_id(b[0], _NAMEDTUPLE) size = struct.unpack('>I', b[1:5])[0] - class_name, len_ob = _load(b[5:]) + class_name, len_ob = _load_str(b[5:]) idx = 5 + len_ob t = [] for i in range(size): - ob, len_ob = _load(b[idx:]) + ob, len_ob = _load(b[idx:], classes) t.append(ob) idx += len_ob return (class_name, tuple(t)), idx @@ -486,7 +489,7 @@ def _dump_namedtuple_00(t): return b -def _load_00(b): +def _load_00(b, classes): identifier = b[0] if isinstance(identifier, str): identifier = ord(identifier) @@ -505,17 +508,17 @@ def _load_00(b): elif identifier == _BYTES: return _load_bytes(b) elif identifier == _TUPLE: - return _load_tuple(b) + return _load_tuple(b, classes) elif identifier == _NAMEDTUPLE: - return _load_namedtuple_00(b) + return _load_namedtuple_00(b, classes) elif identifier == _LIST: - return _load_list(b) + return _load_list(b, classes) elif identifier == _NPARRAY: return _load_np_array(b) elif identifier == _DICT: - return _load_dict(b) + return _load_dict(b, classes) elif identifier == _GETSTATE: - return _load_getstate(b) + return _load_getstate(b, classes) else: raise BFLoadError("unknown identifier '{}'".format(hex(identifier))) diff --git a/tests/test_binfootprint.py b/tests/test_binfootprint.py index 48ca9bb..3332045 100644 --- a/tests/test_binfootprint.py +++ b/tests/test_binfootprint.py @@ -38,44 +38,44 @@ def test_atom(): def test_tuple(): t = (12345678, 3.141, 'hallo Welt', 'öäüß', True, False, None, (3, tuple(), (4,5,None), 'test')) - bin_tuple = bfp._dump_tuple(t) + bin_tuple = bfp.dump(t) assert type(bin_tuple) is bfp.BIN_TYPE - t_prime = bfp._load_tuple(bin_tuple)[0] + t_prime = bfp.load(bin_tuple) assert t == t_prime - bin_ob_prime = bfp._dump(t_prime) + bin_ob_prime = bfp.dump(t_prime) assert bin_tuple == bin_ob_prime def test_nparray(): ob = np.random.randn(3,53,2) - bin_ob = bfp._dump(ob) + bin_ob = bfp.dump(ob) assert type(bin_ob) is bfp.BIN_TYPE - ob_prime = bfp._load(bin_ob)[0] + ob_prime = bfp.load(bin_ob) assert np.all(ob == ob_prime) - bin_ob_prime = bfp._dump(ob_prime) + bin_ob_prime = bfp.dump(ob_prime) assert bin_ob == bin_ob_prime ob = np.random.randn(3,53,2) ob = (ob, ob, 4, None) - bin_ob = bfp._dump(ob) - ob_prime = bfp._load(bin_ob)[0] + bin_ob = bfp.dump(ob) + ob_prime = bfp.load(bin_ob) assert np.all(ob[0] == ob_prime[0]) assert np.all(ob[1] == ob_prime[1]) - bin_ob_prime = bfp._dump(ob_prime) + bin_ob_prime = bfp.dump(ob_prime) assert bin_ob == bin_ob_prime def test_list(): ob = [1,2,3] - bin_ob = bfp._dump(ob) + bin_ob = bfp.dump(ob) assert type(bin_ob) is bfp.BIN_TYPE - ob_prime = bfp._load(bin_ob)[0] + ob_prime = bfp.load(bin_ob) assert np.all(ob == ob_prime) - bin_ob_prime = bfp._dump(ob_prime) + bin_ob_prime = bfp.dump(ob_prime) assert bin_ob == bin_ob_prime ob = [1, (2,3), np.array([2j,3j])] - bin_ob = bfp._dump(ob) - ob_prime = bfp._load(bin_ob)[0] - bin_ob_prime = bfp._dump(ob_prime) + bin_ob = bfp.dump(ob) + ob_prime = bfp.load(bin_ob) + bin_ob_prime = bfp.dump(ob_prime) assert bin_ob == bin_ob_prime assert np.all(ob[0] == ob_prime[0]) @@ -92,14 +92,16 @@ def test_getstate(): self.a = state[0] ob = T(4) - bin_ob = bfp._dump(ob) + bin_ob = bfp.dump(ob) assert type(bin_ob) is bfp.BIN_TYPE - ob_prime_state = bfp._load(bin_ob)[0] - ob_prime = T.__new__(T) - ob_prime.__setstate__(ob_prime_state[1]) + + classes = {} + classes['T'] = T + + ob_prime = bfp.load(bin_ob, classes) assert np.all(ob.a == ob_prime.a) - bin_ob_prime = bfp._dump(ob_prime) + bin_ob_prime = bfp.dump(ob_prime) assert bin_ob == bin_ob_prime def test_named_tuple(): @@ -107,9 +109,9 @@ def test_named_tuple(): obj = obj_type(12345678, 3.141, 'hallo Welt') - bin_obj = bfp._dump(obj) + bin_obj = bfp.dump(obj) assert type(bin_obj) is bfp.BIN_TYPE - obj_prime = bfp._load(bin_obj)[0] + obj_prime = bfp.load(bin_obj) obj_prime_name, obj_prime_data_values, obj_prime_data_fields = obj_prime assert obj_prime_name == obj.__class__.__name__ @@ -118,7 +120,7 @@ def test_named_tuple(): assert obj_prime._fields == obj_prime_data_fields assert obj_prime == obj - bin_ob_prime = bfp._dump(obj_prime) + bin_ob_prime = bfp.dump(obj_prime) assert bin_obj == bin_ob_prime def test_complex(): @@ -132,8 +134,7 @@ def test_dict(): a = {'a':1, 5:5, 3+4j:'l', False: b'ab4+#'} bf = bfp.dump(a) assert type(bf) is bfp.BIN_TYPE - a_restored = bfp.load(bf) - + a_restored = bfp.load(bf) for k in a: assert a[k] == a_restored[k]