mirror of
https://github.com/vale981/stocproc
synced 2025-03-05 09:41:42 -05:00
try to speed up fft init
This commit is contained in:
parent
e0a1aa6519
commit
c03ac939ad
2 changed files with 89 additions and 30 deletions
|
@ -164,7 +164,7 @@ def _absDiff(xRef, x):
|
|||
return np.max(np.abs(xRef - x))
|
||||
|
||||
|
||||
def _f_opt(x, integrand, a, b, N, t_max, ft_ref, diff_method, _f_opt_cache, b_only):
|
||||
def _f_opt_for_SLSQP_minimizer(x, integrand, a, b, N, t_max, ft_ref, diff_method, _f_opt_cache, b_only):
|
||||
key = float(x[0])
|
||||
if key in _f_opt_cache:
|
||||
d, a_, b_ = _f_opt_cache[key]
|
||||
|
@ -196,7 +196,7 @@ def _f_opt(x, integrand, a, b, N, t_max, ft_ref, diff_method, _f_opt_cache, b_on
|
|||
ft_ref_tau = ft_ref(tau[idx])
|
||||
d = diff_method(ft_ref_tau, ft_tau[idx])
|
||||
_f_opt_cache[key] = d, a_, b_
|
||||
#log.debug("f_opt tol {} -> d {}".format(10**x, d))
|
||||
log.info("f_opt tol {} -> d {}".format(tol, d))
|
||||
return np.log10(d)
|
||||
|
||||
def _lower_contrs(x, integrand, a, b, N, t_max, ft_ref, diff_method, _f_opt_cache, b_only):
|
||||
|
@ -214,21 +214,71 @@ def _upper_contrs(x):
|
|||
return -x
|
||||
|
||||
|
||||
def opt_integral_boundaries(integrand, a, b, t_max, ft_ref, opt_b_only, N, diff_method):
|
||||
log.debug("optimize integral boundary N:{} [{:.3e},{:.3e}]".format(N, a, b))
|
||||
def _f_opt(x, integrand, a, b, N, t_max, ft_ref, diff_method, b_only):
|
||||
tol = x
|
||||
|
||||
if b_only:
|
||||
a_ = a
|
||||
b_ = find_integral_boundary(integrand, tol=tol, ref_val=b, max_val=1e6, x0=1)
|
||||
else:
|
||||
a_ = find_integral_boundary(integrand, tol=tol, ref_val=a, max_val=1e6, x0=-1)
|
||||
b_ = find_integral_boundary(integrand, tol=tol, ref_val=b, max_val=1e6, x0=1)
|
||||
|
||||
tau, ft_tau = fourier_integral_midpoint(integrand, a_, b_, N)
|
||||
idx = np.where(tau <= t_max)
|
||||
ft_ref_tau = ft_ref(tau[idx])
|
||||
d = diff_method(ft_ref_tau, ft_tau[idx])
|
||||
log.info("f_opt interval [{:.3e},{:.3e}] -> d {}".format(a_, b_, d))
|
||||
return d, a_, b_
|
||||
|
||||
|
||||
|
||||
def opt_integral_boundaries_use_SLSQP_minimizer(integrand, a, b, t_max, ft_ref, opt_b_only, N, diff_method):
|
||||
"""
|
||||
this is very slow
|
||||
"""
|
||||
log.info("optimize integral boundary N:{} [{:.3e},{:.3e}], please wait ...".format(N, a, b))
|
||||
|
||||
_f_opt_cache = dict()
|
||||
args = (integrand, a, b, N, t_max, ft_ref, diff_method, _f_opt_cache, opt_b_only)
|
||||
x0 = np.log10(0.1*integrand(b))
|
||||
r = minimize(_f_opt, x0 = x0, args = args,
|
||||
r = minimize(_f_opt_for_SLSQP_minimizer, x0 = x0, args = args,
|
||||
method='SLSQP',
|
||||
constraints=[{"type": "ineq", "fun": _lower_contrs, "args": args},
|
||||
{"type": "ineq", "fun": _upper_contrs}])
|
||||
d, a_, b_ = _f_opt_cache[float(r.x)]
|
||||
if a_ is None or b_ is None:
|
||||
log.info("optimization with N {} failed".format(N))
|
||||
return d, a, b
|
||||
|
||||
log.info("optimization with N {} yields max rd {:.3e} and new boundaries [{:.2e},{:.2e}]".format(N, d, a_, b_))
|
||||
return d, a_, b_
|
||||
|
||||
def get_N_a_b_for_accurate_fourier_integral(integrand, a, b, t_max, tol, ft_ref, opt_b_only, N_max = 2**20,
|
||||
def opt_integral_boundaries(integrand, a, b, t_max, ft_ref, tol, opt_b_only, N, diff_method):
|
||||
log.info("optimize integral boundary N:{} [{:.3e},{:.3e}], please wait ...".format(N, a, b))
|
||||
|
||||
|
||||
args = (integrand, a, b, N, t_max, ft_ref, diff_method, opt_b_only)
|
||||
x0 = integrand(b)
|
||||
d1 = np.inf, None, None
|
||||
while True:
|
||||
d = _f_opt(x0, *args)
|
||||
log.info("opt int: J(w) min:{:.3e} and N:{} -> tol:{:.3e}".format(x0, N, d[0]))
|
||||
if d[0] < tol:
|
||||
log.info("return, cause tol of {} was reached".format(tol))
|
||||
return d
|
||||
x0 *= 0.1
|
||||
if d[0] > d1[0]:
|
||||
log.info("return cause further decrease of 'J(w) min' does not improove accuracy")
|
||||
return d
|
||||
if x0 < 1e-12:
|
||||
log.info("return cause 'J(w) min' < 1e-6")
|
||||
return d
|
||||
d1 = d
|
||||
|
||||
|
||||
|
||||
def get_N_a_b_for_accurate_fourier_integral(integrand, a, b, N_start, t_max, tol, ft_ref, opt_b_only, N_max = 2**20,
|
||||
diff_method=_absDiff):
|
||||
"""
|
||||
chooses N such that the approximated Fourier integral
|
||||
|
@ -245,25 +295,24 @@ def get_N_a_b_for_accurate_fourier_integral(integrand, a, b, t_max, tol, ft_ref,
|
|||
log.debug("ft_ref check yields rd {:.3e}".format(rd))
|
||||
if rd > 1e-6:
|
||||
raise FTReferenceError("it seems that 'ft_ref' is not the fourier transform of 'integrand'")
|
||||
|
||||
i = 10
|
||||
|
||||
N = N_start
|
||||
while True:
|
||||
N = 2**i
|
||||
rd, a_new, b_new = opt_integral_boundaries(integrand=integrand, a=a, b=b, t_max=t_max, ft_ref=ft_ref,
|
||||
rd, a_new, b_new = opt_integral_boundaries(integrand=integrand, a=a, b=b, t_max=t_max, ft_ref=ft_ref, tol=tol,
|
||||
opt_b_only=opt_b_only, N=N, diff_method=diff_method)
|
||||
a = a_new
|
||||
b = b_new
|
||||
#a = a_new
|
||||
#b = b_new
|
||||
|
||||
if rd < tol:
|
||||
log.info("reached rd ({:.3e}) < tol ({:.3e}), return N={}".format(rd, tol, N))
|
||||
return N, a, b
|
||||
return N, a_new, b_new
|
||||
if N > N_max:
|
||||
raise RuntimeError("maximum number of points for Fourier Transform reached")
|
||||
i += 1
|
||||
N *= 2
|
||||
|
||||
def get_dt_for_accurate_interpolation(t_max, tol, ft_ref, diff_method=_absDiff):
|
||||
N = 32
|
||||
sub_sampl = 8
|
||||
sub_sampl = 2
|
||||
|
||||
while True:
|
||||
tau = np.linspace(0, t_max, N+1)
|
||||
|
@ -272,24 +321,34 @@ def get_dt_for_accurate_interpolation(t_max, tol, ft_ref, diff_method=_absDiff):
|
|||
ft_intp_n = ft_intp(tau)
|
||||
|
||||
d = diff_method(ft_intp_n, ft_ref_n)
|
||||
log.debug("acc interp N {} dt {:.3e} {:.3e} -> d {:.3e}".format(N, sub_sampl*tau[1], t_max/(N/sub_sampl), d))
|
||||
log.info("acc interp N {} dt {:.3e} {:.3e} -> d {:.3e}".format(N, sub_sampl*tau[1], t_max/(N/sub_sampl), d))
|
||||
if d < tol:
|
||||
return t_max/(N/sub_sampl)
|
||||
N*=2
|
||||
|
||||
|
||||
def calc_ab_N_dx_dt(integrand, intgr_tol, intpl_tol, t_max, a, b, ft_ref, opt_b_only, N_max = 2**20, diff_method=_absDiff):
|
||||
log.info("get_dt_for_accurate_interpolation, please wait ...")
|
||||
c = find_integral_boundary(lambda tau: np.abs(ft_ref(tau)) / np.abs(ft_ref(0)),
|
||||
intgr_tol, 1, 1e6, 1)
|
||||
dt_tol = get_dt_for_accurate_interpolation(t_max=c,
|
||||
tol=intpl_tol,
|
||||
ft_ref=ft_ref,
|
||||
diff_method=diff_method)
|
||||
|
||||
N_start = t_max / dt_tol
|
||||
N_start = 2 ** int(np.ceil(np.log2(N_start)))
|
||||
|
||||
log.info("get_N_a_b_for_accurate_fourier_integral, please wait ...")
|
||||
N, a, b = get_N_a_b_for_accurate_fourier_integral(integrand, a, b,
|
||||
N_start=N_start,
|
||||
t_max = t_max,
|
||||
tol = intgr_tol,
|
||||
ft_ref = ft_ref,
|
||||
opt_b_only=opt_b_only,
|
||||
N_max = N_max,
|
||||
diff_method=diff_method)
|
||||
dt_tol = get_dt_for_accurate_interpolation(t_max = t_max,
|
||||
tol = intpl_tol,
|
||||
ft_ref = ft_ref,
|
||||
diff_method=diff_method)
|
||||
|
||||
dx = (b-a)/N
|
||||
dt = 2*np.pi/dx/N
|
||||
if dt <= dt_tol:
|
||||
|
|
|
@ -379,24 +379,24 @@ class StocProc_FFT(_absStocProc):
|
|||
# assume the spectral_density is non zero also for w<0
|
||||
# but decays fast for large |w|
|
||||
b = method_fft.find_integral_boundary(integrand = spectral_density,
|
||||
tol = intgr_tol**2,
|
||||
tol = intgr_tol,
|
||||
ref_val = 1,
|
||||
max_val = 1e6,
|
||||
x0 = 1)
|
||||
a = method_fft.find_integral_boundary(integrand = spectral_density,
|
||||
tol = intgr_tol**2,
|
||||
tol = intgr_tol,
|
||||
ref_val = -1,
|
||||
max_val = 1e6,
|
||||
x0 = -1)
|
||||
a, b, N, dx, dt = method_fft.calc_ab_N_dx_dt(integrand = spectral_density,
|
||||
intgr_tol = intgr_tol,
|
||||
intpl_tol = intpl_tol,
|
||||
t_max = t_max,
|
||||
a = a,
|
||||
b = b,
|
||||
ft_ref = lambda tau:bcf_ref(tau)*np.pi,
|
||||
opt_b_only= False,
|
||||
N_max = 2**24)
|
||||
intgr_tol = intgr_tol,
|
||||
intpl_tol = intpl_tol,
|
||||
t_max = t_max,
|
||||
a = a,
|
||||
b = b,
|
||||
ft_ref = lambda tau:bcf_ref(tau)*np.pi,
|
||||
opt_b_only= False,
|
||||
N_max = 2**24)
|
||||
log.info("required tol result in N {}".format(N))
|
||||
|
||||
assert abs(2*np.pi - N*dx*dt) < 1e-12
|
||||
|
|
Loading…
Add table
Reference in a new issue