mirror of
https://github.com/vale981/binfootprint
synced 2025-03-05 09:11:40 -05:00
added scipy csc matrix support
This commit is contained in:
parent
d79b766d8e
commit
fae9b47ae8
1 changed files with 45 additions and 2 deletions
|
@ -38,6 +38,11 @@ from math import ceil
|
|||
import numpy as np
|
||||
import struct
|
||||
from sys import version_info
|
||||
try:
|
||||
import scipy
|
||||
from scipy.sparse import csc_matrix
|
||||
except ImportError:
|
||||
scipy = None
|
||||
|
||||
_spec_types = (bool, type(None))
|
||||
|
||||
|
@ -56,6 +61,7 @@ _GETSTATE = 0x0b # only used when __bfkey__ is not present
|
|||
_DICT = 0x0c
|
||||
_INT_NEG = 0x0d
|
||||
_BFKEY = 0x0e # a special BF-Key member __bfkey__ is used if implemented, uses __getstate__ as fallback
|
||||
_SP_CSC_MAT = 0x0f # scipy csc sparse matrix
|
||||
|
||||
_VERS = 0x80
|
||||
def getVersion():
|
||||
|
@ -367,7 +373,11 @@ def _dump_dict(ob):
|
|||
keys = ob.keys()
|
||||
bin_keys = []
|
||||
for k in keys:
|
||||
bin_keys.append( (_dump(k), _dump(ob[k])) )
|
||||
try:
|
||||
bin_keys.append( (_dump(k), _dump(ob[k])) )
|
||||
except:
|
||||
print("failed to dump key '{}'".format(k))
|
||||
raise
|
||||
b += _dump_list(sorted(bin_keys))
|
||||
return b
|
||||
|
||||
|
@ -382,6 +392,35 @@ def _load_dict(b, classes):
|
|||
|
||||
return res_dict, l+1
|
||||
|
||||
def _dump_scipy_csc_matrix(ob):
|
||||
b = init_BYTES([_SP_CSC_MAT])
|
||||
|
||||
b += _dump_np_array(ob.data)
|
||||
b += _dump_np_array(ob.indices)
|
||||
b += _dump_np_array(ob.indptr)
|
||||
b += _dump_tuple(ob.shape)
|
||||
|
||||
return b
|
||||
|
||||
def _load_scipy_csc_matrix(b):
|
||||
assert comp_id(b[0], _SP_CSC_MAT)
|
||||
l = 0
|
||||
data, _l = _load_np_array(b[1:])
|
||||
l += _l
|
||||
indices, _l = _load_np_array(b[1 + l:])
|
||||
l += _l
|
||||
indptr, _l = _load_np_array(b[1 + l:])
|
||||
l += _l
|
||||
shape, _l = _load_tuple(b[1 + l:], classes={})
|
||||
l += _l
|
||||
|
||||
return csc_matrix((data, indices, indptr), shape=shape), l+1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _dump(ob):
|
||||
if isinstance(ob, _spec_types):
|
||||
return _dump_spec(ob)
|
||||
|
@ -413,8 +452,10 @@ def _dump(ob):
|
|||
return _dump_bfkey(ob)
|
||||
elif hasattr(ob, '__getstate__'):
|
||||
return _dump_getstate(ob)
|
||||
elif scipy and scipy.sparse.isspmatrix_csc(ob):
|
||||
return _dump_scipy_csc_matrix(ob)
|
||||
else:
|
||||
raise TypeError("unsupported type for dump '{}'".format(type(ob)))
|
||||
raise TypeError("unsupported type for dump '{}' ({})".format(type(ob), ob))
|
||||
|
||||
def _load(b, classes):
|
||||
identifier = b[0]
|
||||
|
@ -448,6 +489,8 @@ def _load(b, classes):
|
|||
return _load_bfkey(b, classes)
|
||||
elif identifier == _GETSTATE:
|
||||
return _load_getstate(b, classes)
|
||||
elif identifier == _SP_CSC_MAT:
|
||||
return _load_scipy_csc_matrix(b)
|
||||
else:
|
||||
raise BFLoadError("internal error (unknown identifier '{}')".format(hex(identifier)))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue