master-thesis/python/richard_hops/ode_wrapper.py

278 lines
8.8 KiB
Python
Raw Normal View History

2021-10-15 16:18:03 +02:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import warnings
from time import time
import logging
import sys
import traceback
log = logging.getLogger(__name__)
try:
from scipy.integrate import ode
except ImportError as e:
warnings.warn("Submodule 'ode_wrapper' will not work. Reason: {}.".format(e))
class Dummy_c(object):
def __init__(self):
self.value = 0
pass
def complex_to_real(vc):
return np.hstack([np.real(vc), np.imag(vc)])
def real_to_complex(vr):
n = len(vr)//2
return vr[:n] + 1j*vr[n:]
def wrap_complex_intgeration(f_complex):
"""
if f: R x C^n -> C^n
then this functions returns the real equivalent
f_prime R x R^n x R^n -> R^n x R^n
such that a complex vector
cc = [vc_1, ... vc_n]
translates to
cr = [RE(vc_1), ... RE(vc_n), IM(vc_1), ... IM(vc_n)]
"""
def f_real(t, yr):
return complex_to_real( f_complex(t, real_to_complex(yr)) )
return f_real
def timed_f(f, time_as_list):
def new_f(t, x):
t0 = time()
res = f(t, x)
t1 = time()
time_as_list[0] += t1-t0
return res
return new_f
def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res=None, **kwargs):
f_partial_complex = lambda t, x: f(t, x, *args)
if integrator == 'zvode':
# define complex derivative
f_ = f_partial_complex
x0_ = x0
elif (integrator == 'vode') | (integrator == 'lsoda') | (integrator == 'dopri5') | (integrator == 'dop853'):
# define real derivative (separation for real and imaginary part)
f_ = lambda t, x: wrap_complex_intgeration(f_partial_complex)(t, x)
x0_ = complex_to_real(x0)
log.warning("PERFORMANCE WARNING, avoid using 'vode' or 'lsoda' for complex ode's")
else:
raise RuntimeError("unknown integrator '{}'".format(integrator))
time_as_list = [0.]
f__ = timed_f(f_, time_as_list)
r = ode(f__)
if (integrator == 'dopri5') | (integrator == 'dop853'):
if 'order' in kwargs:
del kwargs['order']
kws = list(kwargs.keys())
for kw in kws:
if kwargs[kw] is None:
del kwargs[kw]
r.set_integrator(integrator, **kwargs)
# x0_ might be the mapping from C to R^2
r.set_initial_value(x0_, t0)
t = np.linspace(t0, t1, N)
if res_dim is None:
res_dim = (len(x0), )
res_list_len = None
else:
try:
res_list_len = len(res_dim)
assert res_list_len == len(x_to_res)
except TypeError:
res_list_len = None
if x_to_res is None:
x_to_res = lambda t_, x_: x_
# the usual case with only one result type
if res_list_len is None:
# complex array for result
x = np.empty(shape=(N,) + res_dim, dtype=np.complex128)
x[0] = x_to_res(t0, x0)
# print(args.eta._Z)
t_int = 0
t_conv = 0
i = 1
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
while r.successful() and i < N:
_t = time()
r.integrate(t[i])
t_int += (time()-_t)
_t = time()
if integrator == 'zvode':
# complex integration -> yields complex values
x[i] = x_to_res(r.t, r.y)
else:
# real integration -> mapping from R^2 to C needed
x[i] = x_to_res(r.t, real_to_complex(r.y))
t_conv += (time()-_t)
if abs(t[i]-r.t) > 1e-13:
msg = "INTEGRATION WARNING: time mismatch (diff at step {}: {:.3e})".format(i, abs(t[i]-r.t))
log.warning(msg)
raise Warning(msg)
t[i] = r.t
c.value = i
i += 1
if not r.successful():
msg = "INTEGRATION WARNING: NOT successful!"
log.warning(msg)
raise Warning(msg)
except Exception as e:
trb = traceback.format_exc()
return t[:i], x[:i], (e, trb)
# having to compute multiple result types
else:
# complex array for result
x = []
for a in range(res_list_len):
x.append(np.empty(shape=(N,) + res_dim[a], dtype=np.complex128))
x[-1][0] = x_to_res[a](t0, x0)
# print(args.eta._Z)
t_int = 0
t_conv = 0
i = 1
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
while r.successful() and i < N:
_t = time()
r.integrate(t[i])
t_int += (time()-_t)
_t = time()
if integrator == 'zvode':
# complex integration -> yields complex values
for a in range(res_list_len):
x[a][i] = x_to_res[a](r.t, r.y)
else:
# real integration -> mapping from R^2 to C needed
for a in range(res_list_len):
x[a][i] = x_to_res[a](r.t, real_to_complex(r.y))
t_conv += (time()-_t)
if abs(t[i]-r.t) > 1e-13:
msg = "INTEGRATION WARNING: time mismatch (diff at step {}: {:.3e})".format(i, abs(t[i]-r.t))
log.warning(msg)
raise Warning(msg)
t[i] = r.t
c.value = i
i += 1
if not r.successful():
msg = "INTEGRATION WARNING: NOT successful!"
log.warning(msg)
raise Warning(msg)
except Exception as e:
trb = traceback.format_exc()
return t[:i], [xa[:i] for xa in x], (e, trb)
log.info("integration summary\n"+
"integration time {:.2g}s ({:.2%})\n".format(t_int, t_int / (t_int + t_conv))+
" f_dot eval {:.2g}s ({:.2%})\n".format(time_as_list[0], time_as_list[0] / (t_int + t_conv))+
"data conversion time {:.2g}s ({:.2%})\n".format(t_conv, t_conv / (t_int + t_conv)))
return t, x, None
def integrate_real(c, t0, t1, N, f, args, x0, integrator, verbose=0, res_dim=None, x_to_res=None, **kwargs):
f_partial = lambda t, x: f(t, x, *args)
if integrator == 'zvode':
# define complex derivative
raise RuntimeError("'zvode' can not be used for real integration")
elif (integrator == 'vode') | (integrator == 'lsoda'):
pass
else:
raise RuntimeError("unknown integrator '{}'".format(integrator))
r = ode(f_partial)
kws = list(kwargs.keys())
for kw in kws:
if kwargs[kw] is None:
del kwargs[kw]
r.set_integrator(integrator, **kwargs)
# x0_ might be the mapping from C to R^2
r.set_initial_value(x0, t0)
t = np.linspace(t0, t1, N)
if res_dim is None:
res_dim = (len(x0), )
if x_to_res is None:
x_to_res = lambda t_, x_: x_
# float array for result
x = np.empty(shape=(N,) + res_dim, dtype=np.float64)
x[0] = x_to_res(t0, x0)
t_int = 0
t_conv = 0
i = 1
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
while r.successful() and i < N:
_t = time()
r.integrate(t[i])
t_int += (time()-_t)
_t = time()
x[i] = x_to_res(r.t, r.y)
t_conv += (time()-_t)
if abs(t[i]-r.t) > 1e-13:
msg = "INTEGRATION WARNING: time mismatch (diff at step {}: {:.3e})".format(i, abs(t[i]-r.t))
log.warning(msg)
raise Warning(msg)
t[i] = r.t
c.value = i
i += 1
if not r.successful():
msg = "INTEGRATION WARNING: NOT successful!"
log.warning(msg)
raise Warning(msg)
except Exception as e:
trb = traceback.format_exc()
return t[:i], x[:i], (e, trb)
log.info("integration summary\n"+
"integration time {:.2g}s ({:.2%})\n".format(t_int, t_int / (t_int + t_conv))+
"data conversion time {:.2g}s ({:.2%})\n".format(t_conv, t_conv / (t_int + t_conv)))
return t, x, None