From 66516422e062759759e5c6d96445d234f2010536 Mon Sep 17 00:00:00 2001 From: Richard Hartmann Date: Mon, 4 Nov 2019 13:16:09 +0100 Subject: [PATCH] added the possibility to scale the state vector while integrating --- jobmanager/ode_wrapper.py | 58 +++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/jobmanager/ode_wrapper.py b/jobmanager/ode_wrapper.py index 4528c38..295f610 100644 --- a/jobmanager/ode_wrapper.py +++ b/jobmanager/ode_wrapper.py @@ -1,11 +1,9 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- import numpy as np import warnings from time import time import logging -import sys import traceback +import copy log = logging.getLogger(__name__) @@ -54,7 +52,7 @@ def timed_f(f, time_as_list): return new_f -def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res=None, scale_function = None, **kwargs): +def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res=None, scale_function=None, **kwargs): f_partial_complex = lambda t, x: f(t, x, *args) if integrator == 'zvode': # define complex derivative @@ -120,22 +118,24 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res with warnings.catch_warnings(): warnings.filterwarnings('error') try: - while r.successful() and i < N: + while i < N: _t = time() r.integrate(t[i]) t_int += (time()-_t) - if scale_function: - sc = scale_function(r.y) - r.y /= sc + if not r.successful(): + msg = "INTEGRATION WARNING: NOT successful!" + log.warning(msg) + raise Warning(msg) _t = time() + r_y = copy.copy(r.y) if integrator == 'zvode': # complex integration -> yields complex values - x[i] = x_to_res(r.t, r.y) + x[i] = x_to_res(r.t, r_y) else: # real integration -> mapping from R^2 to C needed - x[i] = x_to_res(r.t, real_to_complex(r.y)) + x[i] = x_to_res(r.t, real_to_complex(r_y)) t_conv += (time()-_t) if abs(t[i]-r.t) > 1e-13: @@ -144,12 +144,16 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res raise Warning(msg) t[i] = r.t c.value = i + + if scale_function: + del r + r = ode(f__) + r.set_integrator(integrator, **kwargs) + r.set_initial_value(y=scale_function(copy.copy(r_y)), t=t[i]) + i += 1 - - if not r.successful(): - msg = "INTEGRATION WARNING: NOT successful!" - log.warning(msg) - raise Warning(msg) + + except Exception as e: trb = traceback.format_exc() return t[:i], x[:i], (e, trb) @@ -176,19 +180,21 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res r.integrate(t[i]) t_int += (time()-_t) - if scale_function: - sc = scale_function(r.y) - r.y /= sc + if not r.successful(): + msg = "INTEGRATION WARNING: NOT successful!" + log.warning(msg) + raise Warning(msg) _t = time() + r_y = copy.copy(r.y) if integrator == 'zvode': # complex integration -> yields complex values for a in range(res_list_len): - x[a][i] = x_to_res[a](r.t, r.y) + x[a][i] = x_to_res[a](r.t, r_y) else: # real integration -> mapping from R^2 to C needed for a in range(res_list_len): - x[a][i] = x_to_res[a](r.t, real_to_complex(r.y)) + x[a][i] = x_to_res[a](r.t, real_to_complex(r_y)) t_conv += (time()-_t) if abs(t[i]-r.t) > 1e-13: msg = "INTEGRATION WARNING: time mismatch (diff at step {}: {:.3e})".format(i, abs(t[i]-r.t)) @@ -196,12 +202,16 @@ def integrate_cplx(c, t0, t1, N, f, args, x0, integrator, res_dim=None, x_to_res raise Warning(msg) t[i] = r.t c.value = i + + if scale_function: + del r + r = ode(f__) + r.set_integrator(integrator, **kwargs) + r.set_initial_value(y=scale_function(copy.copy(r_y)), t=t[i]) + i += 1 - if not r.successful(): - msg = "INTEGRATION WARNING: NOT successful!" - log.warning(msg) - raise Warning(msg) + except Exception as e: trb = traceback.format_exc() return t[:i], [xa[:i] for xa in x], (e, trb)