first try at cholesky

This commit is contained in:
Valentin Boettcher 2022-02-08 10:02:18 +01:00
parent 004e831dc0
commit 01d5542a49
2 changed files with 163 additions and 0 deletions

View file

@ -3,3 +3,4 @@ from .stocproc import StocProc # for typing
from .stocproc import StocProc_FFT
from .stocproc import StocProc_KLE
from .stocproc import StocProc_TanhSinh
from .stocproc import Cholesky

View file

@ -2,6 +2,10 @@ import abc
from functools import partial
from typing import Optional
import numpy as np
from numpy.typing import NDArray
from collections.abc import Callable
import scipy.linalg
import scipy.optimize
import time
from . import method_kle
@ -821,6 +825,164 @@ class StocProc_TanhSinh(StocProc):
return len(self.fl)
BCF = Callable[[NDArray[np.floating]], NDArray[np.complex128]]
class Cholesky(StocProc):
r"""Generate Stochastic Processes using the cholesky decomposition."""
def __init__(
self,
t_max: float,
alpha: BCF,
intgr_tol=1e-2,
intpl_tol=1e-2,
chol_tol=1e-2,
correlation_cutoff=1e-3,
seed=None,
scale=1,
calc_deriv: bool = False,
max_iterations: int = 10,
):
del intgr_tol # not used for now
self.key = (
"chol",
alpha,
t_max,
intpl_tol,
chol_tol,
correlation_cutoff,
calc_deriv,
max_iterations,
)
steps = int(t_max / intpl_tol) + 1
self.t: NDArray[np.float128] = np.linspace(0, t_max, steps, dtype=np.float128)
"""The times at which the stochastic process will be sampled."""
super().__init__(
t_max=t_max,
num_grid_points=len(self.t),
seed=seed,
scale=scale,
calc_deriv=calc_deriv,
)
cutoff_sol = scipy.optimize.root(
lambda t: np.abs(alpha(t)) - correlation_cutoff, x0=[0.001]
)
if not cutoff_sol.success:
raise RuntimeError(
f"Could not find a suitable cutoff time. Scipy says '{cutoff_sol.message}."
)
self.t_chol = np.linspace(
0,
cutoff_sol.x[0] * 2,
int(((cutoff_sol.x[0]) / intpl_tol) + 1) * 2 + 1,
dtype=np.float128,
)
mat, tol = self.stable_cholesky(alpha, self.t_chol, max_iterations)
if mat is None or tol > chol_tol:
raise RuntimeError(
f"The tolerance of {chol_tol} could not be reached. We got as far as {tol}."
)
self.chol_matrix: NDArray[np.complex128] = mat[1:, 1:]
if calc_deriv:
self.chol_deriv = np.gradient(mat, self.t, axis=0)[1:, 1:]
self.t_chol = self.t_chol[1:]
self.chunk_size = len(self.t_chol) // 2
self.patch_matrix: NDArray[np.complex128] = scipy.linalg.inv(
self.chol_matrix[: self.chunk_size, : self.chunk_size]
)
self.num_chunks = int(len(self.t) / self.chunk_size) + 1
breakpoint()
@staticmethod
def stable_cholesky(
α: BCF, t, max_iterations: int = 100, starteps: Optional[float] = None
):
t = np.asarray(t)
tt, ss = np.meshgrid(t, t, sparse=False)
Σ = α(np.array(tt - ss))
eye = np.eye(len(t))
eps: float = 0.0
L = None
reversed = False
for _ in range(max_iterations):
log.info(f"Trying ε={eps}.")
try:
L = scipy.linalg.cholesky(Σ + eps * eye, lower=True, check_finite=False)
if eps == 0 or reversed:
break
eps /= 2
except scipy.linalg.LinAlgError as _:
if eps == 0:
eps = starteps or np.finfo(np.float64).eps * 4
else:
eps = eps * 2
reverse = True
return L, (np.max(np.abs((L @ L.T.conj() - Σ) / Σ)) if L is not None else -1)
def calc_z(self, y: NDArray[np.complex128]):
assert y.shape == (self.get_num_y(),)
res = np.empty(self.chunk_size * self.num_chunks, dtype=np.complex128)
y = np.pad(y, (0, len(res) - len(y)), "constant")
offset = len(self.t_chol)
res[0:offset] = self.chol_matrix @ y[:offset]
y_curr = np.empty(self.chunk_size * 2, dtype=np.complex128)
last_values = res[offset // 2 : offset]
for i in range(self.num_chunks - 2):
next_offset = offset + self.chunk_size
y_curr[0 : self.chunk_size] = self.patch_matrix @ last_values
y_curr[self.chunk_size : self.chunk_size * 2] = y[offset:next_offset]
res[offset:next_offset] = (self.chol_matrix @ y_curr)[
self.chunk_size : self.chunk_size * 2
]
last_values = res[offset:next_offset]
offset = next_offset
return res[0 : len(self.t)]
def calc_z_dot(self, y: np.ndarray) -> np.ndarray:
r"""Calculate the discrete time stochastic process derivative using FFT algorithm
and return values :math:`\dot{z}_n` with :math:`t_n <= t_\mathrm{max}`.
"""
z_dot_fft = np.fft.fft(-1j * self.omega_k * self.yl * y)
z_dot = z_dot_fft[0 : self.num_grid_points] * self.omega_min_correction
return z_dot
def get_num_y(self):
r"""The number of independent random variables :math:`Y_m` is given by the number of discrete nodes
used by the Fast Fourier Transform algorithm.
"""
return len(self.t)
def alpha_times_pi(tau, alpha):
return alpha(tau) * np.pi