mirror of
https://github.com/vale981/stocproc
synced 2025-03-05 09:41:42 -05:00
make stocproc classes pickable with scale (resolved the scale error)
This commit is contained in:
parent
b90644f694
commit
db40aa2279
2 changed files with 56 additions and 17 deletions
|
@ -62,7 +62,7 @@ class _absStocProc(abc.ABC):
|
|||
|
||||
|
||||
"""
|
||||
def __init__(self, t_max=None, num_grid_points=None, seed=None, t_axis=None):
|
||||
def __init__(self, t_max=None, num_grid_points=None, seed=None, t_axis=None, scale=1):
|
||||
r"""
|
||||
:param t_max: specify time axis as [0, t_max], if None, the times must be explicitly
|
||||
given by t_axis
|
||||
|
@ -85,7 +85,8 @@ class _absStocProc(abc.ABC):
|
|||
np.random.seed(seed)
|
||||
self._one_over_sqrt_2 = 1/np.sqrt(2)
|
||||
self._proc_cnt = 0
|
||||
self.sqrt_scale = 1.
|
||||
self.scale = scale
|
||||
self.sqrt_scale = np.sqrt(self.scale)
|
||||
log.debug("init StocProc with t_max {} and {} grid points".format(t_max, num_grid_points))
|
||||
|
||||
def __call__(self, t=None):
|
||||
|
@ -161,6 +162,7 @@ class _absStocProc(abc.ABC):
|
|||
log.debug("created interpolator [{:.2e}s]".format(time.time() - t0))
|
||||
|
||||
def set_scale(self, scale):
|
||||
self.scale = scale
|
||||
self.sqrt_scale = np.sqrt(scale)
|
||||
|
||||
|
||||
|
@ -198,7 +200,7 @@ class StocProc_KLE(_absStocProc):
|
|||
"""
|
||||
|
||||
def __init__(self, r_tau, t_max, tol=1e-2, ng_fac=4, meth='fourpoint', diff_method='full', dm_random_samples=10**4,
|
||||
seed=None, align_eig_vec=False):
|
||||
seed=None, align_eig_vec=False, scale=1):
|
||||
"""
|
||||
:param r_tau: the idesired auto correlation function of a single parameter tau
|
||||
:param t_max: specifies the time interval [0, t_max] for which the processes in generated
|
||||
|
@ -241,23 +243,26 @@ class StocProc_KLE(_absStocProc):
|
|||
if align_eig_vec:
|
||||
method_kle.align_eig_vec(sqrt_lambda_ui_fine)
|
||||
|
||||
state = sqrt_lambda_ui_fine, t, seed
|
||||
state = sqrt_lambda_ui_fine, t, seed, scale
|
||||
self.__setstate__(state)
|
||||
|
||||
|
||||
def get_key(self):
|
||||
"""Returns the tuple (r_tau, t_max, tol) which should suffice to identify the process in order to load/dump
|
||||
the StocProc class.
|
||||
"""
|
||||
# def get_key(self):
|
||||
# """Returns the tuple (r_tau, t_max, tol) which should suffice to identify the process in order to load/dump
|
||||
# the StocProc class.
|
||||
# """
|
||||
# return self.key
|
||||
|
||||
def __bfkey__(self):
|
||||
return self.key
|
||||
|
||||
def __getstate__(self):
|
||||
return self.sqrt_lambda_ui_fine, self.t, self._seed
|
||||
return self.sqrt_lambda_ui_fine, self.t, self._seed, self.scale
|
||||
|
||||
def __setstate__(self, state):
|
||||
sqrt_lambda_ui_fine, t, seed = state
|
||||
sqrt_lambda_ui_fine, t, seed, scale = state
|
||||
num_ev, ng = sqrt_lambda_ui_fine.shape
|
||||
super().__init__(t_axis=t, seed=seed)
|
||||
super().__init__(t_axis=t, seed=seed, scale=scale)
|
||||
self.num_ev = num_ev
|
||||
self.sqrt_lambda_ui_fine = sqrt_lambda_ui_fine
|
||||
|
||||
|
@ -342,7 +347,7 @@ class StocProc_FFT(_absStocProc):
|
|||
|
||||
"""
|
||||
def __init__(self, spectral_density, t_max, bcf_ref, intgr_tol=1e-2, intpl_tol=1e-2,
|
||||
seed=None, negative_frequencies=False):
|
||||
seed=None, negative_frequencies=False, scale=1):
|
||||
self.key = bcf_ref, t_max, intgr_tol, intpl_tol
|
||||
|
||||
if not negative_frequencies:
|
||||
|
@ -396,7 +401,8 @@ class StocProc_FFT(_absStocProc):
|
|||
|
||||
super().__init__(t_max = t_max,
|
||||
num_grid_points = num_grid_points,
|
||||
seed = seed)
|
||||
seed = seed,
|
||||
scale = scale)
|
||||
|
||||
omega = dx*np.arange(N)
|
||||
self.yl = spectral_density(omega + a + dx/2) * dx / np.pi
|
||||
|
@ -404,13 +410,14 @@ class StocProc_FFT(_absStocProc):
|
|||
self.omega_min_correction = np.exp(-1j*(a+dx/2)*self.t) #self.t is from the parent class
|
||||
|
||||
def __getstate__(self):
|
||||
return self.yl, self.num_grid_points, self.omega_min_correction, self.t_max, self._seed
|
||||
return self.yl, self.num_grid_points, self.omega_min_correction, self.t_max, self._seed, self.scale
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.yl, num_grid_points, self.omega_min_correction, t_max, seed = state
|
||||
self.yl, num_grid_points, self.omega_min_correction, t_max, seed, scale = state
|
||||
super().__init__(t_max = t_max,
|
||||
num_grid_points = num_grid_points,
|
||||
seed = seed)
|
||||
seed = seed,
|
||||
scale = scale)
|
||||
|
||||
def calc_z(self, y):
|
||||
r"""calculate
|
||||
|
|
|
@ -236,6 +236,36 @@ def test_many(plot=False):
|
|||
stp = sp.StocProc_KLE(tol=5e-3, r_tau=ac, t_max=t_max, ng_fac=1, seed=0, diff_method='random', meth='fp')
|
||||
stocproc_metatest(stp, num_samples, tol, ac, plot)
|
||||
|
||||
def test_pickle_scale():
|
||||
t_max = 1
|
||||
tol = 0.1
|
||||
|
||||
sd = lsd
|
||||
ac = lac
|
||||
stp = sp.StocProc_FFT(sd, t_max, ac, negative_frequencies=True, seed=0, intgr_tol=tol, intpl_tol=tol)
|
||||
stp.set_scale(0.56)
|
||||
stp_dump = pickle.dumps(stp)
|
||||
stp_prime = pickle.loads(stp_dump)
|
||||
assert stp_prime.scale == 0.56
|
||||
|
||||
stp = sp.StocProc_FFT(sd, t_max, ac, negative_frequencies=True, seed=0, intgr_tol=tol, intpl_tol=tol, scale=0.56)
|
||||
stp_dump = pickle.dumps(stp)
|
||||
stp_prime = pickle.loads(stp_dump)
|
||||
assert stp_prime.scale == 0.56
|
||||
|
||||
stp = sp.StocProc_KLE(ac, t_max, tol=tol)
|
||||
stp.set_scale(0.56)
|
||||
stp_dump = pickle.dumps(stp)
|
||||
stp_prime = pickle.loads(stp_dump)
|
||||
assert stp_prime.scale == 0.56
|
||||
|
||||
stp = sp.StocProc_KLE(ac, t_max, tol=tol, scale=0.56)
|
||||
stp_dump = pickle.dumps(stp)
|
||||
stp_prime = pickle.loads(stp_dump)
|
||||
assert stp_prime.scale == 0.56
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
@ -244,5 +274,7 @@ if __name__ == "__main__":
|
|||
# test_stochastic_process_FFT_correlation_function(plot=False)
|
||||
# test_stocproc_dump_load()
|
||||
|
||||
test_many(plot=False)
|
||||
# test_many(plot=False)
|
||||
test_pickle_scale()
|
||||
pass
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue