convinced the PersistentDataStructure to save np.ndarrays in sparate file

This commit is contained in:
Richard Hartmann 2015-12-03 17:21:14 +01:00
parent b60f7870a2
commit 988ed36d18
2 changed files with 141 additions and 26 deletions

View file

@ -8,6 +8,13 @@ import shutil
import traceback import traceback
import pickle import pickle
import warnings import warnings
try:
import numpy as np
_NP = True
except ImportError:
warnings.warn("could not import 'numpy', I can not treat np.ndarray separately!")
_NP = False
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
# fixes keyword problems with python 2.x # fixes keyword problems with python 2.x
@ -22,11 +29,14 @@ if sys.version_info[0] == 2:
os.rmdir = new_rmdir os.rmdir = new_rmdir
MAGIC_SIGN = 0xff4a87 MAGIC_SIGN = 0xff4a87
MAGIC_SIGN_NPARRAY = 0xee4a87
KEY_COUNTER = '0' KEY_COUNTER = '0'
KEY_SUB_DATA_KEYS = '1' KEY_SUB_DATA_KEYS = '1'
KEY_NPARRAY_COUNTER = 'np0'
KEY_NPARRAY_KEYS = 'np1'
RESERVED_KEYS = (KEY_COUNTER, KEY_SUB_DATA_KEYS) RESERVED_KEYS = (KEY_COUNTER, KEY_SUB_DATA_KEYS, KEY_NPARRAY_COUNTER, KEY_NPARRAY_KEYS)
def key_to_str(key, max_len = 255): def key_to_str(key, max_len = 255):
if isinstance(key, (bytearray, bytes)): if isinstance(key, (bytearray, bytes)):
@ -80,11 +90,19 @@ class PersistentDataStructure(object):
if KEY_SUB_DATA_KEYS not in self.db: if KEY_SUB_DATA_KEYS not in self.db:
self.db[KEY_SUB_DATA_KEYS] = set() self.db[KEY_SUB_DATA_KEYS] = set()
if KEY_NPARRAY_COUNTER not in self.db:
self.db[KEY_NPARRAY_COUNTER] = 0
if KEY_NPARRAY_KEYS not in self.db:
self.db[KEY_NPARRAY_KEYS] = set()
self.db.commit() self.db.commit()
self.counter = self.db[KEY_COUNTER] self.counter = self.db[KEY_COUNTER]
self.sub_data_keys = self.db[KEY_SUB_DATA_KEYS] self.sub_data_keys = self.db[KEY_SUB_DATA_KEYS]
self.nparray_counter = self.db[KEY_NPARRAY_COUNTER]
self.nparray_keys = self.db[KEY_NPARRAY_KEYS]
def _repair(self): def _repair(self):
self.need_open() self.need_open()
c = 0 c = 0
@ -118,6 +136,10 @@ class PersistentDataStructure(object):
assert len(self.sub_data_keys) == c assert len(self.sub_data_keys) == c
def _nparray_file_name(self, i):
fname = join(self._dirname, "{}.npy".format(i))
return fname
def __enter__(self): def __enter__(self):
return self return self
@ -180,6 +202,10 @@ class PersistentDataStructure(object):
print("call erase for key:", key, "on file", self._filename) print("call erase for key:", key, "on file", self._filename)
sub_data = self.getData(key) sub_data = self.getData(key)
sub_data.erase() sub_data.erase()
for k in self.nparray_keys:
fname = self.getNPA_filename(k)
os.remove(self._nparray_file_name(fname))
except: except:
traceback.print_exc() traceback.print_exc()
finally: finally:
@ -204,15 +230,25 @@ class PersistentDataStructure(object):
for k in self.sub_data_keys: for k in self.sub_data_keys:
with self[k] as sub_data: with self[k] as sub_data:
sub_data.erase() sub_data.erase()
for i in range(self.nparray_counter):
try:
os.remove(self._nparray_file_name(i))
except FileNotFoundError:
pass
self.db.clear() self.db.clear()
self.db[KEY_COUNTER] = 0 self.db[KEY_COUNTER] = 0
self.db[KEY_SUB_DATA_KEYS] = set() self.db[KEY_SUB_DATA_KEYS] = set()
self.db[KEY_NPARRAY_COUNTER] = 0
self.db[KEY_NPARRAY_KEYS] = set()
self.db.commit() self.db.commit()
self.sub_data_keys = set() self.sub_data_keys = set()
self.counter = 0 self.counter = 0
self.nparray_counter = 0
self.nparray_keys = set()
def show_stat(self, recursive = False, prepend = ""): def show_stat(self, recursive = False, prepend = ""):
prepend += self._name prepend += self._name
@ -239,6 +275,7 @@ class PersistentDataStructure(object):
if oth_key > 0: if oth_key > 0:
print("{}: number of other keys: {}".format(prepend, oth_key)) print("{}: number of other keys: {}".format(prepend, oth_key))
print("{}: number of subdata: {}".format(prepend, len(self.sub_data_keys))) print("{}: number of subdata: {}".format(prepend, len(self.sub_data_keys)))
print("{}: nparray counter: {}".format(prepend, self.nparray_counter))
print() print()
sys.stdout.flush() sys.stdout.flush()
if recursive: if recursive:
@ -258,8 +295,6 @@ class PersistentDataStructure(object):
""" """
if str(key) in RESERVED_KEYS: if str(key) in RESERVED_KEYS:
raise RuntimeError("key must not be in {} (reserved key)".format(RESERVED_KEYS)) raise RuntimeError("key must not be in {} (reserved key)".format(RESERVED_KEYS))
return True
def __is_sub_data(self, value): def __is_sub_data(self, value):
""" """
@ -275,6 +310,20 @@ class PersistentDataStructure(object):
return True return True
except: except:
return False return False
def __is_nparray(self, value):
"""
determine if the value gotten from the sqlitedict refers
to a numpy array which is stored in a seperate file
this is considered the case if the value itself has an index 'magic'
whose value matches a magic sign defined by MAGIC_SIGN_NPARRAY
"""
try:
assert value['magic'] == MAGIC_SIGN_NPARRAY
return True
except:
return False
def has_key(self, key): def has_key(self, key):
self.need_open() self.need_open()
@ -284,6 +333,10 @@ class PersistentDataStructure(object):
return key in self.sub_data_keys return key in self.sub_data_keys
# return self.__is_sub_data(self.db[key]) # return self.__is_sub_data(self.db[key])
def is_NPA(self, key):
return key in self.nparray_keys
# return self.__is_sub_data(self.db[key])
def setData(self, key, value, overwrite=False): def setData(self, key, value, overwrite=False):
""" """
write the key value pair to the data base write the key value pair to the data base
@ -293,15 +346,53 @@ class PersistentDataStructure(object):
that key in the database that key in the database
""" """
self.need_open() self.need_open()
if not self.__check_key(key): self.__check_key(key)
return False
if overwrite or (not key in self.db): if overwrite:
self.db[key] = value if self.verbose > 1:
self.db.commit() print("overwrite True: del key")
return True if key in self.db:
self.__delitem__(key)
if not key in self.db:
if _NP and isinstance(value, np.ndarray):
if self.verbose > 1:
print("set nparray")
return self._setNPA(key, nparray=value)
else:
if self.verbose > 1:
print("set normal value")
self.db[key] = value
self.db.commit()
return True
else:
if overwrite:
raise RuntimeError("this can not happen -> if so, pls check code!")
raise KeyError("could not set data, key exists, and overwrite is False")
def _setNPA(self, key, nparray):
d = {'fname': self._nparray_file_name(self.nparray_counter),
'magic': MAGIC_SIGN_NPARRAY}
self.nparray_keys.add(key)
self.nparray_counter += 1
self.db[KEY_NPARRAY_COUNTER] = self.nparray_counter
self.db[KEY_NPARRAY_KEYS] = self.nparray_keys
self.db[key] = d
self.db.commit()
np.save(d['fname'], nparray)
return True
def _getNPA(self, key):
d = self.db[key]
assert d['magic'] == MAGIC_SIGN_NPARRAY
fname = d['fname']
return np.load(fname)
return False
def newSubData(self, key): def newSubData(self, key):
""" """
@ -347,6 +438,10 @@ class PersistentDataStructure(object):
if self.verbose > 1: if self.verbose > 1:
print("return subData stored as key", key, "using name", sub_db_name) print("return subData stored as key", key, "using name", sub_db_name)
return PersistentDataStructure(name = sub_db_name, path = os.path.join(self._dirname) , verbose = self.verbose) return PersistentDataStructure(name = sub_db_name, path = os.path.join(self._dirname) , verbose = self.verbose)
elif self.is_NPA(key):
if self.verbose > 1:
print("return nparray value")
return self._getNPA(key)
else: else:
if self.verbose > 1: if self.verbose > 1:
print("return normal value") print("return normal value")
@ -384,7 +479,7 @@ class PersistentDataStructure(object):
def __len__(self): def __len__(self):
self.need_open() self.need_open()
return len(self.db) - 2 return len(self.db) - len(RESERVED_KEYS)
# implements the iterator # implements the iterator
def __iter__(self): def __iter__(self):
@ -409,21 +504,13 @@ class PersistentDataStructure(object):
# implements '[]' operator setter # implements '[]' operator setter
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.need_open()
self.__check_key(key)
# if key in self.db:
# if self.__is_sub_data(self.db[key]):
# raise RuntimeWarning("values which hold sub_data structures can not be overwritten!")
# return None
if self.verbose > 1:
print("set", key, "to", value, "in", self._filename)
if isinstance(value, PersistentDataStructure): if isinstance(value, PersistentDataStructure):
self.setDataFromSubData(key, value) self.setDataFromSubData(key, value)
else: else:
self.db[key] = value self.setData(key, value, overwrite=True)
self.db.commit()
if self.verbose > 1:
print("set", key, "to", value, "in", self._filename)
# implements '[]' operator deletion # implements '[]' operator deletion
@ -437,6 +524,11 @@ class PersistentDataStructure(object):
self.sub_data_keys.remove(key) self.sub_data_keys.remove(key)
self.db[KEY_SUB_DATA_KEYS] = self.sub_data_keys self.db[KEY_SUB_DATA_KEYS] = self.sub_data_keys
elif self.is_NPA(key):
d = self.db[key]
assert d['magic'] == MAGIC_SIGN_NPARRAY
fname = d['fname']
os.remove(fname)
del self.db[key] del self.db[key]
self.db.commit() self.db.commit()

View file

@ -8,6 +8,8 @@ import os
from os.path import abspath, dirname, split, exists from os.path import abspath, dirname, split, exists
from shutil import rmtree from shutil import rmtree
import numpy as np
import warnings import warnings
warnings.filterwarnings('error') warnings.filterwarnings('error')
@ -264,7 +266,8 @@ def test_remove_sub_data_and_check_len():
with sub_data.getData(key = 'subsub1', create_sub_data = True) as sub_sub_data: with sub_data.getData(key = 'subsub1', create_sub_data = True) as sub_sub_data:
sub_sub_data['t'] = 'hallo Welt' sub_sub_data['t'] = 'hallo Welt'
assert len(sub_data) == 3
assert len(sub_data) == 3, "len = {}".format(len(sub_data))
@ -400,7 +403,26 @@ def test_not_in():
finally: finally:
data.erase() data.erase()
def test_npa():
a = np.linspace(0, 1, 100).reshape(10,10)
with PDS(name='data_npa', verbose=VERBOSE) as data:
data.clear()
data['a'] = a
with PDS(name='data_npa', verbose=VERBOSE) as data:
b = data['a']
assert np.all(b == a)
assert os.path.exists('__data_npa/0.npy')
with PDS(name='data_npa', verbose=VERBOSE) as data:
del data['a']
data['a'] = a
assert not os.path.exists('__data_npa/0.npy')
assert os.path.exists('__data_npa/1.npy')
if __name__ == "__main__": if __name__ == "__main__":
test_reserved_key_catch() test_reserved_key_catch()
@ -414,3 +436,4 @@ if __name__ == "__main__":
test_len() test_len()
test_clear() test_clear()
test_not_in() test_not_in()
test_npa()