diff --git a/ode_wrapper.py b/ode_wrapper.py index cd1a748..083121b 100644 --- a/ode_wrapper.py +++ b/ode_wrapper.py @@ -30,7 +30,7 @@ def wrap_complex_intgeration(f_complex): return f_real -def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): +def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, verbose=0, res_dim=None, x_to_res=None, **kwargs): f_partial_complex = lambda t, x: f(t, x, *args) if integrator == 'zvode': # define complex derivative @@ -53,9 +53,15 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): t = np.linspace(t0, t1, N) + if res_dim is None: + res_dim = (len(x0), ) + + if x_to_res is None: + x_to_res = lambda x: x + # complex array for result - x = np.empty(shape=(N, len(x0)), dtype=np.complex128) - x[0] = x0 + x = np.empty(shape=(N,) + res_dim, dtype=np.complex128) + x[0] = x_to_res(x0) # print(args.eta._Z) @@ -64,10 +70,10 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): r.integrate(t[i]) if integrator == 'zvode': # complex integration -> yields complex values - x[i] = r.y + x[i] = x_to_res(r.y) else: # real integration -> mapping from R^2 to C needed - x[i] = real_to_complex(r.y) + x[i] = x_to_res(real_to_complex(r.y)) t[i] = r.t c.value = i @@ -78,7 +84,7 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): return t, x -def integrate_real(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): +def integrate_real(c, t0, t1, N, f, args, x0, integrator, verbose=0, res_dim=None, x_to_res=None, **kwargs): f_partial = lambda t, x: f(t, x, *args) if integrator == 'zvode': @@ -97,14 +103,20 @@ def integrate_real(c, t0, t1, N, f, args, x0, integrator, verbose=0, **kwargs): t = np.linspace(t0, t1, N) - # complex array for result - x = np.empty(shape=(N, len(x0)), dtype=np.float64) - x[0] = x0 + if res_dim is None: + res_dim = (len(x0), ) + + if x_to_res is None: + x_to_res = lambda x: x + + # float array for result + x = np.empty(shape=(N,) + res_dim, dtype=np.float64) + x[0] = x_to_res(x0) i = 1 while r.successful() and i < N: r.integrate(t[i]) - x[i] = r.y + x[i] = x_to_res(r.y) t[i] = r.t c.value = i i += 1