mirror of
https://github.com/vale981/stocproc
synced 2025-03-06 02:01:41 -05:00
implement stocproc Class dump load, cache result for x(t) via decorator, may need some own implementation within the class definition
This commit is contained in:
parent
8a247107ae
commit
8d71a2ee03
2 changed files with 146 additions and 27 deletions
70
stocproc.py
70
stocproc.py
|
@ -17,6 +17,7 @@
|
||||||
# along with this program; if not, write to the Free Software
|
# along with this program; if not, write to the Free Software
|
||||||
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||||
# MA 02110-1301, USA.
|
# MA 02110-1301, USA.
|
||||||
|
from copyreg import pickle
|
||||||
"""
|
"""
|
||||||
**Stochastic Process Module**
|
**Stochastic Process Module**
|
||||||
|
|
||||||
|
@ -58,6 +59,8 @@ solutions of the time discrete version.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time as tm
|
import time as tm
|
||||||
|
import functools
|
||||||
|
import pickle
|
||||||
|
|
||||||
class StocProc(object):
|
class StocProc(object):
|
||||||
r"""Simulate Stochastic Process using Karhunen-Loève expansion
|
r"""Simulate Stochastic Process using Karhunen-Loève expansion
|
||||||
|
@ -124,16 +127,30 @@ class StocProc(object):
|
||||||
Auflage: 3. ed. Cambridge University Press, Cambridge, UK ; New York. (pp. 989)
|
Auflage: 3. ed. Cambridge University Press, Cambridge, UK ; New York. (pp. 989)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, r_tau, t, w, seed = None, sig_min = 1e-4):
|
|
||||||
|
def __init__(self, r_tau=None, t=None, w=None, seed = None, sig_min = 1e-4, fname=None):
|
||||||
|
self.__dump_members = ['_r_tau',
|
||||||
|
'_s',
|
||||||
|
'_w',
|
||||||
|
'_eig_val',
|
||||||
|
'_eig_vec']
|
||||||
|
if fname == None:
|
||||||
|
|
||||||
|
assert r_tau != None
|
||||||
self._r_tau = r_tau
|
self._r_tau = r_tau
|
||||||
self._num_gp = len(t)
|
|
||||||
|
assert t != None
|
||||||
self._s = t
|
self._s = t
|
||||||
|
self._num_gp = len(self._s)
|
||||||
|
|
||||||
|
assert w != None
|
||||||
self._w = w
|
self._w = w
|
||||||
t_row = t.reshape(1, self._num_gp)
|
|
||||||
t_col = t.reshape(self._num_gp, 1)
|
t_row = self._s.reshape(1, self._num_gp)
|
||||||
|
t_col = self._s.reshape(self._num_gp, 1)
|
||||||
# correlation matrix
|
# correlation matrix
|
||||||
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
|
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
|
||||||
r = r_tau(t_col-t_row)
|
r = self._r_tau(t_col-t_row)
|
||||||
|
|
||||||
# solve discrete Fredholm equation
|
# solve discrete Fredholm equation
|
||||||
# eig_val = lambda
|
# eig_val = lambda
|
||||||
|
@ -141,10 +158,44 @@ class StocProc(object):
|
||||||
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
|
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
|
||||||
self._sqrt_eig_val = np.sqrt(self._eig_val)
|
self._sqrt_eig_val = np.sqrt(self._eig_val)
|
||||||
self._num_ev = len(self._eig_val)
|
self._num_ev = len(self._eig_val)
|
||||||
self._A = w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.__load(fname)
|
||||||
|
self.__calc_missing()
|
||||||
|
|
||||||
self.new_process(seed)
|
self.new_process(seed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_instance_with_trapezoidal_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
|
||||||
|
t, w = get_trapezoidal_weights_times(t_max, ng)
|
||||||
|
return _StocProc(r_tau, t, w, seed, sig_min)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_instance_with_mid_point_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
|
||||||
|
t, w = get_mid_point_weights(t_max, ng)
|
||||||
|
return _StocProc(r_tau, t, w, seed, sig_min)
|
||||||
|
|
||||||
|
def __load(self, fname):
|
||||||
|
with open(fname, 'rb') as f:
|
||||||
|
for m in self.__dump_members:
|
||||||
|
setattr(self, m, pickle.load(f))
|
||||||
|
|
||||||
|
def __calc_missing(self):
|
||||||
|
self._num_gp = len(self._s)
|
||||||
|
self._sqrt_eig_val = np.sqrt(self._eig_val)
|
||||||
|
self._num_ev = len(self._eig_val)
|
||||||
|
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
||||||
|
|
||||||
|
|
||||||
|
def __dump(self, fname):
|
||||||
|
with open(fname, 'wb') as f:
|
||||||
|
for m in self.__dump_members:
|
||||||
|
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def save_to_file(self, fname):
|
||||||
|
self.__dump(fname)
|
||||||
|
|
||||||
def new_process(self, seed = None):
|
def new_process(self, seed = None):
|
||||||
r"""setup new process
|
r"""setup new process
|
||||||
|
|
||||||
|
@ -171,13 +222,18 @@ class StocProc(object):
|
||||||
tmp = self._Y * self._sqrt_eig_val.reshape(self._num_ev,1)
|
tmp = self._Y * self._sqrt_eig_val.reshape(self._num_ev,1)
|
||||||
return np.tensordot(tmp, self._eig_vec, axes=([0],[1])).flatten()
|
return np.tensordot(tmp, self._eig_vec, axes=([0],[1])).flatten()
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1024, typed=False)
|
||||||
def x(self, t):
|
def x(self, t):
|
||||||
# self._Y # (N_ev, 1 )
|
# _Y # (N_ev, 1 )
|
||||||
tmp = self._Y*self._r_tau(t-self._s.reshape(1, self._num_gp))
|
tmp = self._Y*self._r_tau(t-self._s.reshape(1, self._num_gp))
|
||||||
# (N_ev, N_gp)
|
# (N_ev, N_gp)
|
||||||
# A # (N_gp, N_ev)
|
# A # (N_gp, N_ev)
|
||||||
return np.tensordot(tmp, self._A, axes=([1,0],[0,1]))
|
return np.tensordot(tmp, self._A, axes=([1,0],[0,1]))
|
||||||
|
|
||||||
|
def get_cache_info(self):
|
||||||
|
return self.x.cache_info()
|
||||||
|
|
||||||
|
|
||||||
def x_t_array(self, t_array):
|
def x_t_array(self, t_array):
|
||||||
t_array = t_array.reshape(1,1,len(t_array)) # (1 , 1 , N_t)
|
t_array = t_array.reshape(1,1,len(t_array)) # (1 , 1 , N_t)
|
||||||
tmp = (self._Y.reshape(self._num_ev,1,1) *
|
tmp = (self._Y.reshape(self._num_ev,1,1) *
|
||||||
|
|
|
@ -20,16 +20,21 @@
|
||||||
"""Test Suite for Stochastic Process Module stocproc.py
|
"""Test Suite for Stochastic Process Module stocproc.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import division
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import stocproc as sp
|
|
||||||
from scipy.special import gamma
|
from scipy.special import gamma
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
import time as tm
|
import time as tm
|
||||||
import pickle as pc
|
import pickle as pc
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import functools
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
path = os.path.dirname(__file__)
|
||||||
|
sys.path.append(path)
|
||||||
|
|
||||||
|
import stocproc as sp
|
||||||
|
|
||||||
def corr(tau, s, gamma_s_plus_1):
|
def corr(tau, s, gamma_s_plus_1):
|
||||||
"""ohmic bath correlation function"""
|
"""ohmic bath correlation function"""
|
||||||
|
@ -186,9 +191,9 @@ def test_orthonomality():
|
||||||
idx1, idx2 = np.where(diff > diff_assert)
|
idx1, idx2 = np.where(diff > diff_assert)
|
||||||
|
|
||||||
if len(idx1) > 0:
|
if len(idx1) > 0:
|
||||||
print "orthonomality test FAILED at:"
|
print("orthonomality test FAILED at:")
|
||||||
for i in range(len(idx1)):
|
for i in range(len(idx1)):
|
||||||
print " ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]])
|
print(" ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]]))
|
||||||
raise Exception("test_orthonomality FAILED!")
|
raise Exception("test_orthonomality FAILED!")
|
||||||
|
|
||||||
def test_auto_grid_points():
|
def test_auto_grid_points():
|
||||||
|
@ -202,7 +207,63 @@ def test_auto_grid_points():
|
||||||
tol = 1e-16
|
tol = 1e-16
|
||||||
|
|
||||||
ng = sp.auto_grid_points(r_tau, t_max, ng_interpolation, tol)
|
ng = sp.auto_grid_points(r_tau, t_max, ng_interpolation, tol)
|
||||||
print ng
|
print(ng)
|
||||||
|
|
||||||
|
def test_chache():
|
||||||
|
s_param = 1
|
||||||
|
gamma_s_plus_1 = gamma(s_param+1)
|
||||||
|
r_tau = lambda tau : corr(tau, s_param, gamma_s_plus_1)
|
||||||
|
|
||||||
|
t_max = 10
|
||||||
|
ng = 50
|
||||||
|
seed = 0
|
||||||
|
sig_min = 1e-8
|
||||||
|
|
||||||
|
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
|
||||||
|
|
||||||
|
t = {}
|
||||||
|
t[1] = 3
|
||||||
|
t[2] = 4
|
||||||
|
t[3] = 5
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
misses = len(t.keys())
|
||||||
|
for t_i in t.keys():
|
||||||
|
for i in range(t[t_i]):
|
||||||
|
total += 1
|
||||||
|
stocproc.x(t_i)
|
||||||
|
|
||||||
|
ci = stocproc.get_cache_info()
|
||||||
|
assert ci.hits == total - misses
|
||||||
|
assert ci.misses == misses
|
||||||
|
|
||||||
|
def test_dump_load():
|
||||||
|
s_param = 1
|
||||||
|
gamma_s_plus_1 = gamma(s_param+1)
|
||||||
|
r_tau = functools.partial(corr, s=s_param, gamma_s_plus_1=gamma_s_plus_1)
|
||||||
|
|
||||||
|
t_max = 10
|
||||||
|
ng = 50
|
||||||
|
seed = 0
|
||||||
|
sig_min = 1e-8
|
||||||
|
|
||||||
|
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
|
||||||
|
|
||||||
|
t = np.linspace(0,4,30)
|
||||||
|
|
||||||
|
x_t = stocproc.x_t_array(t)
|
||||||
|
|
||||||
|
fname = 'test_stocproc.dump'
|
||||||
|
|
||||||
|
stocproc.save_to_file(fname)
|
||||||
|
|
||||||
|
stocproc_2 = sp.StocProc(seed = seed, fname = fname)
|
||||||
|
x_t_2 = stocproc_2.x_t_array(t)
|
||||||
|
|
||||||
|
assert np.all(x_t == x_t_2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def show_auto_grid_points_result():
|
def show_auto_grid_points_result():
|
||||||
s_param = 1
|
s_param = 1
|
||||||
|
@ -242,5 +303,7 @@ if __name__ == "__main__":
|
||||||
# test_stocProc_eigenfunction_extraction()
|
# test_stocProc_eigenfunction_extraction()
|
||||||
# test_orthonomality()
|
# test_orthonomality()
|
||||||
# test_auto_grid_points()
|
# test_auto_grid_points()
|
||||||
show_auto_grid_points_result()
|
# show_auto_grid_points_result()
|
||||||
|
# test_chache()
|
||||||
|
test_dump_load()
|
||||||
pass
|
pass
|
Loading…
Add table
Reference in a new issue