make stocproc classes pickable with scale (resolved the scale error)

This commit is contained in:
Richard Hartmann 2016-12-09 15:31:37 +01:00
parent b90644f694
commit db40aa2279
2 changed files with 56 additions and 17 deletions

View file

@ -62,7 +62,7 @@ class _absStocProc(abc.ABC):
"""
def __init__(self, t_max=None, num_grid_points=None, seed=None, t_axis=None):
def __init__(self, t_max=None, num_grid_points=None, seed=None, t_axis=None, scale=1):
r"""
:param t_max: specify time axis as [0, t_max], if None, the times must be explicitly
given by t_axis
@ -85,7 +85,8 @@ class _absStocProc(abc.ABC):
np.random.seed(seed)
self._one_over_sqrt_2 = 1/np.sqrt(2)
self._proc_cnt = 0
self.sqrt_scale = 1.
self.scale = scale
self.sqrt_scale = np.sqrt(self.scale)
log.debug("init StocProc with t_max {} and {} grid points".format(t_max, num_grid_points))
def __call__(self, t=None):
@ -161,6 +162,7 @@ class _absStocProc(abc.ABC):
log.debug("created interpolator [{:.2e}s]".format(time.time() - t0))
def set_scale(self, scale):
self.scale = scale
self.sqrt_scale = np.sqrt(scale)
@ -198,7 +200,7 @@ class StocProc_KLE(_absStocProc):
"""
def __init__(self, r_tau, t_max, tol=1e-2, ng_fac=4, meth='fourpoint', diff_method='full', dm_random_samples=10**4,
seed=None, align_eig_vec=False):
seed=None, align_eig_vec=False, scale=1):
"""
:param r_tau: the idesired auto correlation function of a single parameter tau
:param t_max: specifies the time interval [0, t_max] for which the processes in generated
@ -241,23 +243,26 @@ class StocProc_KLE(_absStocProc):
if align_eig_vec:
method_kle.align_eig_vec(sqrt_lambda_ui_fine)
state = sqrt_lambda_ui_fine, t, seed
state = sqrt_lambda_ui_fine, t, seed, scale
self.__setstate__(state)
def get_key(self):
"""Returns the tuple (r_tau, t_max, tol) which should suffice to identify the process in order to load/dump
the StocProc class.
"""
# def get_key(self):
# """Returns the tuple (r_tau, t_max, tol) which should suffice to identify the process in order to load/dump
# the StocProc class.
# """
# return self.key
def __bfkey__(self):
return self.key
def __getstate__(self):
return self.sqrt_lambda_ui_fine, self.t, self._seed
return self.sqrt_lambda_ui_fine, self.t, self._seed, self.scale
def __setstate__(self, state):
sqrt_lambda_ui_fine, t, seed = state
sqrt_lambda_ui_fine, t, seed, scale = state
num_ev, ng = sqrt_lambda_ui_fine.shape
super().__init__(t_axis=t, seed=seed)
super().__init__(t_axis=t, seed=seed, scale=scale)
self.num_ev = num_ev
self.sqrt_lambda_ui_fine = sqrt_lambda_ui_fine
@ -342,7 +347,7 @@ class StocProc_FFT(_absStocProc):
"""
def __init__(self, spectral_density, t_max, bcf_ref, intgr_tol=1e-2, intpl_tol=1e-2,
seed=None, negative_frequencies=False):
seed=None, negative_frequencies=False, scale=1):
self.key = bcf_ref, t_max, intgr_tol, intpl_tol
if not negative_frequencies:
@ -396,7 +401,8 @@ class StocProc_FFT(_absStocProc):
super().__init__(t_max = t_max,
num_grid_points = num_grid_points,
seed = seed)
seed = seed,
scale = scale)
omega = dx*np.arange(N)
self.yl = spectral_density(omega + a + dx/2) * dx / np.pi
@ -404,13 +410,14 @@ class StocProc_FFT(_absStocProc):
self.omega_min_correction = np.exp(-1j*(a+dx/2)*self.t) #self.t is from the parent class
def __getstate__(self):
return self.yl, self.num_grid_points, self.omega_min_correction, self.t_max, self._seed
return self.yl, self.num_grid_points, self.omega_min_correction, self.t_max, self._seed, self.scale
def __setstate__(self, state):
self.yl, num_grid_points, self.omega_min_correction, t_max, seed = state
self.yl, num_grid_points, self.omega_min_correction, t_max, seed, scale = state
super().__init__(t_max = t_max,
num_grid_points = num_grid_points,
seed = seed)
seed = seed,
scale = scale)
def calc_z(self, y):
r"""calculate

View file

@ -236,6 +236,36 @@ def test_many(plot=False):
stp = sp.StocProc_KLE(tol=5e-3, r_tau=ac, t_max=t_max, ng_fac=1, seed=0, diff_method='random', meth='fp')
stocproc_metatest(stp, num_samples, tol, ac, plot)
def test_pickle_scale():
t_max = 1
tol = 0.1
sd = lsd
ac = lac
stp = sp.StocProc_FFT(sd, t_max, ac, negative_frequencies=True, seed=0, intgr_tol=tol, intpl_tol=tol)
stp.set_scale(0.56)
stp_dump = pickle.dumps(stp)
stp_prime = pickle.loads(stp_dump)
assert stp_prime.scale == 0.56
stp = sp.StocProc_FFT(sd, t_max, ac, negative_frequencies=True, seed=0, intgr_tol=tol, intpl_tol=tol, scale=0.56)
stp_dump = pickle.dumps(stp)
stp_prime = pickle.loads(stp_dump)
assert stp_prime.scale == 0.56
stp = sp.StocProc_KLE(ac, t_max, tol=tol)
stp.set_scale(0.56)
stp_dump = pickle.dumps(stp)
stp_prime = pickle.loads(stp_dump)
assert stp_prime.scale == 0.56
stp = sp.StocProc_KLE(ac, t_max, tol=tol, scale=0.56)
stp_dump = pickle.dumps(stp)
stp_prime = pickle.loads(stp_dump)
assert stp_prime.scale == 0.56
if __name__ == "__main__":
import logging
@ -244,5 +274,7 @@ if __name__ == "__main__":
# test_stochastic_process_FFT_correlation_function(plot=False)
# test_stocproc_dump_load()
test_many(plot=False)
# test_many(plot=False)
test_pickle_scale()
pass