multithreaded numerical integration

This commit is contained in:
Fredrik Johansson 2022-05-16 14:20:40 +02:00
parent bccb6ce632
commit 1c618abfae
4 changed files with 161 additions and 28 deletions

View file

@ -63,10 +63,28 @@ void gl_init()
}
/* Compute GL node and weight of index k for n = gl_steps[i]. Cached. */
void
acb_calc_gl_node(arb_t x, arb_t w, slong i, slong k, slong prec)
typedef struct
{
slong n, kk, jj, wp;
arb_ptr nodes;
arb_ptr weights;
slong n;
slong wp;
}
nodes_work_t;
static void
nodes_worker(slong jj, nodes_work_t * work)
{
arb_hypgeom_legendre_p_ui_root(work->nodes + jj, work->weights + jj, work->n, jj, work->wp);
}
/* if k >= 0, compute the node and weight of index k */
/* if k < 0, compute the first (n+1)/2 nodes and weights (the others are given by symmetry) */
void
acb_calc_gl_node(arb_ptr x, arb_ptr w, slong i, slong k, slong prec)
{
slong n, kk, wp;
int all;
if (i < 0 || i >= GL_STEPS || prec < 2)
flint_abort();
@ -76,16 +94,15 @@ acb_calc_gl_node(arb_t x, arb_t w, slong i, slong k, slong prec)
n = gl_steps[i];
if (k < 0 || k >= n)
if (k >= n)
flint_abort();
if (2 * k < n)
kk = k;
else
kk = n - 1 - k;
all = (k < 0);
if (gl_cache->gl_prec[i] < prec)
{
nodes_work_t work;
if (gl_cache->gl_prec[i] == 0)
{
gl_cache->gl_nodes[i] = _arb_vec_init((n + 1) / 2);
@ -94,21 +111,86 @@ acb_calc_gl_node(arb_t x, arb_t w, slong i, slong k, slong prec)
wp = FLINT_MAX(prec, gl_cache->gl_prec[i] * 2 + 30);
for (jj = 0; 2 * jj < n; jj++)
{
arb_hypgeom_legendre_p_ui_root(gl_cache->gl_nodes[i] + jj,
gl_cache->gl_weights[i] + jj, n, jj, wp);
}
work.nodes = gl_cache->gl_nodes[i];
work.weights = gl_cache->gl_weights[i];
work.n = n;
work.wp = wp;
flint_parallel_do((do_func_t) nodes_worker, &work, (n + 1) / 2, -1, FLINT_PARALLEL_STRIDED);
gl_cache->gl_prec[i] = wp;
}
if (2 * k < n)
arb_set_round(x, gl_cache->gl_nodes[i] + kk, prec);
if (all)
{
for (k = 0; k < (n + 1) / 2; k++)
{
arb_set_round(x + k, gl_cache->gl_nodes[i] + k, prec);
arb_set_round(w + k, gl_cache->gl_weights[i] + k, prec);
}
}
else
arb_neg_round(x, gl_cache->gl_nodes[i] + kk, prec);
{
if (2 * k < n)
kk = k;
else
kk = n - 1 - k;
arb_set_round(w, gl_cache->gl_weights[i] + kk, prec);
if (2 * k < n)
arb_set_round(x, gl_cache->gl_nodes[i] + kk, prec);
else
arb_neg_round(x, gl_cache->gl_nodes[i] + kk, prec);
arb_set_round(w, gl_cache->gl_weights[i] + kk, prec);
}
}
typedef struct
{
slong n;
slong prec;
arb_srcptr x;
arb_srcptr w;
acb_srcptr delta;
acb_srcptr mid;
acb_ptr v;
acb_calc_func_t f;
void * param;
}
gl_work_t;
static void
gl_worker(slong k, gl_work_t * args)
{
arb_t x, w;
acb_t t;
slong k2;
slong prec = args->prec;
slong n = args->n;
acb_ptr v = args->v;
arb_init(x);
arb_init(w);
acb_init(t);
if (2 * k < n)
k2 = k;
else
k2 = n - 1 - k;
acb_mul_arb(t, args->delta, args->x + k2, prec);
if (k2 != k)
acb_neg(t, t);
acb_add(t, t, args->mid, prec);
args->f(v + k, t, args->param, 0, prec);
acb_mul_arb(v + k, v + k, args->w + k2, prec);
arb_clear(x);
arb_clear(w);
acb_clear(t);
}
int
@ -236,6 +318,7 @@ acb_calc_integrate_gl_auto_deg(acb_t res, slong * eval_count,
/* Evaluate best found Gauss-Legendre quadrature rule. */
if (status == ARB_CALC_SUCCESS)
{
slong nt;
arb_t x, w;
arb_init(x);
arb_init(w);
@ -259,15 +342,52 @@ acb_calc_integrate_gl_auto_deg(acb_t res, slong * eval_count,
if (gl_steps[i] == best_n)
break;
acb_zero(s);
nt = flint_get_num_threads();
for (k = 0; k < best_n; k++)
if (nt >= 2 && best_n >= 2)
{
acb_calc_gl_node(x, w, i, k, prec);
acb_mul_arb(wide, delta, x, prec);
acb_add(wide, wide, mid, prec);
f(v, wide, param, 0, prec);
acb_addmul_arb(s, v, w, prec);
gl_work_t work;
acb_ptr v;
arb_ptr x, w;
v = _acb_vec_init(best_n);
w = _arb_vec_init((best_n + 1) / 2);
x = _arb_vec_init((best_n + 1) / 2);
acb_calc_gl_node(x, w, i, -1, prec);
work.n = best_n;
work.x = x;
work.w = w;
work.prec = prec;
work.delta = delta;
work.mid = mid;
work.v = v;
work.f = f;
work.param = param;
flint_parallel_do((do_func_t) gl_worker, &work, best_n, -1, FLINT_PARALLEL_STRIDED);
acb_add(s, v, v + 1, prec);
for (k = 2; k < best_n; k++)
acb_add(s, s, v + k, prec);
_acb_vec_clear(v, best_n);
_arb_vec_clear(x, (best_n + 1) / 2);
_arb_vec_clear(w, (best_n + 1) / 2);
}
else
{
acb_zero(s);
for (k = 0; k < best_n; k++)
{
acb_calc_gl_node(x, w, i, k, prec);
acb_mul_arb(wide, delta, x, prec);
acb_add(wide, wide, mid, prec);
f(v, wide, param, 0, prec);
acb_addmul_arb(s, v, w, prec);
}
}
eval_count[0] += best_n;

View file

@ -351,6 +351,8 @@ int main()
acb_calc_integrate_opt_t opt;
int integral;
flint_set_num_threads(1 + n_randint(state, 3));
acb_init(ans);
acb_init(res);
acb_init(a);
@ -712,7 +714,7 @@ int main()
}
flint_randclear(state);
flint_cleanup();
flint_cleanup_master();
flint_printf("PASS\n");
return EXIT_SUCCESS;
}

View file

@ -655,6 +655,7 @@ Invoking the program without parameters shows usage::
-deg n - use quadrature degree up to n
-eval n - limit number of function evaluations to n
-depth n - limit subinterval queue size to n
-threads n - use parallel computation with n threads
Implemented integrals:
I0 = int_0^100 sin(x) dx

View file

@ -780,6 +780,7 @@ int main(int argc, char *argv[])
{
acb_t s, t, a, b;
mag_t tol;
slong num_threads;
slong prec, goal;
slong N;
ulong k;
@ -821,7 +822,8 @@ int main(int argc, char *argv[])
flint_printf("-verbose2 - show more information\n");
flint_printf("-deg n - use quadrature degree up to n\n");
flint_printf("-eval n - limit number of function evaluations to n\n");
flint_printf("-depth n - limit subinterval queue size to n\n\n");
flint_printf("-depth n - limit subinterval queue size to n\n");
flint_printf("-threads n - use parallel computation with n threads\n\n");
flint_printf("Implemented integrals:\n");
for (integral = 0; integral < NUM_INTEGRALS; integral++)
flint_printf("I%d = %s\n", integral, descr[integral]);
@ -835,6 +837,7 @@ int main(int argc, char *argv[])
twice = 0;
goal = 0;
havetol = havegoal = 0;
num_threads = 1;
acb_init(a);
acb_init(b);
@ -895,8 +898,15 @@ int main(int argc, char *argv[])
{
options->use_heap = 1;
}
else if (!strcmp(argv[i], "-threads"))
{
num_threads = atol(argv[i+1]);
}
}
if (num_threads >= 2)
flint_set_num_threads(num_threads);
if (!havegoal)
goal = prec;
@ -1276,7 +1286,7 @@ int main(int argc, char *argv[])
acb_clear(t);
mag_clear(tol);
flint_cleanup();
flint_cleanup_master();
return 0;
}