mirror of
https://github.com/vale981/stocproc
synced 2025-03-05 09:41:42 -05:00
minor mods in stocproc.py
This commit is contained in:
parent
e0fecb7712
commit
f2da55ca65
1 changed files with 21 additions and 10 deletions
31
stocproc.py
31
stocproc.py
|
@ -135,6 +135,12 @@ class StocProc(object):
|
|||
|
||||
"""
|
||||
|
||||
_dump_members = ['_r_tau',
|
||||
'_s',
|
||||
'_w',
|
||||
'_eig_val',
|
||||
'_eig_vec']
|
||||
|
||||
def __init__(self,
|
||||
r_tau=None,
|
||||
t=None,
|
||||
|
@ -143,11 +149,7 @@ class StocProc(object):
|
|||
sig_min=1e-4,
|
||||
fname=None,
|
||||
cache_size=1024):
|
||||
self.__dump_members = ['_r_tau',
|
||||
'_s',
|
||||
'_w',
|
||||
'_eig_val',
|
||||
'_eig_vec']
|
||||
|
||||
if fname is None:
|
||||
|
||||
assert r_tau is not None
|
||||
|
@ -214,7 +216,7 @@ class StocProc(object):
|
|||
|
||||
def __load(self, fname):
|
||||
with open(fname, 'rb') as f:
|
||||
for m in self.__dump_members:
|
||||
for m in self._dump_members:
|
||||
setattr(self, m, pickle.load(f))
|
||||
|
||||
def __calc_missing(self):
|
||||
|
@ -226,8 +228,16 @@ class StocProc(object):
|
|||
|
||||
def __dump(self, fname):
|
||||
with open(fname, 'wb') as f:
|
||||
for m in self.__dump_members:
|
||||
for m in self._dump_members:
|
||||
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __getstate__(self):
|
||||
return [getattr(self, atr) for atr in self._dump_members]
|
||||
|
||||
def __setstate__(self, state):
|
||||
for i, atr_value in enumerate(state):
|
||||
setattr(self, self._dump_members[i], atr_value)
|
||||
self.__calc_missing()
|
||||
|
||||
def save_to_file(self, fname):
|
||||
self.__dump(fname)
|
||||
|
@ -420,15 +430,16 @@ def _mean_error(r_t_s, r_t_s_exact):
|
|||
|
||||
def _max_error(r_t_s, r_t_s_exact):
|
||||
return np.max(np.abs(r_t_s - r_t_s_exact))
|
||||
|
||||
def _max_rel_error(r_t_s, r_t_s_exact):
|
||||
return np.max(np.abs(r_t_s - r_t_s_exact) / np.abs(r_t_s_exact))
|
||||
|
||||
def auto_grid_points(r_tau, t_max, ng_interpolation, tol = 1e-8, err_method = _max_error, name = 'mid_point'):
|
||||
def auto_grid_points(r_tau, t_max, ng_interpolation, tol = 1e-8, err_method = _max_error, name = 'mid_point', sig_min=1e-4):
|
||||
err = 1
|
||||
ng = 1
|
||||
seed = None
|
||||
sig_min = 0
|
||||
t_large = np.linspace(0, t_max, ng_interpolation)
|
||||
print("start auto_grid_points, determine ng ...")
|
||||
|
||||
#exponential increase to get below error threshold
|
||||
while err > tol:
|
||||
ng *= 2
|
||||
|
|
Loading…
Add table
Reference in a new issue