added scipy csc matrix support

This commit is contained in:
Richard Hartmann 2018-11-23 11:08:02 +01:00
parent d79b766d8e
commit fae9b47ae8

View file

@ -38,6 +38,11 @@ from math import ceil
import numpy as np
import struct
from sys import version_info
try:
import scipy
from scipy.sparse import csc_matrix
except ImportError:
scipy = None
_spec_types = (bool, type(None))
@ -56,6 +61,7 @@ _GETSTATE = 0x0b # only used when __bfkey__ is not present
_DICT = 0x0c
_INT_NEG = 0x0d
_BFKEY = 0x0e # a special BF-Key member __bfkey__ is used if implemented, uses __getstate__ as fallback
_SP_CSC_MAT = 0x0f # scipy csc sparse matrix
_VERS = 0x80
def getVersion():
@ -367,7 +373,11 @@ def _dump_dict(ob):
keys = ob.keys()
bin_keys = []
for k in keys:
bin_keys.append( (_dump(k), _dump(ob[k])) )
try:
bin_keys.append( (_dump(k), _dump(ob[k])) )
except:
print("failed to dump key '{}'".format(k))
raise
b += _dump_list(sorted(bin_keys))
return b
@ -382,6 +392,35 @@ def _load_dict(b, classes):
return res_dict, l+1
def _dump_scipy_csc_matrix(ob):
b = init_BYTES([_SP_CSC_MAT])
b += _dump_np_array(ob.data)
b += _dump_np_array(ob.indices)
b += _dump_np_array(ob.indptr)
b += _dump_tuple(ob.shape)
return b
def _load_scipy_csc_matrix(b):
assert comp_id(b[0], _SP_CSC_MAT)
l = 0
data, _l = _load_np_array(b[1:])
l += _l
indices, _l = _load_np_array(b[1 + l:])
l += _l
indptr, _l = _load_np_array(b[1 + l:])
l += _l
shape, _l = _load_tuple(b[1 + l:], classes={})
l += _l
return csc_matrix((data, indices, indptr), shape=shape), l+1
def _dump(ob):
if isinstance(ob, _spec_types):
return _dump_spec(ob)
@ -413,8 +452,10 @@ def _dump(ob):
return _dump_bfkey(ob)
elif hasattr(ob, '__getstate__'):
return _dump_getstate(ob)
elif scipy and scipy.sparse.isspmatrix_csc(ob):
return _dump_scipy_csc_matrix(ob)
else:
raise TypeError("unsupported type for dump '{}'".format(type(ob)))
raise TypeError("unsupported type for dump '{}' ({})".format(type(ob), ob))
def _load(b, classes):
identifier = b[0]
@ -448,6 +489,8 @@ def _load(b, classes):
return _load_bfkey(b, classes)
elif identifier == _GETSTATE:
return _load_getstate(b, classes)
elif identifier == _SP_CSC_MAT:
return _load_scipy_csc_matrix(b)
else:
raise BFLoadError("internal error (unknown identifier '{}')".format(hex(identifier)))