Merge pull request #344 from p15-git-acc/j-loop-dot

in platt multieval break the j loop into blocks and use dot product
This commit is contained in:
Fredrik Johansson 2020-09-27 21:55:56 +02:00 committed by GitHub
commit adba6336ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -13,6 +13,36 @@
#include "arb_hypgeom.h"
#include "acb_dft.h"
static void
_acb_dot_arb(acb_t res, const acb_t initial, int subtract,
acb_srcptr x, slong xstep, arb_srcptr y, slong ystep,
slong len, slong prec)
{
arb_ptr a;
arb_srcptr b, c;
if (sizeof(acb_struct) != 2*sizeof(arb_struct))
{
flint_printf("expected sizeof(acb_struct)=%ld "
"to be twice sizeof(arb_struct)=%ld\n",
sizeof(acb_struct), sizeof(arb_struct));
flint_abort();
}
if (initial == NULL)
{
flint_printf("not implemented for NULL initial value\n");
flint_abort();
}
a = acb_realref(res);
b = acb_realref(initial);
c = acb_realref(x);
arb_dot(a, b, subtract, c, xstep*2, y, ystep, len, prec);
a = acb_imagref(res);
b = acb_imagref(initial);
c = acb_imagref(x);
arb_dot(a, b, subtract, c, xstep*2, y, ystep, len, prec);
}
static void
_arb_add_d(arb_t z, const arb_t x, double d, slong prec)
@ -24,6 +54,13 @@ _arb_add_d(arb_t z, const arb_t x, double d, slong prec)
arb_clear(u);
}
static void
_arb_div_si_si(arb_t z, slong a, slong b, slong prec)
{
arb_set_si(z, a);
arb_div_si(z, z, b, prec);
}
static void
_arb_inv_si(arb_t z, slong n, slong prec)
{
@ -256,6 +293,67 @@ platt_get_smk_index(slong B, slong j, slong prec)
return m;
}
typedef struct
{
slong bmax;
slong b;
slong K;
arb_ptr M; /* (b, K) */
acb_ptr v; /* (b, ) */
}
smk_block_struct;
typedef smk_block_struct smk_block_t[1];
static void
smk_block_init(smk_block_t p, slong K, slong bmax)
{
p->bmax = bmax;
p->b = 0;
p->K = K;
p->M = _arb_vec_init(K*bmax);
p->v = _acb_vec_init(bmax);
}
static void
smk_block_clear(smk_block_t p)
{
_arb_vec_clear(p->M, p->K * p->bmax);
_acb_vec_clear(p->v, p->bmax);
}
static int
smk_block_is_full(smk_block_t p)
{
return p->b == p->bmax;
}
static void
smk_block_reset(smk_block_t p)
{
p->b = 0;
}
static void
smk_block_increment(smk_block_t p, const acb_t z, arb_srcptr v)
{
if (smk_block_is_full(p))
{
flint_printf("trying to increment a full block\n");
flint_abort();
}
acb_set(p->v + p->b, z);
_arb_vec_set(p->M + p->K * p->b, v, p->K);
p->b += 1;
}
static void
smk_block_accumulate(smk_block_t p, acb_ptr res, slong prec)
{
slong i;
for (i = 0; i < p->K; i++)
_acb_dot_arb(res + i, res + i, 0, p->v, 1, p->M + i, p->K, p->b, prec);
}
void
_platt_smk(acb_ptr table, acb_ptr startvec, acb_ptr stopvec,
const slong * smk_points, const arb_t t0, slong A, slong B,
@ -264,28 +362,35 @@ _platt_smk(acb_ptr table, acb_ptr startvec, acb_ptr stopvec,
{
slong j, k, m;
slong N = A * B;
smk_block_t block;
acb_ptr accum;
arb_ptr diff_powers;
arb_t rpi, rsqrtj, um, a, base;
arb_t rpi, logsqrtpi, rsqrtj, um, a, base;
acb_t z;
arb_init(rpi);
arb_init(logsqrtpi);
arb_init(rsqrtj);
arb_init(um);
arb_init(a);
arb_init(base);
acb_init(z);
smk_block_init(block, K, 32);
diff_powers = _arb_vec_init(K);
accum = _acb_vec_init(K);
arb_const_pi(rpi, prec);
arb_inv(rpi, rpi, prec);
arb_const_sqrt_pi(logsqrtpi, prec);
arb_log(logsqrtpi, logsqrtpi, prec);
m = platt_get_smk_index(B, jstart, prec);
_arb_div_si_si(um, m, B, prec);
for (j = jstart; j <= jstop; j++)
{
logjsqrtpi(a, j, prec);
arb_log_ui(a, (ulong) j, prec);
arb_add(a, a, logsqrtpi, prec);
arb_mul(a, a, rpi, prec);
arb_rsqrt_ui(rsqrtj, (ulong) j, prec);
@ -297,7 +402,10 @@ _platt_smk(acb_ptr table, acb_ptr startvec, acb_ptr stopvec,
acb_mul_arb(z, z, rsqrtj, prec);
while (m < N - 1 && smk_points[m + 1] <= j)
{
m += 1;
_arb_div_si_si(um, m, B, prec);
}
if (m < mstart || m > mstop)
{
@ -306,42 +414,48 @@ _platt_smk(acb_ptr table, acb_ptr startvec, acb_ptr stopvec,
flint_abort();
}
arb_set_si(um, m);
arb_div_si(um, um, B, prec);
arb_mul_2exp_si(base, a, -1);
arb_sub(base, base, um, prec);
_arb_vec_set_powers(diff_powers, base, K, prec);
smk_block_increment(block, z, diff_powers);
for (k = 0; k < K; k++)
acb_addmul_arb(accum + k, z, diff_powers + k, prec);
if (j == jstop || (m < N - 1 && smk_points[m + 1] <= j + 1))
{
if (startvec && m == mstart)
int j_stops = j == jstop;
int m_increases = m < N - 1 && smk_points[m + 1] <= j + 1;
if (j_stops || m_increases || smk_block_is_full(block))
{
_acb_vec_set(startvec, accum, K);
smk_block_accumulate(block, accum, prec);
smk_block_reset(block);
}
else if (stopvec && m == mstop)
if (j_stops || m_increases)
{
_acb_vec_set(stopvec, accum, K);
if (startvec && m == mstart)
{
_acb_vec_set(startvec, accum, K);
}
else if (stopvec && m == mstop)
{
_acb_vec_set(stopvec, accum, K);
}
else
{
for (k = 0; k < K; k++)
acb_set(table + N*k + m, accum + k);
}
_acb_vec_zero(accum, K);
}
else
{
for (k = 0; k < K; k++)
acb_set(table + N*k + m, accum + k);
}
_acb_vec_zero(accum, K);
}
}
arb_clear(rpi);
arb_clear(logsqrtpi);
arb_clear(rsqrtj);
arb_clear(um);
arb_clear(a);
arb_clear(base);
acb_clear(z);
smk_block_clear(block);
_arb_vec_clear(diff_powers, K);
_acb_vec_clear(accum, K);
}