implement stocproc Class dump load, cache result for x(t) via decorator, may need some own implementation within the class definition

This commit is contained in:
Richard Hartmann 2014-09-17 17:17:17 +02:00
parent 8a247107ae
commit 8d71a2ee03
2 changed files with 146 additions and 27 deletions

View file

@ -17,6 +17,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
from copyreg import pickle
"""
**Stochastic Process Module**
@ -58,6 +59,8 @@ solutions of the time discrete version.
import numpy as np
import time as tm
import functools
import pickle
class StocProc(object):
r"""Simulate Stochastic Process using Karhunen-Loève expansion
@ -124,26 +127,74 @@ class StocProc(object):
Auflage: 3. ed. Cambridge University Press, Cambridge, UK ; New York. (pp. 989)
"""
def __init__(self, r_tau, t, w, seed = None, sig_min = 1e-4):
self._r_tau = r_tau
self._num_gp = len(t)
self._s = t
self._w = w
t_row = t.reshape(1, self._num_gp)
t_col = t.reshape(self._num_gp, 1)
# correlation matrix
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
r = r_tau(t_col-t_row)
def __init__(self, r_tau=None, t=None, w=None, seed = None, sig_min = 1e-4, fname=None):
self.__dump_members = ['_r_tau',
'_s',
'_w',
'_eig_val',
'_eig_vec']
if fname == None:
assert r_tau != None
self._r_tau = r_tau
assert t != None
self._s = t
self._num_gp = len(self._s)
assert w != None
self._w = w
t_row = self._s.reshape(1, self._num_gp)
t_col = self._s.reshape(self._num_gp, 1)
# correlation matrix
# r_tau(t-s) -> integral/sum over s -> s must be row in EV equation
r = self._r_tau(t_col-t_row)
# solve discrete Fredholm equation
# eig_val = lambda
# eig_vec = u(t)
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
self._sqrt_eig_val = np.sqrt(self._eig_val)
self._num_ev = len(self._eig_val)
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
# solve discrete Fredholm equation
# eig_val = lambda
# eig_vec = u(t)
self._eig_val, self._eig_vec = solve_hom_fredholm(r, w, sig_min**2)
else:
self.__load(fname)
self.__calc_missing()
self.new_process(seed)
@classmethod
def new_instance_with_trapezoidal_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
t, w = get_trapezoidal_weights_times(t_max, ng)
return _StocProc(r_tau, t, w, seed, sig_min)
@classmethod
def new_instance_with_mid_point_weights(_StocProc, r_tau, t_max, ng, seed, sig_min):
t, w = get_mid_point_weights(t_max, ng)
return _StocProc(r_tau, t, w, seed, sig_min)
def __load(self, fname):
with open(fname, 'rb') as f:
for m in self.__dump_members:
setattr(self, m, pickle.load(f))
def __calc_missing(self):
self._num_gp = len(self._s)
self._sqrt_eig_val = np.sqrt(self._eig_val)
self._num_ev = len(self._eig_val)
self._A = w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
self._A = self._w.reshape(self._num_gp,1) * self._eig_vec / self._sqrt_eig_val.reshape(1, self._num_ev)
self.new_process(seed)
def __dump(self, fname):
with open(fname, 'wb') as f:
for m in self.__dump_members:
pickle.dump(getattr(self, m), f, pickle.HIGHEST_PROTOCOL)
def save_to_file(self, fname):
self.__dump(fname)
def new_process(self, seed = None):
r"""setup new process
@ -170,13 +221,18 @@ class StocProc(object):
"""
tmp = self._Y * self._sqrt_eig_val.reshape(self._num_ev,1)
return np.tensordot(tmp, self._eig_vec, axes=([0],[1])).flatten()
@functools.lru_cache(maxsize=1024, typed=False)
def x(self, t):
# self._Y # (N_ev, 1 )
# _Y # (N_ev, 1 )
tmp = self._Y*self._r_tau(t-self._s.reshape(1, self._num_gp))
# (N_ev, N_gp)
# A # (N_gp, N_ev)
# (N_ev, N_gp)
# A # (N_gp, N_ev)
return np.tensordot(tmp, self._A, axes=([1,0],[0,1]))
def get_cache_info(self):
return self.x.cache_info()
def x_t_array(self, t_array):
t_array = t_array.reshape(1,1,len(t_array)) # (1 , 1 , N_t)

View file

@ -20,16 +20,21 @@
"""Test Suite for Stochastic Process Module stocproc.py
"""
from __future__ import division
import numpy as np
import stocproc as sp
from scipy.special import gamma
from scipy.interpolate import interp1d
import time as tm
import pickle as pc
import matplotlib.pyplot as plt
import functools
import matplotlib.pyplot as plt
import sys
import os
path = os.path.dirname(__file__)
sys.path.append(path)
import stocproc as sp
def corr(tau, s, gamma_s_plus_1):
"""ohmic bath correlation function"""
@ -186,9 +191,9 @@ def test_orthonomality():
idx1, idx2 = np.where(diff > diff_assert)
if len(idx1) > 0:
print "orthonomality test FAILED at:"
print("orthonomality test FAILED at:")
for i in range(len(idx1)):
print " ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]])
print(" ({}, {}) diff to unity matrix: {}".format(idx1[i],idx2[i], diff[idx1[i],idx2[i]]))
raise Exception("test_orthonomality FAILED!")
def test_auto_grid_points():
@ -202,7 +207,63 @@ def test_auto_grid_points():
tol = 1e-16
ng = sp.auto_grid_points(r_tau, t_max, ng_interpolation, tol)
print ng
print(ng)
def test_chache():
s_param = 1
gamma_s_plus_1 = gamma(s_param+1)
r_tau = lambda tau : corr(tau, s_param, gamma_s_plus_1)
t_max = 10
ng = 50
seed = 0
sig_min = 1e-8
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
t = {}
t[1] = 3
t[2] = 4
t[3] = 5
total = 0
misses = len(t.keys())
for t_i in t.keys():
for i in range(t[t_i]):
total += 1
stocproc.x(t_i)
ci = stocproc.get_cache_info()
assert ci.hits == total - misses
assert ci.misses == misses
def test_dump_load():
s_param = 1
gamma_s_plus_1 = gamma(s_param+1)
r_tau = functools.partial(corr, s=s_param, gamma_s_plus_1=gamma_s_plus_1)
t_max = 10
ng = 50
seed = 0
sig_min = 1e-8
stocproc = sp.StocProc.new_instance_with_trapezoidal_weights(r_tau, t_max, ng, seed, sig_min)
t = np.linspace(0,4,30)
x_t = stocproc.x_t_array(t)
fname = 'test_stocproc.dump'
stocproc.save_to_file(fname)
stocproc_2 = sp.StocProc(seed = seed, fname = fname)
x_t_2 = stocproc_2.x_t_array(t)
assert np.all(x_t == x_t_2)
def show_auto_grid_points_result():
s_param = 1
@ -242,5 +303,7 @@ if __name__ == "__main__":
# test_stocProc_eigenfunction_extraction()
# test_orthonomality()
# test_auto_grid_points()
show_auto_grid_points_result()
# show_auto_grid_points_result()
# test_chache()
test_dump_load()
pass