convinced the PersistentDataStructure to save np.ndarrays in sparate file

This commit is contained in:
Richard Hartmann 2015-12-03 17:21:14 +01:00
parent b60f7870a2
commit 988ed36d18
2 changed files with 141 additions and 26 deletions

View file

@ -8,6 +8,13 @@ import shutil
import traceback
import pickle
import warnings
try:
import numpy as np
_NP = True
except ImportError:
warnings.warn("could not import 'numpy', I can not treat np.ndarray separately!")
_NP = False
if sys.version_info[0] == 2:
# fixes keyword problems with python 2.x
@ -22,11 +29,14 @@ if sys.version_info[0] == 2:
os.rmdir = new_rmdir
MAGIC_SIGN = 0xff4a87
MAGIC_SIGN_NPARRAY = 0xee4a87
KEY_COUNTER = '0'
KEY_SUB_DATA_KEYS = '1'
KEY_NPARRAY_COUNTER = 'np0'
KEY_NPARRAY_KEYS = 'np1'
RESERVED_KEYS = (KEY_COUNTER, KEY_SUB_DATA_KEYS)
RESERVED_KEYS = (KEY_COUNTER, KEY_SUB_DATA_KEYS, KEY_NPARRAY_COUNTER, KEY_NPARRAY_KEYS)
def key_to_str(key, max_len = 255):
if isinstance(key, (bytearray, bytes)):
@ -80,10 +90,18 @@ class PersistentDataStructure(object):
if KEY_SUB_DATA_KEYS not in self.db:
self.db[KEY_SUB_DATA_KEYS] = set()
if KEY_NPARRAY_COUNTER not in self.db:
self.db[KEY_NPARRAY_COUNTER] = 0
if KEY_NPARRAY_KEYS not in self.db:
self.db[KEY_NPARRAY_KEYS] = set()
self.db.commit()
self.counter = self.db[KEY_COUNTER]
self.sub_data_keys = self.db[KEY_SUB_DATA_KEYS]
self.nparray_counter = self.db[KEY_NPARRAY_COUNTER]
self.nparray_keys = self.db[KEY_NPARRAY_KEYS]
def _repair(self):
self.need_open()
@ -118,6 +136,10 @@ class PersistentDataStructure(object):
assert len(self.sub_data_keys) == c
def _nparray_file_name(self, i):
fname = join(self._dirname, "{}.npy".format(i))
return fname
def __enter__(self):
return self
@ -180,6 +202,10 @@ class PersistentDataStructure(object):
print("call erase for key:", key, "on file", self._filename)
sub_data = self.getData(key)
sub_data.erase()
for k in self.nparray_keys:
fname = self.getNPA_filename(k)
os.remove(self._nparray_file_name(fname))
except:
traceback.print_exc()
finally:
@ -205,14 +231,24 @@ class PersistentDataStructure(object):
with self[k] as sub_data:
sub_data.erase()
for i in range(self.nparray_counter):
try:
os.remove(self._nparray_file_name(i))
except FileNotFoundError:
pass
self.db.clear()
self.db[KEY_COUNTER] = 0
self.db[KEY_SUB_DATA_KEYS] = set()
self.db[KEY_NPARRAY_COUNTER] = 0
self.db[KEY_NPARRAY_KEYS] = set()
self.db.commit()
self.sub_data_keys = set()
self.counter = 0
self.nparray_counter = 0
self.nparray_keys = set()
def show_stat(self, recursive = False, prepend = ""):
prepend += self._name
@ -239,6 +275,7 @@ class PersistentDataStructure(object):
if oth_key > 0:
print("{}: number of other keys: {}".format(prepend, oth_key))
print("{}: number of subdata: {}".format(prepend, len(self.sub_data_keys)))
print("{}: nparray counter: {}".format(prepend, self.nparray_counter))
print()
sys.stdout.flush()
if recursive:
@ -259,8 +296,6 @@ class PersistentDataStructure(object):
if str(key) in RESERVED_KEYS:
raise RuntimeError("key must not be in {} (reserved key)".format(RESERVED_KEYS))
return True
def __is_sub_data(self, value):
"""
determine if the value gotten from the sqlitedict refers
@ -276,6 +311,20 @@ class PersistentDataStructure(object):
except:
return False
def __is_nparray(self, value):
"""
determine if the value gotten from the sqlitedict refers
to a numpy array which is stored in a seperate file
this is considered the case if the value itself has an index 'magic'
whose value matches a magic sign defined by MAGIC_SIGN_NPARRAY
"""
try:
assert value['magic'] == MAGIC_SIGN_NPARRAY
return True
except:
return False
def has_key(self, key):
self.need_open()
return (key in self.db)
@ -284,6 +333,10 @@ class PersistentDataStructure(object):
return key in self.sub_data_keys
# return self.__is_sub_data(self.db[key])
def is_NPA(self, key):
return key in self.nparray_keys
# return self.__is_sub_data(self.db[key])
def setData(self, key, value, overwrite=False):
"""
write the key value pair to the data base
@ -293,15 +346,53 @@ class PersistentDataStructure(object):
that key in the database
"""
self.need_open()
if not self.__check_key(key):
return False
self.__check_key(key)
if overwrite or (not key in self.db):
if overwrite:
if self.verbose > 1:
print("overwrite True: del key")
if key in self.db:
self.__delitem__(key)
if not key in self.db:
if _NP and isinstance(value, np.ndarray):
if self.verbose > 1:
print("set nparray")
return self._setNPA(key, nparray=value)
else:
if self.verbose > 1:
print("set normal value")
self.db[key] = value
self.db.commit()
return True
else:
if overwrite:
raise RuntimeError("this can not happen -> if so, pls check code!")
raise KeyError("could not set data, key exists, and overwrite is False")
def _setNPA(self, key, nparray):
d = {'fname': self._nparray_file_name(self.nparray_counter),
'magic': MAGIC_SIGN_NPARRAY}
self.nparray_keys.add(key)
self.nparray_counter += 1
self.db[KEY_NPARRAY_COUNTER] = self.nparray_counter
self.db[KEY_NPARRAY_KEYS] = self.nparray_keys
self.db[key] = d
self.db.commit()
np.save(d['fname'], nparray)
return True
def _getNPA(self, key):
d = self.db[key]
assert d['magic'] == MAGIC_SIGN_NPARRAY
fname = d['fname']
return np.load(fname)
return False
def newSubData(self, key):
"""
@ -347,6 +438,10 @@ class PersistentDataStructure(object):
if self.verbose > 1:
print("return subData stored as key", key, "using name", sub_db_name)
return PersistentDataStructure(name = sub_db_name, path = os.path.join(self._dirname) , verbose = self.verbose)
elif self.is_NPA(key):
if self.verbose > 1:
print("return nparray value")
return self._getNPA(key)
else:
if self.verbose > 1:
print("return normal value")
@ -384,7 +479,7 @@ class PersistentDataStructure(object):
def __len__(self):
self.need_open()
return len(self.db) - 2
return len(self.db) - len(RESERVED_KEYS)
# implements the iterator
def __iter__(self):
@ -409,21 +504,13 @@ class PersistentDataStructure(object):
# implements '[]' operator setter
def __setitem__(self, key, value):
self.need_open()
self.__check_key(key)
# if key in self.db:
# if self.__is_sub_data(self.db[key]):
# raise RuntimeWarning("values which hold sub_data structures can not be overwritten!")
# return None
if self.verbose > 1:
print("set", key, "to", value, "in", self._filename)
if isinstance(value, PersistentDataStructure):
self.setDataFromSubData(key, value)
else:
self.db[key] = value
self.db.commit()
self.setData(key, value, overwrite=True)
if self.verbose > 1:
print("set", key, "to", value, "in", self._filename)
# implements '[]' operator deletion
@ -437,6 +524,11 @@ class PersistentDataStructure(object):
self.sub_data_keys.remove(key)
self.db[KEY_SUB_DATA_KEYS] = self.sub_data_keys
elif self.is_NPA(key):
d = self.db[key]
assert d['magic'] == MAGIC_SIGN_NPARRAY
fname = d['fname']
os.remove(fname)
del self.db[key]
self.db.commit()

View file

@ -8,6 +8,8 @@ import os
from os.path import abspath, dirname, split, exists
from shutil import rmtree
import numpy as np
import warnings
warnings.filterwarnings('error')
@ -264,7 +266,8 @@ def test_remove_sub_data_and_check_len():
with sub_data.getData(key = 'subsub1', create_sub_data = True) as sub_sub_data:
sub_sub_data['t'] = 'hallo Welt'
assert len(sub_data) == 3
assert len(sub_data) == 3, "len = {}".format(len(sub_data))
@ -401,6 +404,25 @@ def test_not_in():
finally:
data.erase()
def test_npa():
a = np.linspace(0, 1, 100).reshape(10,10)
with PDS(name='data_npa', verbose=VERBOSE) as data:
data.clear()
data['a'] = a
with PDS(name='data_npa', verbose=VERBOSE) as data:
b = data['a']
assert np.all(b == a)
assert os.path.exists('__data_npa/0.npy')
with PDS(name='data_npa', verbose=VERBOSE) as data:
del data['a']
data['a'] = a
assert not os.path.exists('__data_npa/0.npy')
assert os.path.exists('__data_npa/1.npy')
if __name__ == "__main__":
test_reserved_key_catch()
@ -414,3 +436,4 @@ if __name__ == "__main__":
test_len()
test_clear()
test_not_in()
test_npa()