mirror of
https://github.com/vale981/stocproc
synced 2025-03-05 09:41:42 -05:00
implement stocproc Class dump load, cache result for x(t) via decorator, may need some own implementation within the class definition
This commit is contained in:
parent
8a247107ae
commit
8d71a2ee03
2 changed files with 146 additions and 27 deletions
96
stocproc.py
96
stocproc.py
|
@ -17,6 +17,7 @@
|
|||
# along with this program; if not, write to the Free Software
|
||||
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
||||
# MA 02110-1301, USA.
|
||||
from copyreg import pickle
|
||||
"""
|
||||
**Stochastic Process Module**
|
||||
|
||||
|
@ -58,6 +59,8 @@ solutions of the time discrete version.
|
|||
|
||||
import numpy as np
|
||||
import time as tm
|
||||
import functools
|
||||
import pickle
|
||||
|
||||
class StocProc(object):
|
||||
r"""Simulate Stochastic Process using Karhunen-Loève expansion
|
||||
|
@ -124,26 +127,74 @@ class StocProc(object):
|
|||
Auflage: 3. ed. Cambridge University Press, Cambridge, UK ; New York. (pp. 989)
|
||||
|
||||
"""
|
||||
def __init__(self, r_tau, t, w, seed = None, sig_min = 1e-4):
|
||||
self._r_tau = r_tau
|
||||
self._num_gp = len(t)
|
||||
self._s = t
|
||||
self._w = w
|
||||
t_row = t.reshape(1, self._num_gp)
|
||||
t_col = t.reshape(self._num_gp, 1)
|
||||
# correlation matrix
|
||||
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
|
||||
r = r_tau(t_col-t_row)
|
||||
|
||||
def __init__(self, r_tau=None, t=None, w=None, seed = None, sig_min = 1e-4, fname=None):
|
||||
self.__dump_members = ['_r_tau',
|
||||
'_s',
|
||||
'_w',
|
||||
'_eig_val',
|
||||
'_eig_vec']
|
||||
if fname == None:
|
||||
|
||||
assert r_tau != None
|
||||
self._r_tau = r_tau
|
||||
|
||||
assert t != None
|
||||
self._s = t
|
||||
self._num_gp = len(self._s)
|
||||
|
||||
assert w != None
|
||||
self._w = w
|
||||
|
||||
t_row = self._s.reshape(1, self._num_gp)
|
||||
t_col = self._s.reshape(self._num_gp, 1)
|
||||
# correlation matrix
|
||||
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
|
||||
r = self._r_tau(t_col-t_row)
|
||||
|
||||
# solve discrete Fredholm equation
|
||||
# eig_val = lambda
|
||||
# eig_vec = u(t)
|
||||
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
|
||||
self._sqrt_eig_val = np.sqrt(self._eig_val)
|
||||
self._num_ev = len(self._eig_val)
|
||||
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
||||
|
||||
# solve discrete Fredholm equation
|
||||
# eig_val = lambda
|
||||
# eig_vec = u(t)
|
||||
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
|
||||
else:
|
||||
self.__load(fname)
|
||||
self.__calc_missing()
|
||||
|
||||
self.new_process(seed)
|
||||
|
||||
@classmethod
|
||||
def new_instance_with_trapezoidal_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
|
||||
t, w = get_trapezoidal_weights_times(t_max, ng)
|
||||
return _StocProc(r_tau, t, w, seed, sig_min)
|
||||
|
||||
@classmethod
|
||||
def new_instance_with_mid_point_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
|
||||
t, w = get_mid_point_weights(t_max, ng)
|
||||
return _StocProc(r_tau, t, w, seed, sig_min)
|
||||
|
||||
def __load(self, fname):
|
||||
with open(fname, 'rb') as f:
|
||||
for m in self.__dump_members:
|
||||
setattr(self, m, pickle.load(f))
|
||||
|
||||
def __calc_missing(self):
|
||||
self._num_gp = len(self._s)
|
||||
self._sqrt_eig_val = np.sqrt(self._eig_val)
|
||||
self._num_ev = len(self._eig_val)
|
||||
self._A = w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
||||
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
|
||||
|
||||
self.new_process(seed)
|
||||
|
||||
def __dump(self, fname):
|
||||
with open(fname, 'wb') as f:
|
||||
for m in self.__dump_members:
|
||||
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def save_to_file(self, fname):
|
||||
self.__dump(fname)
|
||||
|
||||
def new_process(self, seed = None):
|
||||
r"""setup new process
|
||||
|
@ -170,13 +221,18 @@ class StocProc(object):
|
|||
"""
|
||||
tmp = self._Y * self._sqrt_eig_val.reshape(self._num_ev,1)
|
||||
return np.tensordot(tmp, self._eig_vec, axes=([0],[1])).flatten()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1024, typed=False)
|
||||
def x(self, t):
|
||||
# self._Y # (N_ev, 1 )
|
||||
# _Y # (N_ev, 1 )
|
||||
tmp = self._Y*self._r_tau(t-self._s.reshape(1, self._num_gp))
|
||||
# (N_ev, N_gp)
|
||||
# A # (N_gp, N_ev)
|
||||
# (N_ev, N_gp)
|
||||
# A # (N_gp, N_ev)
|
||||
return np.tensordot(tmp, self._A, axes=([1,0],[0,1]))
|
||||
|
||||
def get_cache_info(self):
|
||||
return self.x.cache_info()
|
||||
|
||||
|
||||
def x_t_array(self, t_array):
|
||||
t_array = t_array.reshape(1,1,len(t_array)) # (1 , 1 , N_t)
|
||||
|
|
|
@ -20,16 +20,21 @@
|
|||
"""Test Suite for Stochastic Process Module stocproc.py
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import stocproc as sp
|
||||
from scipy.special import gamma
|
||||
from scipy.interpolate import interp1d
|
||||
import time as tm
|
||||
import pickle as pc
|
||||
import matplotlib.pyplot as plt
|
||||
import functools
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
|
||||
path = os.path.dirname(__file__)
|
||||
sys.path.append(path)
|
||||
|
||||
import stocproc as sp
|
||||
|
||||
def corr(tau, s, gamma_s_plus_1):
|
||||
"""ohmic bath correlation function"""
|
||||
|
@ -186,9 +191,9 @@ def test_orthonomality():
|
|||
idx1, idx2 = np.where(diff > diff_assert)
|
||||
|
||||
if len(idx1) > 0:
|
||||
print "orthonomality test FAILED at:"
|
||||
print("orthonomality test FAILED at:")
|
||||
for i in range(len(idx1)):
|
||||
print " ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]])
|
||||
print(" ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]]))
|
||||
raise Exception("test_orthonomality FAILED!")
|
||||
|
||||
def test_auto_grid_points():
|
||||
|
@ -202,7 +207,63 @@ def test_auto_grid_points():
|
|||
tol = 1e-16
|
||||
|
||||
ng = sp.auto_grid_points(r_tau, t_max, ng_interpolation, tol)
|
||||
print ng
|
||||
print(ng)
|
||||
|
||||
def test_chache():
|
||||
s_param = 1
|
||||
gamma_s_plus_1 = gamma(s_param+1)
|
||||
r_tau = lambda tau : corr(tau, s_param, gamma_s_plus_1)
|
||||
|
||||
t_max = 10
|
||||
ng = 50
|
||||
seed = 0
|
||||
sig_min = 1e-8
|
||||
|
||||
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
|
||||
|
||||
t = {}
|
||||
t[1] = 3
|
||||
t[2] = 4
|
||||
t[3] = 5
|
||||
|
||||
total = 0
|
||||
misses = len(t.keys())
|
||||
for t_i in t.keys():
|
||||
for i in range(t[t_i]):
|
||||
total += 1
|
||||
stocproc.x(t_i)
|
||||
|
||||
ci = stocproc.get_cache_info()
|
||||
assert ci.hits == total - misses
|
||||
assert ci.misses == misses
|
||||
|
||||
def test_dump_load():
|
||||
s_param = 1
|
||||
gamma_s_plus_1 = gamma(s_param+1)
|
||||
r_tau = functools.partial(corr, s=s_param, gamma_s_plus_1=gamma_s_plus_1)
|
||||
|
||||
t_max = 10
|
||||
ng = 50
|
||||
seed = 0
|
||||
sig_min = 1e-8
|
||||
|
||||
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
|
||||
|
||||
t = np.linspace(0,4,30)
|
||||
|
||||
x_t = stocproc.x_t_array(t)
|
||||
|
||||
fname = 'test_stocproc.dump'
|
||||
|
||||
stocproc.save_to_file(fname)
|
||||
|
||||
stocproc_2 = sp.StocProc(seed = seed, fname = fname)
|
||||
x_t_2 = stocproc_2.x_t_array(t)
|
||||
|
||||
assert np.all(x_t == x_t_2)
|
||||
|
||||
|
||||
|
||||
|
||||
def show_auto_grid_points_result():
|
||||
s_param = 1
|
||||
|
@ -242,5 +303,7 @@ if __name__ == "__main__":
|
|||
# test_stocProc_eigenfunction_extraction()
|
||||
# test_orthonomality()
|
||||
# test_auto_grid_points()
|
||||
show_auto_grid_points_result()
|
||||
# show_auto_grid_points_result()
|
||||
# test_chache()
|
||||
test_dump_load()
|
||||
pass
|
Loading…
Add table
Reference in a new issue