mirror of
https://github.com/vale981/stocproc
synced 2025-03-06 02:01:41 -05:00
minor mods in stocproc.py
This commit is contained in:
parent
e0fecb7712
commit
f2da55ca65
1 changed files with 21 additions and 10 deletions
31
stocproc.py
31
stocproc.py
|
@ -135,6 +135,12 @@ class StocProc(object):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_dump_members = ['_r_tau',
|
||||||
|
'_s',
|
||||||
|
'_w',
|
||||||
|
'_eig_val',
|
||||||
|
'_eig_vec']
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
r_tau=None,
|
r_tau=None,
|
||||||
t=None,
|
t=None,
|
||||||
|
@ -143,11 +149,7 @@ class StocProc(object):
|
||||||
sig_min=1e-4,
|
sig_min=1e-4,
|
||||||
fname=None,
|
fname=None,
|
||||||
cache_size=1024):
|
cache_size=1024):
|
||||||
self.__dump_members = ['_r_tau',
|
|
||||||
'_s',
|
|
||||||
'_w',
|
|
||||||
'_eig_val',
|
|
||||||
'_eig_vec']
|
|
||||||
if fname is None:
|
if fname is None:
|
||||||
|
|
||||||
assert r_tau is not None
|
assert r_tau is not None
|
||||||
|
@ -214,7 +216,7 @@ class StocProc(object):
|
||||||
|
|
||||||
def __load(self, fname):
|
def __load(self, fname):
|
||||||
with open(fname, 'rb') as f:
|
with open(fname, 'rb') as f:
|
||||||
for m in self.__dump_members:
|
for m in self._dump_members:
|
||||||
setattr(self, m, pickle.load(f))
|
setattr(self, m, pickle.load(f))
|
||||||
|
|
||||||
def __calc_missing(self):
|
def __calc_missing(self):
|
||||||
|
@ -226,9 +228,17 @@ class StocProc(object):
|
||||||
|
|
||||||
def __dump(self, fname):
|
def __dump(self, fname):
|
||||||
with open(fname, 'wb') as f:
|
with open(fname, 'wb') as f:
|
||||||
for m in self.__dump_members:
|
for m in self._dump_members:
|
||||||
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return [getattr(self, atr) for atr in self._dump_members]
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
for i, atr_value in enumerate(state):
|
||||||
|
setattr(self, self._dump_members[i], atr_value)
|
||||||
|
self.__calc_missing()
|
||||||
|
|
||||||
def save_to_file(self, fname):
|
def save_to_file(self, fname):
|
||||||
self.__dump(fname)
|
self.__dump(fname)
|
||||||
|
|
||||||
|
@ -421,14 +431,15 @@ def _mean_error(r_t_s, r_t_s_exact):
|
||||||
def _max_error(r_t_s, r_t_s_exact):
|
def _max_error(r_t_s, r_t_s_exact):
|
||||||
return np.max(np.abs(r_t_s - r_t_s_exact))
|
return np.max(np.abs(r_t_s - r_t_s_exact))
|
||||||
|
|
||||||
def auto_grid_points(r_tau, t_max, ng_interpolation, tol = 1e-8, err_method = _max_error, name = 'mid_point'):
|
def _max_rel_error(r_t_s, r_t_s_exact):
|
||||||
|
return np.max(np.abs(r_t_s - r_t_s_exact) / np.abs(r_t_s_exact))
|
||||||
|
|
||||||
|
def auto_grid_points(r_tau, t_max, ng_interpolation, tol = 1e-8, err_method = _max_error, name = 'mid_point', sig_min=1e-4):
|
||||||
err = 1
|
err = 1
|
||||||
ng = 1
|
ng = 1
|
||||||
seed = None
|
seed = None
|
||||||
sig_min = 0
|
|
||||||
t_large = np.linspace(0, t_max, ng_interpolation)
|
t_large = np.linspace(0, t_max, ng_interpolation)
|
||||||
print("start auto_grid_points, determine ng ...")
|
print("start auto_grid_points, determine ng ...")
|
||||||
|
|
||||||
#exponential increase to get below error threshold
|
#exponential increase to get below error threshold
|
||||||
while err > tol:
|
while err > tol:
|
||||||
ng *= 2
|
ng *= 2
|
||||||
|
|
Loading…
Add table
Reference in a new issue