diff --git a/acb.h b/acb.h index 7ea59648..e30dd6fc 100644 --- a/acb.h +++ b/acb.h @@ -644,6 +644,14 @@ acb_submul_arb(acb_t z, const acb_t x, const arb_t y, slong prec) arb_submul(acb_imagref(z), acb_imagref(x), y, prec); } +void acb_dot_simple(acb_t res, const acb_t initial, int subtract, + acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec); +void acb_dot_precise(acb_t res, const acb_t initial, int subtract, + acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec); +void acb_dot(acb_t res, const acb_t initial, int subtract, + acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec); + + void acb_inv(acb_t z, const acb_t x, slong prec); void acb_div(acb_t z, const acb_t x, const acb_t y, slong prec); diff --git a/acb/dot.c b/acb/dot.c new file mode 100644 index 00000000..27b0d971 --- /dev/null +++ b/acb/dot.c @@ -0,0 +1,1008 @@ +/* + Copyright (C) 2018 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "acb.h" + +/* We need uint64_t instead of mp_limb_t on 32-bit systems for + safe summation of 30-bit error bounds. */ +#include + +/* The following macros are found in FLINT's longlong.h, but + the release version is out of date. */ + +/* x86 : 64 bit */ +#if (GMP_LIMB_BITS == 64 && defined (__amd64__)) + +#define add_sssaaaaaa2(sh, sm, sl, ah, am, al, bh, bm, bl) \ + __asm__ ("addq %8,%q2\n\tadcq %6,%q1\n\tadcq %4,%q0" \ + : "=r" (sh), "=&r" (sm), "=&r" (sl) \ + : "0" ((mp_limb_t)(ah)), "rme" ((mp_limb_t)(bh)), \ + "1" ((mp_limb_t)(am)), "rme" ((mp_limb_t)(bm)), \ + "2" ((mp_limb_t)(al)), "rme" ((mp_limb_t)(bl))) \ + +#define sub_dddmmmsss2(dh, dm, dl, mh, mm, ml, sh, sm, sl) \ + __asm__ ("subq %8,%q2\n\tsbbq %6,%q1\n\tsbbq %4,%q0" \ + : "=r" (dh), "=&r" (dm), "=&r" (dl) \ + : "0" ((mp_limb_t)(mh)), "rme" ((mp_limb_t)(sh)), \ + "1" ((mp_limb_t)(mm)), "rme" ((mp_limb_t)(sm)), \ +"2" ((mp_limb_t)(ml)), "rme" ((mp_limb_t)(sl))) \ + +#endif /* x86_64 */ + +/* x86 : 32 bit */ +#if (GMP_LIMB_BITS == 32 && (defined (__i386__) \ + || defined (__i486__) || defined(__amd64__))) + +#define add_sssaaaaaa2(sh, sm, sl, ah, am, al, bh, bm, bl) \ + __asm__ ("addl %8,%k2\n\tadcl %6,%k1\n\tadcl %4,%k0" \ + : "=r" (sh), "=r" (sm), "=&r" (sl) \ + : "0" ((mp_limb_t)(ah)), "g" ((mp_limb_t)(bh)), \ + "1" ((mp_limb_t)(am)), "g" ((mp_limb_t)(bm)), \ + "2" ((mp_limb_t)(al)), "g" ((mp_limb_t)(bl))) \ + +#define sub_dddmmmsss2(dh, dm, dl, mh, mm, ml, sh, sm, sl) \ + __asm__ ("subl %8,%k2\n\tsbbl %6,%k1\n\tsbbl %4,%k0" \ + : "=r" (dh), "=r" (dm), "=&r" (dl) \ + : "0" ((mp_limb_t)(mh)), "g" ((mp_limb_t)(sh)), \ + "1" ((mp_limb_t)(mm)), "g" ((mp_limb_t)(sm)), \ + "2" ((mp_limb_t)(ml)), "g" ((mp_limb_t)(sl))) \ + +#endif /* x86 */ + + +#if !defined(add_sssaaaaaa2) + +#define add_sssaaaaaa2(sh, sm, sl, ah, am, al, bh, bm, bl) \ + do { \ + mp_limb_t __t, __u; \ + add_ssaaaa(__t, sl, (mp_limb_t) 0, al, (mp_limb_t) 0, bl); \ + add_ssaaaa(__u, sm, (mp_limb_t) 0, am, (mp_limb_t) 0, bm); \ + add_ssaaaa(sh, sm, ah + bh, sm, __u, __t); \ +} while (0) + +#define sub_dddmmmsss2(dh, dm, dl, mh, mm, ml, sh, sm, sl) \ + do { \ + mp_limb_t __t, __u; \ + sub_ddmmss(__t, dl, (mp_limb_t) 0, ml, (mp_limb_t) 0, sl); \ + sub_ddmmss(__u, dm, (mp_limb_t) 0, mm, (mp_limb_t) 0, sm); \ + sub_ddmmss(dh, dm, mh - sh, dm, __u, __t); \ + } while (0) + +#endif + + +/* Add ((a * b) / 2^MAG_BITS) * 2^exp into srad*2^srad_exp. + Assumes that srad_exp >= exp and that overflow cannot occur. */ +#define RAD_ADDMUL(srad, srad_exp, a, b, exp) \ + do { \ + uint64_t __a, __b; \ + slong __shift; \ + __a = (a); \ + __b = (b); \ + __shift = (srad_exp) - (exp); \ + if (__shift < MAG_BITS) \ + (srad) += (((__a) * (__b)) >> (MAG_BITS + __shift)) + 1; \ + else \ + (srad) += 1; \ + } while (0) + +void mag_set_ui_2exp_small(mag_t z, ulong x, slong e); + +static void +add_errors(mag_t rad, uint64_t Aerr, slong Aexp, uint64_t Berr, slong Bexp, uint64_t Cerr, slong Cexp) +{ + slong shift; + + if (Aerr && Berr) + { + if (Aexp >= Bexp) + { + shift = Aexp - Bexp; + if (shift < 64) + Aerr = Aerr + (Berr >> shift) + 1; + else + Aerr = Aerr + 1; + } + else + { + shift = Bexp - Aexp; + if (shift < 64) + Aerr = Berr + (Aerr >> shift) + 1; + else + Aerr = Berr + 1; + Aexp = Bexp; + } + } + else if (Berr) + { + Aerr = Berr; + Aexp = Bexp; + } + + if (Aerr && Cerr) + { + if (Aexp >= Cexp) + { + shift = Aexp - Cexp; + if (shift < 64) + Aerr = Aerr + (Cerr >> shift) + 1; + else + Aerr = Aerr + 1; + } + else + { + shift = Cexp - Aexp; + if (shift < 64) + Aerr = Cerr + (Aerr >> shift) + 1; + else + Aerr = Cerr + 1; + Aexp = Cexp; + } + } + else if (Cerr) + { + Aerr = Cerr; + Aexp = Cexp; + } + +#if FLINT_BITS == 64 + mag_set_ui_2exp_small(rad, Aerr, Aexp - MAG_BITS); +#else + mag_set_d(rad, Aerr * (1.0 + 1e-14)); + mag_mul_2exp_si(rad, rad, Aexp - MAG_BITS); +#endif +} + +void +_arb_dot_addmul_generic(mp_ptr sum, mp_ptr serr, mp_ptr tmp, mp_size_t sn, + mp_srcptr xptr, mp_size_t xn, mp_srcptr yptr, mp_size_t yn, + int negative, mp_bitcnt_t shift); + +void +_arb_dot_add_generic(mp_ptr sum, mp_ptr serr, mp_ptr tmp, mp_size_t sn, + mp_srcptr xptr, mp_size_t xn, + int negative, mp_bitcnt_t shift); + +static void +_arb_dot_output(arb_t res, mp_ptr sum, mp_size_t sn, int negative, + uint64_t serr, slong sum_exp, uint64_t srad, slong srad_exp, slong prec) +{ + slong exp_fix; + int inexact; + + if (sum[sn - 1] >= LIMB_TOP) + { + mpn_neg(sum, sum, sn); + negative ^= 1; + } + + exp_fix = 0; + + if (sum[sn - 1] == 0) + { + slong sum_exp2; + mp_size_t sn2; + + sn2 = sn; + sum_exp2 = sum_exp; + + while (sn2 > 0 && sum[sn2 - 1] == 0) + { + sum_exp2 -= FLINT_BITS; + sn2--; + } + + if (sn2 == 0) + { + arf_zero(arb_midref(res)); + inexact = 0; + } + else + { + inexact = _arf_set_round_mpn(arb_midref(res), &exp_fix, sum, sn2, negative, prec, ARF_RND_DOWN); + _fmpz_set_si_small(ARF_EXPREF(arb_midref(res)), exp_fix + sum_exp2); + } + } + else + { + if (sn == 2) /* unnecessary? */ + inexact = _arf_set_round_uiui(arb_midref(res), &exp_fix, sum[1], sum[0], negative, prec, ARF_RND_DOWN); + else + inexact = _arf_set_round_mpn(arb_midref(res), &exp_fix, sum, sn, negative, prec, ARF_RND_DOWN); + + _fmpz_set_si_small(ARF_EXPREF(arb_midref(res)), exp_fix + sum_exp); + } + + add_errors(arb_radref(res), + inexact << MAG_BITS, + exp_fix + sum_exp - prec, + ((uint64_t) serr) << MAG_BITS, + sum_exp - sn * FLINT_BITS, + srad, + srad_exp); +} + +/* xxx: don't use surrounding variables */ +#define ARB_DOT_ADD(s_sum, s_serr, s_sn, s_sum_exp, s_subtract, xm) \ + if (!arf_is_special(xm)) \ + { \ + mp_srcptr xptr; \ + xexp = ARF_EXP(xm); \ + xn = ARF_SIZE(xm); \ + xnegative = ARF_SGNBIT(xm); \ + shift = s_sum_exp - xexp; \ + if (shift >= s_sn * FLINT_BITS) \ + { \ + s_serr++; \ + } \ + else \ + { \ + xptr = (xn <= ARF_NOPTR_LIMBS) ? ARF_NOPTR_D(xm) : ARF_PTR_D(xm); \ + _arb_dot_add_generic(s_sum, &s_serr, tmp, s_sn, xptr, xn, xnegative ^ s_subtract, shift); \ + } \ + } \ + +/* xxx: don't use surrounding variables */ +#define ARB_DOT_ADD_RAD(s_srad, s_srad_exp, xr) \ + if (!mag_is_special(xr)) \ + { \ + xrad = MAG_MAN(xr); \ + xrexp = MAG_EXP(xr); \ + shift = s_srad_exp - xrexp; \ + if (shift < 64) \ + s_srad += (xrad >> shift) + 1; \ + else \ + s_srad++; \ + } + +static void +_arf_complex_mul_gauss(arf_t e, arf_t f, const arf_t a, const arf_t b, + const arf_t c, const arf_t d) +{ + mp_srcptr ap, bp, cp, dp; + int asgn, bsgn, csgn, dsgn; + mp_size_t an, bn, cn, dn; + slong aexp, bexp, cexp, dexp; + fmpz texp, uexp; + + fmpz_t za, zb, zc, zd, t, u, v; + slong abot, bbot, cbot, dbot; + + ARF_GET_MPN_READONLY(ap, an, a); + asgn = ARF_SGNBIT(a); + aexp = ARF_EXP(a); + + ARF_GET_MPN_READONLY(bp, bn, b); + bsgn = ARF_SGNBIT(b); + bexp = ARF_EXP(b); + + ARF_GET_MPN_READONLY(cp, cn, c); + csgn = ARF_SGNBIT(c); + cexp = ARF_EXP(c); + + ARF_GET_MPN_READONLY(dp, dn, d); + dsgn = ARF_SGNBIT(d); + dexp = ARF_EXP(d); + + /* Gauss multiplication + e = ac - bd + f = (a+b)(c+d) - ac - bd */ + + abot = aexp - an * FLINT_BITS; + bbot = bexp - bn * FLINT_BITS; + cbot = cexp - cn * FLINT_BITS; + dbot = dexp - dn * FLINT_BITS; + + texp = FLINT_MIN(abot, bbot); + uexp = FLINT_MIN(cbot, dbot); + + fmpz_init(za); + fmpz_init(zb); + fmpz_init(zc); + fmpz_init(zd); + fmpz_init(t); + fmpz_init(u); + fmpz_init(v); + + fmpz_lshift_mpn(za, ap, an, asgn, abot - texp); + fmpz_lshift_mpn(zb, bp, bn, bsgn, bbot - texp); + fmpz_lshift_mpn(zc, cp, cn, csgn, cbot - uexp); + fmpz_lshift_mpn(zd, dp, dn, dsgn, dbot - uexp); + + fmpz_add(t, za, zb); + fmpz_add(v, zc, zd); + fmpz_mul(u, t, v); + fmpz_mul(t, za, zc); + fmpz_mul(v, zb, zd); + fmpz_sub(u, u, t); + fmpz_sub(u, u, v); + fmpz_sub(t, t, v); + + texp += uexp; + arf_set_fmpz_2exp(e, t, &texp); + arf_set_fmpz_2exp(f, u, &texp); + + fmpz_clear(za); + fmpz_clear(zb); + fmpz_clear(zc); + fmpz_clear(zd); + fmpz_clear(t); + fmpz_clear(u); + fmpz_clear(v); +} + +/* TODO: this could be much lower, but it's currently competing + against mulhigh in the Karatsuba range. */ +ARB_DLL slong acb_dot_gauss_dot_cutoff = 128; +#define GAUSS_CUTOFF acb_dot_gauss_dot_cutoff + +void +acb_dot(acb_t res, const acb_t initial, int subtract, acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec) +{ + slong i, j, padding, extend; + slong xexp, yexp, exp; + slong re_nonzero, im_nonzero; + slong re_max_exp, re_min_exp, re_sum_exp; + slong im_max_exp, im_min_exp, im_sum_exp; + slong re_srad_exp, re_max_rad_exp; + slong im_srad_exp, im_max_rad_exp; + slong re_prec, im_prec; + slong xrexp, yrexp; + int xnegative, ynegative; + mp_size_t xn, yn, re_sn, im_sn, alloc; + mp_bitcnt_t shift; + arb_srcptr xi, yi; + arf_srcptr xm, ym; + mag_srcptr xr, yr; + mp_limb_t xtop, ytop; + mp_limb_t xrad, yrad; + mp_limb_t re_serr, im_serr; /* Sum over arithmetic errors */ + uint64_t re_srad, im_srad; /* Sum over propagated errors */ + mp_ptr tmp, re_sum, im_sum; /* Workspace */ + slong xoff, yoff; + char * use_gauss; + ARF_ADD_TMP_DECL; + + /* todo: fast fma and fmma (len=2) code */ + if (len <= 1) + { + if (initial == NULL) + { + if (len <= 0) + acb_zero(res); + else + { + acb_mul(res, x, y, prec); + if (subtract) + acb_neg(res, res); + } + return; + } + else if (len <= 0) + { + acb_set_round(res, initial, prec); + return; + } + } + + /* Number of nonzero midpoint terms in sum. */ + re_nonzero = 0; + im_nonzero = 0; + + /* Terms are bounded by 2^max_exp (with WORD_MIN = -infty) */ + re_max_exp = WORD_MIN; + im_max_exp = WORD_MIN; + + /* Propagated error terms are bounded by 2^max_rad_exp */ + re_max_rad_exp = WORD_MIN; + im_max_rad_exp = WORD_MIN; + + /* Used to reduce the precision. */ + re_min_exp = WORD_MAX; + im_min_exp = WORD_MAX; + + /* Account for the initial term. */ + if (initial != NULL) + { + if (!ARB_IS_LAGOM(acb_realref(initial)) || !ARB_IS_LAGOM(acb_imagref(initial))) + { + acb_dot_simple(res, initial, subtract, x, xstep, y, ystep, len, prec); + return; + } + + xm = arb_midref(acb_realref(initial)); + xr = arb_radref(acb_realref(initial)); + + if (!arf_is_special(xm)) + { + re_max_exp = ARF_EXP(xm); + re_nonzero++; + + if (prec > 2 * FLINT_BITS) + re_min_exp = ARF_EXP(xm) - ARF_SIZE(xm) * FLINT_BITS; + } + + if (!mag_is_special(xr)) + re_max_rad_exp = MAG_EXP(xr); + + xm = arb_midref(acb_imagref(initial)); + xr = arb_radref(acb_imagref(initial)); + + if (!arf_is_special(xm)) + { + im_max_exp = ARF_EXP(xm); + im_nonzero++; + + if (prec > 2 * FLINT_BITS) + im_min_exp = ARF_EXP(xm) - ARF_SIZE(xm) * FLINT_BITS; + } + + if (!mag_is_special(xr)) + im_max_rad_exp = MAG_EXP(xr); + } + + for (xoff = 0; xoff < 2; xoff++) + { + for (yoff = 0; yoff < 2; yoff++) + { + slong nonzero, max_exp, min_exp, max_rad_exp; + + if (xoff == yoff) + { + nonzero = re_nonzero; + max_exp = re_max_exp; + min_exp = re_min_exp; + max_rad_exp = re_max_rad_exp; + } + else + { + nonzero = im_nonzero; + max_exp = im_max_exp; + min_exp = im_min_exp; + max_rad_exp = im_max_rad_exp; + } + + /* Determine maximum exponents for the main sum and the radius sum. */ + for (i = 0; i < len; i++) + { + xi = ((arb_srcptr) x) + 2 * i * xstep + xoff; + yi = ((arb_srcptr) y) + 2 * i * ystep + yoff; + + /* Fallback for huge exponents or non-finite values. */ + if (!ARB_IS_LAGOM(xi) || !ARB_IS_LAGOM(yi)) + { + acb_dot_simple(res, initial, subtract, x, xstep, y, ystep, len, prec); + return; + } + + xm = arb_midref(xi); + ym = arb_midref(yi); + xr = arb_radref(xi); + yr = arb_radref(yi); + + /* (xm+xr)(ym+yr) = xm ym + [xr ym + xm yr + xr yr] */ + if (!arf_is_special(xm)) + { + xexp = ARF_EXP(xm); + + if (!arf_is_special(ym)) + { + yexp = ARF_EXP(ym); + + max_exp = FLINT_MAX(max_exp, xexp + yexp); + nonzero++; + + if (prec > 2 * FLINT_BITS) + { + slong bot; + bot = (xexp + yexp) - (ARF_SIZE(xm) + ARF_SIZE(ym)) * FLINT_BITS; + min_exp = FLINT_MIN(min_exp, bot); + } + + if (!mag_is_special(xr)) + { + xrexp = MAG_EXP(xr); + max_rad_exp = FLINT_MAX(max_rad_exp, yexp + xrexp); + + if (!mag_is_special(yr)) + { + yrexp = MAG_EXP(yr); + max_rad_exp = FLINT_MAX(max_rad_exp, xexp + yrexp); + max_rad_exp = FLINT_MAX(max_rad_exp, xrexp + yrexp); + } + } + else + { + if (!mag_is_special(yr)) + { + yrexp = MAG_EXP(yr); + max_rad_exp = FLINT_MAX(max_rad_exp, xexp + yrexp); + } + } + } + else /* if y = 0, something can happen only if yr != 0 */ + { + if (!mag_is_special(yr)) + { + yrexp = MAG_EXP(yr); + max_rad_exp = FLINT_MAX(max_rad_exp, xexp + yrexp); + + if (!mag_is_special(xr)) + { + xrexp = MAG_EXP(xr); + max_rad_exp = FLINT_MAX(max_rad_exp, xrexp + yrexp); + } + } + } + } + else /* if x = 0, something can happen only if xr != 0 */ + { + if (!mag_is_special(xr)) + { + xrexp = MAG_EXP(xr); + + if (!arf_is_special(ym)) + { + yexp = ARF_EXP(ym); + max_rad_exp = FLINT_MAX(max_rad_exp, xrexp + yexp); + } + + if (!mag_is_special(yr)) + { + yrexp = MAG_EXP(yr); + max_rad_exp = FLINT_MAX(max_rad_exp, xrexp + yrexp); + } + } + } + } + + if (xoff == yoff) + { + re_nonzero = nonzero; + re_max_exp = max_exp; + re_min_exp = min_exp; + re_max_rad_exp = max_rad_exp; + } + else + { + im_nonzero = nonzero; + im_max_exp = max_exp; + im_min_exp = min_exp; + im_max_rad_exp = max_rad_exp; + } + } + } + + re_prec = prec; + im_prec = prec; + + if (re_max_exp == WORD_MIN && re_max_rad_exp == WORD_MIN && + im_max_exp == WORD_MIN && im_max_rad_exp == WORD_MIN) + { + acb_zero(res); + return; + } + + /* The midpoint sum is zero. */ + if (re_max_exp == WORD_MIN) + { + re_prec = 2; + } + else + { + if (re_max_rad_exp != WORD_MIN) + re_prec = FLINT_MIN(re_prec, re_max_exp - re_max_rad_exp + MAG_BITS); + if (re_min_exp != WORD_MAX) + re_prec = FLINT_MIN(re_prec, re_max_exp - re_min_exp + MAG_BITS); + re_prec = FLINT_MAX(re_prec, 2); + } + + if (im_max_exp == WORD_MIN) + { + im_prec = 2; + } + else + { + if (im_max_rad_exp != WORD_MIN) + im_prec = FLINT_MIN(im_prec, im_max_exp - im_max_rad_exp + MAG_BITS); + if (re_min_exp != WORD_MAX) + im_prec = FLINT_MIN(im_prec, im_max_exp - im_min_exp + MAG_BITS); + im_prec = FLINT_MAX(im_prec, 2); + } + + extend = FLINT_BIT_COUNT(re_nonzero) + 1; + padding = 4 + FLINT_BIT_COUNT(len); + re_sn = (re_prec + extend + padding + FLINT_BITS - 1) / FLINT_BITS; + re_sn = FLINT_MAX(re_sn, 2); + re_sum_exp = re_max_exp + extend; + + extend = FLINT_BIT_COUNT(im_nonzero) + 1; + padding = 4 + FLINT_BIT_COUNT(len); + im_sn = (im_prec + extend + padding + FLINT_BITS - 1) / FLINT_BITS; + im_sn = FLINT_MAX(im_sn, 2); + im_sum_exp = im_max_exp + extend; + + /* We need sn + 1 limb for the sum (sn limbs + 1 dummy limb + for carry or borrow that avoids an extra branch). We need + 2 * (sn + 2) limbs to store the product of two numbers + with up to (sn + 2) limbs, plus 1 extra limb for shifting + the product. */ + alloc = (re_sn + 1) + (im_sn + 1) + 2 * (FLINT_MAX(re_sn, im_sn) + 2) + 1; + ARF_ADD_TMP_ALLOC(re_sum, alloc) + im_sum = re_sum + (re_sn + 1); + tmp = im_sum + (im_sn + 1); + + /* Sum of propagated errors. */ + re_srad_exp = re_max_rad_exp; + re_srad = 0; + im_srad_exp = im_max_rad_exp; + im_srad = 0; + + /* Set sum to 0 */ + re_serr = 0; + for (j = 0; j < re_sn + 1; j++) + re_sum[j] = 0; + im_serr = 0; + for (j = 0; j < im_sn + 1; j++) + im_sum[j] = 0; + + if (initial != NULL) + { + xm = arb_midref(acb_realref(initial)); + xr = arb_radref(acb_realref(initial)); + + ARB_DOT_ADD(re_sum, re_serr, re_sn, re_sum_exp, subtract, xm); + ARB_DOT_ADD_RAD(re_srad, re_srad_exp, xr); + + xm = arb_midref(acb_imagref(initial)); + xr = arb_radref(acb_imagref(initial)); + + ARB_DOT_ADD(im_sum, im_serr, im_sn, im_sum_exp, subtract, xm); + ARB_DOT_ADD_RAD(im_srad, im_srad_exp, xr); + } + + /* + Look for terms to process using the Gauss multiplication formula. + If any such terms are found, we mask the ith entry in use_gauss + so that they will be skipped in the main loop. + Important: the cutoffs must be such that the fast case + (xn <= 2, yn <= 2, sn <= 3) is not hit below and the mask + check is done. + + The cutoffs below are not optimal in the generic case; also, it + would be nicer to have both mulhigh and Gauss here. A more elegant + solution would be to write a fallback version of acb_dot_simple + where acb_addmul does the right thing. + */ + use_gauss = NULL; + + if (re_prec >= GAUSS_CUTOFF * FLINT_BITS && + im_prec >= GAUSS_CUTOFF * FLINT_BITS) + { + arf_t e, f; + + for (i = 0; i < len; i++) + { + arb_srcptr ai, bi, ci, di; + mp_size_t an, bn, cn, dn; + slong aexp, bexp, cexp, dexp; + + ai = ((arb_srcptr) x) + 2 * i * xstep; + bi = ((arb_srcptr) x) + 2 * i * xstep + 1; + ci = ((arb_srcptr) y) + 2 * i * ystep; + di = ((arb_srcptr) y) + 2 * i * ystep + 1; + + an = ARF_SIZE(arb_midref(ai)); + bn = ARF_SIZE(arb_midref(bi)); + cn = ARF_SIZE(arb_midref(ci)); + dn = ARF_SIZE(arb_midref(di)); + + aexp = ARF_EXP(arb_midref(ai)); + bexp = ARF_EXP(arb_midref(bi)); + cexp = ARF_EXP(arb_midref(ci)); + dexp = ARF_EXP(arb_midref(di)); + + if (an >= GAUSS_CUTOFF && bn >= GAUSS_CUTOFF && + bn >= GAUSS_CUTOFF && cn >= GAUSS_CUTOFF && + FLINT_ABS(an - bn) <= 2 && + FLINT_ABS(cn - dn) <= 2 && + FLINT_ABS(aexp - bexp) <= 64 && + FLINT_ABS(cexp - dexp) <= 64 && + re_sum_exp - (aexp + cexp) < 0.1 * re_prec && + im_sum_exp - (aexp + dexp) < 0.1 * im_prec && + an + cn < 2.2 * re_sn && an + dn < 2.2 * im_sn) + { + if (use_gauss == NULL) + { + use_gauss = flint_calloc(len, sizeof(char)); + arf_init(e); + arf_init(f); + } + + use_gauss[i] = 1; + _arf_complex_mul_gauss(e, f, arb_midref(ai), arb_midref(bi), arb_midref(ci), arb_midref(di)); + ARB_DOT_ADD(re_sum, re_serr, re_sn, re_sum_exp, 0, e); + ARB_DOT_ADD(im_sum, im_serr, im_sn, im_sum_exp, 0, f); + } + } + + if (use_gauss != NULL) + { + arf_clear(e); + arf_clear(f); + } + } + + for (xoff = 0; xoff < 2; xoff++) + { + for (yoff = 0; yoff < 2; yoff++) + { + slong sum_exp, srad_exp; + mp_ptr sum; + mp_size_t sn; + mp_limb_t serr, srad; + int flipsign; + + if (xoff == yoff) + { + sum_exp = re_sum_exp; + srad_exp = re_srad_exp; + sum = re_sum; + sn = re_sn; + if (re_max_exp == WORD_MIN && re_max_rad_exp == WORD_MIN) + continue; + } + else + { + sum_exp = im_sum_exp; + srad_exp = im_srad_exp; + sum = im_sum; + sn = im_sn; + if (im_max_exp == WORD_MIN && im_max_rad_exp == WORD_MIN) + continue; + } + + serr = 0; + srad = 0; + flipsign = (xoff + yoff == 2); + + for (i = 0; i < len; i++) + { + xi = ((arb_srcptr) x) + 2 * i * xstep + xoff; + yi = ((arb_srcptr) y) + 2 * i * ystep + yoff; + + xm = arb_midref(xi); + ym = arb_midref(yi); + xr = arb_radref(xi); + yr = arb_radref(yi); + + /* The midpoints of x[i] and y[i] are both nonzero. */ + if (!arf_is_special(xm) && !arf_is_special(ym)) + { + xexp = ARF_EXP(xm); + xn = ARF_SIZE(xm); + xnegative = ARF_SGNBIT(xm); + + yexp = ARF_EXP(ym); + yn = ARF_SIZE(ym); + ynegative = ARF_SGNBIT(ym); + + exp = xexp + yexp; + shift = sum_exp - exp; + + if (shift >= sn * FLINT_BITS) + { + /* We may yet need the top limbs for bounds. */ + ARF_GET_TOP_LIMB(xtop, xm); + ARF_GET_TOP_LIMB(ytop, ym); + serr++; + } + else if (xn <= 2 && yn <= 2 && sn <= 3) + { + mp_limb_t x1, x0, y1, y0; + mp_limb_t u3, u2, u1, u0; + + if (xn == 1 && yn == 1) + { + xtop = ARF_NOPTR_D(xm)[0]; + ytop = ARF_NOPTR_D(ym)[0]; + umul_ppmm(u3, u2, xtop, ytop); + u1 = u0 = 0; + } + else if (xn == 2 && yn == 2) + { + x0 = ARF_NOPTR_D(xm)[0]; + x1 = ARF_NOPTR_D(xm)[1]; + y0 = ARF_NOPTR_D(ym)[0]; + y1 = ARF_NOPTR_D(ym)[1]; + xtop = x1; + ytop = y1; + nn_mul_2x2(u3, u2, u1, u0, x1, x0, y1, y0); + } + else if (xn == 1) + { + x0 = ARF_NOPTR_D(xm)[0]; + y0 = ARF_NOPTR_D(ym)[0]; + y1 = ARF_NOPTR_D(ym)[1]; + xtop = x0; + ytop = y1; + nn_mul_2x1(u3, u2, u1, y1, y0, x0); + u0 = 0; + } + else + { + x0 = ARF_NOPTR_D(xm)[0]; + x1 = ARF_NOPTR_D(xm)[1]; + y0 = ARF_NOPTR_D(ym)[0]; + xtop = x1; + ytop = y0; + nn_mul_2x1(u3, u2, u1, x1, x0, y0); + u0 = 0; + } + + if (sn == 2) + { + if (shift < FLINT_BITS) + { + serr += ((u2 << (FLINT_BITS - shift)) != 0) || (u1 != 0) || (u0 != 0); + u2 = (u2 >> shift) | (u3 << (FLINT_BITS - shift)); + u3 = (u3 >> shift); + } + else if (shift == FLINT_BITS) + { + serr += (u2 != 0) || (u1 != 0) || (u0 != 0); + u2 = u3; + u3 = 0; + } + else /* FLINT_BITS < shift < 2 * FLINT_BITS */ + { + serr += ((u3 << (2 * FLINT_BITS - shift)) != 0) || (u2 != 0) || (u1 != 0) || (u0 != 0); + u2 = (u3 >> (shift - FLINT_BITS)); + u3 = 0; + } + + if (xnegative ^ ynegative ^ flipsign) + sub_ddmmss(sum[1], sum[0], sum[1], sum[0], u3, u2); + else + add_ssaaaa(sum[1], sum[0], sum[1], sum[0], u3, u2); + } + else if (sn == 3) + { + if (shift < FLINT_BITS) + { + serr += ((u1 << (FLINT_BITS - shift)) != 0) || (u0 != 0); + u1 = (u1 >> shift) | (u2 << (FLINT_BITS - shift)); + u2 = (u2 >> shift) | (u3 << (FLINT_BITS - shift)); + u3 = (u3 >> shift); + } + else if (shift == FLINT_BITS) + { + serr += (u1 != 0) || (u0 != 0); + u1 = u2; + u2 = u3; + u3 = 0; + } + else if (shift < 2 * FLINT_BITS) + { + serr += ((u2 << (2 * FLINT_BITS - shift)) != 0) || (u1 != 0) || (u0 != 0); + u1 = (u3 << (2 * FLINT_BITS - shift)) | (u2 >> (shift - FLINT_BITS)); + u2 = (u3 >> (shift - FLINT_BITS)); + u3 = 0; + } + else if (shift == 2 * FLINT_BITS) + { + serr += (u2 != 0) || (u1 != 0) || (u0 != 0); + u1 = u3; + u2 = 0; + u3 = 0; + } + else /* 2 * FLINT_BITS < shift < 3 * FLINT_BITS */ + { + serr += ((u3 << (3 * FLINT_BITS - shift)) != 0) || (u2 != 0) || (u1 != 0) || (u0 != 0); + u1 = (u3 >> (shift - 2 * FLINT_BITS)); + u2 = 0; + u3 = 0; + } + + if (xnegative ^ ynegative ^ flipsign) + sub_dddmmmsss2(sum[2], sum[1], sum[0], sum[2], sum[1], sum[0], u3, u2, u1); + else + add_sssaaaaaa2(sum[2], sum[1], sum[0], sum[2], sum[1], sum[0], u3, u2, u1); + } + } + else + { + mp_srcptr xptr, yptr; + + xptr = (xn <= ARF_NOPTR_LIMBS) ? ARF_NOPTR_D(xm) : ARF_PTR_D(xm); + yptr = (yn <= ARF_NOPTR_LIMBS) ? ARF_NOPTR_D(ym) : ARF_PTR_D(ym); + + xtop = xptr[xn - 1]; + ytop = yptr[yn - 1]; + + if (use_gauss == NULL || use_gauss[i] == 0) + _arb_dot_addmul_generic(sum, &serr, tmp, sn, xptr, xn, yptr, yn, xnegative ^ ynegative ^ flipsign, shift); + } + + xrad = MAG_MAN(xr); + yrad = MAG_MAN(yr); + + if (xrad != 0 && yrad != 0) + { + xrexp = MAG_EXP(xr); + yrexp = MAG_EXP(yr); + + RAD_ADDMUL(srad, srad_exp, (xtop >> (FLINT_BITS - MAG_BITS)) + 1, yrad, xexp + yrexp); + RAD_ADDMUL(srad, srad_exp, (ytop >> (FLINT_BITS - MAG_BITS)) + 1, xrad, yexp + xrexp); + RAD_ADDMUL(srad, srad_exp, xrad, yrad, xrexp + yrexp); + } + else if (xrad != 0) + { + xrexp = MAG_EXP(xr); + RAD_ADDMUL(srad, srad_exp, (ytop >> (FLINT_BITS - MAG_BITS)) + 1, xrad, yexp + xrexp); + } + else if (yrad != 0) + { + yrexp = MAG_EXP(yr); + RAD_ADDMUL(srad, srad_exp, (xtop >> (FLINT_BITS - MAG_BITS)) + 1, yrad, xexp + yrexp); + } + } + else + { + xrad = MAG_MAN(xr); + yrad = MAG_MAN(yr); + + xexp = ARF_EXP(xm); + yexp = ARF_EXP(ym); + + xrexp = MAG_EXP(xr); + yrexp = MAG_EXP(yr); + + /* (xm+xr)(ym+yr) = xm ym + [xm yr + ym xr + xr yr] */ + if (yrad && !arf_is_special(xm)) + { + ARF_GET_TOP_LIMB(xtop, xm); + RAD_ADDMUL(srad, srad_exp, (xtop >> (FLINT_BITS - MAG_BITS)) + 1, yrad, xexp + yrexp); + } + + if (xrad && !arf_is_special(ym)) + { + ARF_GET_TOP_LIMB(ytop, ym); + RAD_ADDMUL(srad, srad_exp, (ytop >> (FLINT_BITS - MAG_BITS)) + 1, xrad, yexp + xrexp); + } + + if (xrad && yrad) + { + RAD_ADDMUL(srad, srad_exp, xrad, yrad, xrexp + yrexp); + } + } + } + + if (xoff == yoff) + { + re_serr += serr; + re_srad += srad; + } + else + { + im_serr += serr; + im_srad += srad; + } + } + } + + _arb_dot_output(acb_realref(res), re_sum, re_sn, subtract, re_serr, re_sum_exp, re_srad, re_srad_exp, re_prec); + _arb_dot_output(acb_imagref(res), im_sum, im_sn, subtract, im_serr, im_sum_exp, im_srad, im_srad_exp, im_prec); + + ARF_ADD_TMP_FREE(re_sum, alloc); + if (use_gauss != NULL) + flint_free(use_gauss); +} diff --git a/acb/dot_precise.c b/acb/dot_precise.c new file mode 100644 index 00000000..78eef5e4 --- /dev/null +++ b/acb/dot_precise.c @@ -0,0 +1,50 @@ +/* + Copyright (C) 2018 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "acb.h" + +void +acb_dot_precise(acb_t res, const acb_t initial, int subtract, acb_srcptr x, slong xstep, + acb_srcptr y, slong ystep, slong len, slong prec) +{ + arb_ptr tmp; + slong i; + + tmp = flint_malloc(sizeof(arb_struct) * (4 * len)); + + for (i = 0; i < len; i++) + { + tmp[0 * len + i] = *acb_realref(x + i * xstep); + tmp[1 * len + i] = *acb_imagref(x + i * xstep); + tmp[2 * len + i] = *acb_realref(y + i * ystep); + arb_init(tmp + 3 * len + i); + arb_neg(tmp + 3 * len + i, acb_imagref(y + i * ystep)); + } + + arb_dot_precise(acb_realref(res), initial == NULL ? NULL : acb_realref(initial), subtract, + tmp, 1, tmp + 2 * len, 1, 2 * len, prec); + + for (i = 0; i < len; i++) + arb_clear(tmp + 3 * len + i); + + for (i = 0; i < len; i++) + { + tmp[0 * len + i] = *acb_realref(x + i * xstep); + tmp[1 * len + i] = *acb_imagref(x + i * xstep); + tmp[2 * len + i] = *acb_imagref(y + i * ystep); + tmp[3 * len + i] = *acb_realref(y + i * ystep); + } + + arb_dot_precise(acb_imagref(res), initial == NULL ? NULL : acb_imagref(initial), subtract, + tmp, 1, tmp + 2 * len, 1, 2 * len, prec); + + flint_free(tmp); +} diff --git a/acb/dot_simple.c b/acb/dot_simple.c new file mode 100644 index 00000000..215596e8 --- /dev/null +++ b/acb/dot_simple.c @@ -0,0 +1,47 @@ +/* + Copyright (C) 2018 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "acb.h" + +void +acb_dot_simple(acb_t res, const acb_t initial, int subtract, + acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec) +{ + slong i; + + if (len <= 0) + { + if (initial == NULL) + acb_zero(res); + else + acb_set_round(res, initial, prec); + return; + } + + if (initial == NULL) + { + acb_mul(res, x, y, prec); + } + else + { + if (subtract) + acb_neg(res, initial); + else + acb_set(res, initial); + acb_addmul(res, x, y, prec); + } + + for (i = 1; i < len; i++) + acb_addmul(res, x + i * xstep, y + i * ystep, prec); + + if (subtract) + acb_neg(res, res); +} diff --git a/acb/test/t-dot.c b/acb/test/t-dot.c new file mode 100644 index 00000000..0cd25848 --- /dev/null +++ b/acb/test/t-dot.c @@ -0,0 +1,200 @@ +/* + Copyright (C) 2018 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "acb.h" + +ARB_DLL extern slong acb_dot_gauss_dot_cutoff; + +int main() +{ + slong iter; + flint_rand_t state; + + flint_printf("dot...."); + fflush(stdout); + + flint_randinit(state); + + for (iter = 0; iter < 1000000 * arb_test_multiplier(); iter++) + { + acb_ptr x, y; + acb_t s1, s2, z; + slong i, len, prec, xbits, ybits, ebits; + int ok, initial, subtract, revx, revy; + + if (n_randint(state, 100) == 0) + len = n_randint(state, 50); + else if (n_randint(state, 10) == 0) + len = n_randint(state, 5); + else + len = n_randint(state, 3); + + acb_dot_gauss_dot_cutoff = 3 + n_randint(state, 30); + + if (n_randint(state, 10) != 0 || len > 10) + { + prec = 2 + n_randint(state, 500); + xbits = 2 + n_randint(state, 500); + ybits = 2 + n_randint(state, 500); + } + else + { + prec = 2 + n_randint(state, 5000); + xbits = 2 + n_randint(state, 5000); + ybits = 2 + n_randint(state, 5000); + } + + if (n_randint(state, 100) == 0) + ebits = 2 + n_randint(state, 100); + else + ebits = 2 + n_randint(state, 10); + + initial = n_randint(state, 2); + subtract = n_randint(state, 2); + revx = n_randint(state, 2); + revy = n_randint(state, 2); + + x = _acb_vec_init(len); + y = _acb_vec_init(len); + acb_init(s1); + acb_init(s2); + acb_init(z); + + switch (n_randint(state, 3)) + { + case 0: + for (i = 0; i < len; i++) + { + acb_randtest(x + i, state, xbits, ebits); + acb_randtest(y + i, state, ybits, ebits); + } + break; + + /* Test with cancellation */ + case 1: + for (i = 0; i < len; i++) + { + if (i <= len / 2) + { + acb_randtest(x + i, state, xbits, ebits); + acb_randtest(y + i, state, ybits, ebits); + } + else + { + acb_neg(x + i, x + len - i - 1); + acb_set(y + i, y + len - i - 1); + } + } + break; + + default: + for (i = 0; i < len; i++) + { + if (i <= len / 2) + { + acb_randtest(x + i, state, xbits, ebits); + acb_randtest(y + i, state, ybits, ebits); + } + else + { + acb_neg_round(x + i, x + len - i - 1, 2 + n_randint(state, 500)); + acb_set_round(y + i, y + len - i - 1, 2 + n_randint(state, 500)); + } + } + break; + } + + acb_randtest(s1, state, 200, 100); + acb_randtest(s2, state, 200, 100); + acb_randtest(z, state, xbits, ebits); + + acb_dot(s1, initial ? z : NULL, subtract, + revx ? (x + len - 1) : x, revx ? -1 : 1, + revy ? (y + len - 1) : y, revy ? -1 : 1, + len, prec); + + acb_dot_precise(s2, initial ? z : NULL, subtract, + revx ? (x + len - 1) : x, revx ? -1 : 1, + revy ? (y + len - 1) : y, revy ? -1 : 1, + len, ebits <= 12 ? ARF_PREC_EXACT : 2 * prec + 100); + + if (ebits <= 12) + ok = acb_contains(s1, s2); + else + ok = acb_overlaps(s1, s2); + + if (!ok) + { + flint_printf("FAIL\n\n"); + flint_printf("iter = %wd, len = %wd, prec = %wd, ebits = %wd, subtract = %d\n\n", iter, len, prec, ebits, subtract); + + if (initial) + { + flint_printf("z = ", i); acb_printn(z, 100, ARB_STR_MORE); flint_printf(" (%wd)\n\n", acb_bits(z)); + } + + for (i = 0; i < len; i++) + { + flint_printf("x[%wd] = ", i); acb_printn(x + i, 100, ARB_STR_MORE); flint_printf(" (%wd)\n", acb_bits(x + i)); + flint_printf("y[%wd] = ", i); acb_printn(y + i, 100, ARB_STR_MORE); flint_printf(" (%wd)\n", acb_bits(y + i)); + } + flint_printf("\n\n"); + flint_printf("s1 = "); acb_printn(s1, 100, ARB_STR_MORE); flint_printf("\n\n"); + flint_printf("s2 = "); acb_printn(s2, 100, ARB_STR_MORE); flint_printf("\n\n"); + flint_abort(); + } + + /* With the fast algorithm, we expect identical results when + reversing the vectors. */ + if (ebits <= 12) + { + revx ^= 1; + revy ^= 1; + + acb_dot(s2, initial ? z : NULL, subtract, + revx ? (x + len - 1) : x, revx ? -1 : 1, + revy ? (y + len - 1) : y, revy ? -1 : 1, + len, prec); + + if (!acb_equal(s1, s2)) + { + flint_printf("FAIL (reversal)\n\n"); + flint_printf("iter = %wd, len = %wd, prec = %wd, ebits = %wd\n\n", iter, len, prec, ebits); + + if (initial) + { + flint_printf("z = ", i); acb_printn(z, 100, ARB_STR_MORE); flint_printf(" (%wd)\n\n", acb_bits(z)); + } + + for (i = 0; i < len; i++) + { + flint_printf("x[%wd] = ", i); acb_printn(x + i, 100, ARB_STR_MORE); flint_printf(" (%wd)\n", acb_bits(x + i)); + flint_printf("y[%wd] = ", i); acb_printn(y + i, 100, ARB_STR_MORE); flint_printf(" (%wd)\n", acb_bits(y + i)); + } + flint_printf("\n\n"); + flint_printf("s1 = "); acb_printn(s1, 100, ARB_STR_MORE); flint_printf("\n\n"); + flint_printf("s2 = "); acb_printn(s2, 100, ARB_STR_MORE); flint_printf("\n\n"); + flint_abort(); + } + } + + acb_clear(s1); + acb_clear(s2); + acb_clear(z); + _acb_vec_clear(x, len); + _acb_vec_clear(y, len); + } + + flint_randclear(state); + flint_cleanup(); + flint_printf("PASS\n"); + return EXIT_SUCCESS; +} diff --git a/acb_poly/div_series.c b/acb_poly/div_series.c index 81e6d58b..419cef9a 100644 --- a/acb_poly/div_series.c +++ b/acb_poly/div_series.c @@ -44,7 +44,7 @@ _acb_poly_div_series(acb_ptr Q, acb_srcptr A, slong Alen, { /* The basecase algorithm is faster for much larger Blen and n than this, but unfortunately has worse numerical stability. */ - slong i, j; + slong i; acb_t q; acb_init(q); @@ -54,16 +54,8 @@ _acb_poly_div_series(acb_ptr Q, acb_srcptr A, slong Alen, for (i = 1; i < n; i++) { - acb_mul(Q + i, B + 1, Q + i - 1, prec); - - for (j = 2; j < FLINT_MIN(i + 1, Blen); j++) - acb_addmul(Q + i, B + j, Q + i - j, prec); - - if (i < Alen) - acb_sub(Q + i, A + i, Q + i, prec); - else - acb_neg(Q + i, Q + i); - + acb_dot(Q + i, (i < Alen) ? A + i : NULL, 1, + B + 1, 1, Q + i - 1, -1, FLINT_MIN(i, Blen - 1), prec); if (!acb_is_one(q)) acb_mul(Q + i, Q + i, q, prec); } diff --git a/acb_poly/exp_series.c b/acb_poly/exp_series.c index c69d99f9..414f0a05 100644 --- a/acb_poly/exp_series.c +++ b/acb_poly/exp_series.c @@ -12,7 +12,7 @@ #include "acb_poly.h" /* allow changing this from the test code */ -ARB_DLL slong acb_poly_newton_exp_cutoff = 120; +ARB_DLL slong acb_poly_newton_exp_cutoff = 0; /* with inverse=1 simultaneously computes g = exp(-x) to length n with inverse=0 uses g as scratch space, computing @@ -112,39 +112,51 @@ _acb_poly_exp_series(acb_ptr f, acb_srcptr h, slong hlen, slong n, slong prec) _acb_vec_zero(f + j - d + 1, n - (j - d + 1)); acb_clear(t); } - else if (hlen <= acb_poly_newton_exp_cutoff) - { - _acb_poly_exp_series_basecase(f, h, hlen, n, prec); - } else { - acb_ptr g, t; - acb_t u; - int fix; + slong cutoff; - g = _acb_vec_init((n + 1) / 2); - fix = (hlen < n || h == f || !acb_is_zero(h)); + if (acb_poly_newton_exp_cutoff != 0) + cutoff = acb_poly_newton_exp_cutoff; + else if (prec <= 256) + cutoff = 750; + else + cutoff = 1e5 / pow(log(prec), 3); - if (fix) + if (hlen <= cutoff) { - t = _acb_vec_init(n); - _acb_vec_set(t + 1, h + 1, hlen - 1); + _acb_poly_exp_series_basecase(f, h, hlen, n, prec); } else - t = (acb_ptr) h; + { + acb_ptr g, t; + acb_t u; + int fix; - acb_init(u); - acb_exp(u, h, prec); + g = _acb_vec_init((n + 1) / 2); + fix = (hlen < n || h == f || !acb_is_zero(h)); - _acb_poly_exp_series_newton(f, g, t, n, prec, 0, acb_poly_newton_exp_cutoff); + if (fix) + { + t = _acb_vec_init(n); + _acb_vec_set(t + 1, h + 1, hlen - 1); + } + else + t = (acb_ptr) h; - if (!acb_is_one(u)) - _acb_vec_scalar_mul(f, f, n, u, prec); + acb_init(u); + acb_exp(u, h, prec); - _acb_vec_clear(g, (n + 1) / 2); - if (fix) - _acb_vec_clear(t, n); - acb_clear(u); + _acb_poly_exp_series_newton(f, g, t, n, prec, 0, cutoff); + + if (!acb_is_one(u)) + _acb_vec_scalar_mul(f, f, n, u, prec); + + _acb_vec_clear(g, (n + 1) / 2); + if (fix) + _acb_vec_clear(t, n); + acb_clear(u); + } } } diff --git a/acb_poly/exp_series_basecase.c b/acb_poly/exp_series_basecase.c index e6eefb53..1de27418 100644 --- a/acb_poly/exp_series_basecase.c +++ b/acb_poly/exp_series_basecase.c @@ -11,13 +11,11 @@ #include "acb_poly.h" -#define MUL_CUTOFF 24 - static void _acb_poly_exp_series_basecase_rec(acb_ptr f, acb_ptr a, acb_srcptr h, slong hlen, slong n, slong prec) { - slong j, k; + slong k; acb_t s; acb_init(s); @@ -29,10 +27,7 @@ _acb_poly_exp_series_basecase_rec(acb_ptr f, acb_ptr a, for (k = 1; k < n; k++) { - acb_zero(s); - for (j = 1; j < FLINT_MIN(k + 1, hlen); j++) - acb_addmul(s, a + j, f + k - j, prec); - + acb_dot(s, NULL, 0, a + 1, 1, f + k - 1, -1, FLINT_MIN(k, hlen - 1), prec); acb_div_ui(f + k, s, k, prec); } @@ -45,7 +40,7 @@ _acb_poly_exp_series_basecase(acb_ptr f, { hlen = FLINT_MIN(n, hlen); - if (n < MUL_CUTOFF || hlen < 0.9 * n) + if (n < 20 || hlen < 0.9 * n || prec <= 2 * FLINT_BITS || n < 1000.0 / log(prec + 10) - 70) { acb_ptr t = _acb_vec_init(hlen); _acb_poly_exp_series_basecase_rec(f, t, h, hlen, n, prec); diff --git a/acb_poly/inv_series.c b/acb_poly/inv_series.c index 9fbb4d92..1872e5a5 100644 --- a/acb_poly/inv_series.c +++ b/acb_poly/inv_series.c @@ -37,7 +37,7 @@ _acb_poly_inv_series(acb_ptr Qinv, } else { - slong i, j, blen; + slong i, blen; /* The basecase algorithm is faster for much larger Qlen or len than this, but unfortunately also much less numerically stable. */ @@ -48,15 +48,10 @@ _acb_poly_inv_series(acb_ptr Qinv, for (i = 1; i < blen; i++) { - acb_mul(Qinv + i, Q + 1, Qinv + i - 1, prec); - - for (j = 2; j < FLINT_MIN(i + 1, Qlen); j++) - acb_addmul(Qinv + i, Q + j, Qinv + i - j, prec); - + acb_dot(Qinv + i, NULL, 1, + Q + 1, 1, Qinv + i - 1, -1, FLINT_MIN(i, Qlen - 1), prec); if (!acb_is_one(Qinv)) acb_mul(Qinv + i, Qinv + i, Qinv, prec); - - acb_neg(Qinv + i, Qinv + i); } if (len > blen) @@ -116,4 +111,3 @@ acb_poly_inv_series(acb_poly_t Qinv, const acb_poly_t Q, slong n, slong prec) _acb_poly_set_length(Qinv, n); _acb_poly_normalise(Qinv); } - diff --git a/acb_poly/mullow.c b/acb_poly/mullow.c index 75002552..ac034c0e 100644 --- a/acb_poly/mullow.c +++ b/acb_poly/mullow.c @@ -11,17 +11,46 @@ #include "acb_poly.h" -#define CUTOFF 4 - void _acb_poly_mullow(acb_ptr res, acb_srcptr poly1, slong len1, acb_srcptr poly2, slong len2, slong n, slong prec) { - if (n < CUTOFF || len1 < CUTOFF || len2 < CUTOFF) + if (n == 1) + { + acb_mul(res, poly1, poly2, prec); + } + else if (n <= 7 || len1 <= 7 || len2 <= 7) + { _acb_poly_mullow_classical(res, poly1, len1, poly2, len2, n, prec); + } else - _acb_poly_mullow_transpose(res, poly1, len1, poly2, len2, n, prec); + { + slong cutoff; + double p; + + if (prec <= 2 * FLINT_BITS) + { + cutoff = 110; + } + else + { + p = log(prec); + + cutoff = 10000.0 / (p * p * p); + cutoff = FLINT_MIN(cutoff, 60); + if (poly1 == poly2 && prec >= 256) + cutoff *= 1.25; + if (poly1 == poly2 && prec >= 4096) + cutoff *= 1.25; + cutoff = FLINT_MAX(cutoff, 8); + } + + if (2 * FLINT_MIN(len1, len2) <= cutoff || n <= cutoff) + _acb_poly_mullow_classical(res, poly1, len1, poly2, len2, n, prec); + else + _acb_poly_mullow_transpose(res, poly1, len1, poly2, len2, n, prec); + } } void diff --git a/acb_poly/mullow_classical.c b/acb_poly/mullow_classical.c index 316415bb..11465bf0 100644 --- a/acb_poly/mullow_classical.c +++ b/acb_poly/mullow_classical.c @@ -26,35 +26,55 @@ _acb_poly_mullow_classical(acb_ptr res, } else if (poly1 == poly2 && len1 == len2) { - slong i; + slong i, start, stop; - _acb_vec_scalar_mul(res, poly1, FLINT_MIN(len1, n), poly1, prec); - _acb_vec_scalar_mul(res + len1, poly1 + 1, n - len1, poly1 + len1 - 1, prec); + acb_sqr(res, poly1, prec); + acb_mul(res + 1, poly1, poly1 + 1, prec); + acb_mul_2exp_si(res + 1, res + 1, 1); - for (i = 1; i < len1 - 1; i++) - _acb_vec_scalar_addmul(res + i + 1, poly1 + 1, - FLINT_MIN(i - 1, n - (i + 1)), poly1 + i, prec); + for (i = 2; i < FLINT_MIN(n, 2 * len1 - 3); i++) + { + start = FLINT_MAX(0, i - len1 + 1); + stop = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1); - for (i = 1; i < FLINT_MIN(2 * len1 - 2, n); i++) + acb_dot(res + i, NULL, 0, poly1 + start, 1, + poly1 + i - start, -1, stop - start + 1, prec); acb_mul_2exp_si(res + i, res + i, 1); + if (i % 2 == 0 && i / 2 < len1) + acb_addmul(res + i, poly1 + i / 2, poly1 + i / 2, prec); + } - for (i = 1; i < FLINT_MIN(len1 - 1, (n + 1) / 2); i++) - acb_addmul(res + 2 * i, poly1 + i, poly1 + i, prec); + if (len1 > 2 && n >= 2 * len1 - 2) + { + acb_mul(res + 2 * len1 - 3, poly1 + len1 - 1, poly1 + len1 - 2, prec); + acb_mul_2exp_si(res + 2 * len1 - 3, res + 2 * len1 - 3, 1); + } + + if (n >= 2 * len1 - 1) + acb_sqr(res + 2 * len1 - 2, poly1 + len1 - 1, prec); + } + else if (len1 == 1) + { + _acb_vec_scalar_mul(res, poly2, n, poly1, prec); + } + else if (len2 == 1) + { + _acb_vec_scalar_mul(res, poly1, n, poly2, prec); } else { - slong i; + slong i, top1, top2; - _acb_vec_scalar_mul(res, poly1, FLINT_MIN(len1, n), poly2, prec); + acb_mul(res, poly1, poly2, prec); - if (n > len1) - _acb_vec_scalar_mul(res + len1, poly2 + 1, n - len1, - poly1 + len1 - 1, prec); + for (i = 1; i < n; i++) + { + top1 = FLINT_MIN(len1 - 1, i); + top2 = FLINT_MIN(len2 - 1, i); - for (i = 0; i < FLINT_MIN(len1, n) - 1; i++) - _acb_vec_scalar_addmul(res + i + 1, poly2 + 1, - FLINT_MIN(len2, n - i) - 1, - poly1 + i, prec); + acb_dot(res + i, NULL, 0, poly1 + i - top2, 1, + poly2 + top2, -1, top1 + top2 - i + 1, prec); + } } } @@ -94,4 +114,3 @@ acb_poly_mullow_classical(acb_poly_t res, const acb_poly_t poly1, _acb_poly_set_length(res, n); _acb_poly_normalise(res); } - diff --git a/acb_poly/sin_cos_pi_series.c b/acb_poly/sin_cos_pi_series.c index 846cd263..58d48dce 100644 --- a/acb_poly/sin_cos_pi_series.c +++ b/acb_poly/sin_cos_pi_series.c @@ -11,10 +11,8 @@ #include "acb_poly.h" -#define TANGENT_CUTOFF 80 - void -_acb_poly_sin_cos_pi_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen, slong n, slong prec) +_acb_poly_sin_cos_pi_series(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen, slong n, slong prec) { hlen = FLINT_MIN(hlen, n); @@ -36,10 +34,25 @@ _acb_poly_sin_cos_pi_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen acb_mul(c + 1, s, t, prec); acb_clear(t); } - else if (hlen < TANGENT_CUTOFF) - _acb_poly_sin_cos_series_basecase(s, c, h, hlen, n, prec, 1); else - _acb_poly_sin_cos_series_tangent(s, c, h, hlen, n, prec, 1); + { + slong cutoff; + + if (prec <= 128) + { + cutoff = 1400; + } + else + { + cutoff = 100000 / pow(log(prec), 3); + cutoff = FLINT_MIN(cutoff, 700); + } + + if (hlen < cutoff) + _acb_poly_sin_cos_series_basecase(s, c, h, hlen, n, prec, 1); + else + _acb_poly_sin_cos_series_tangent(s, c, h, hlen, n, prec, 1); + } } void @@ -73,4 +86,3 @@ acb_poly_sin_cos_pi_series(acb_poly_t s, acb_poly_t c, _acb_poly_set_length(c, n); _acb_poly_normalise(c); } - diff --git a/acb_poly/sin_cos_series.c b/acb_poly/sin_cos_series.c index 42b5c733..6557ab74 100644 --- a/acb_poly/sin_cos_series.c +++ b/acb_poly/sin_cos_series.c @@ -11,10 +11,8 @@ #include "acb_poly.h" -#define TANGENT_CUTOFF 80 - void -_acb_poly_sin_cos_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen, slong n, slong prec) +_acb_poly_sin_cos_series(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen, slong n, slong prec) { hlen = FLINT_MIN(hlen, n); @@ -35,10 +33,25 @@ _acb_poly_sin_cos_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen, s acb_mul(c + 1, s, t, prec); acb_clear(t); } - else if (hlen < TANGENT_CUTOFF) - _acb_poly_sin_cos_series_basecase(s, c, h, hlen, n, prec, 0); else - _acb_poly_sin_cos_series_tangent(s, c, h, hlen, n, prec, 0); + { + slong cutoff; + + if (prec <= 128) + { + cutoff = 1400; + } + else + { + cutoff = 100000 / pow(log(prec), 3); + cutoff = FLINT_MIN(cutoff, 700); + } + + if (hlen < cutoff) + _acb_poly_sin_cos_series_basecase(s, c, h, hlen, n, prec, 0); + else + _acb_poly_sin_cos_series_tangent(s, c, h, hlen, n, prec, 0); + } } void diff --git a/acb_poly/sin_cos_series_basecase.c b/acb_poly/sin_cos_series_basecase.c index 9af63a9b..3b44649a 100644 --- a/acb_poly/sin_cos_series_basecase.c +++ b/acb_poly/sin_cos_series_basecase.c @@ -15,7 +15,7 @@ void _acb_poly_sin_cos_series_basecase(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen, slong n, slong prec, int times_pi) { - slong j, k, alen = FLINT_MIN(n, hlen); + slong k, alen = FLINT_MIN(n, hlen); acb_ptr a; acb_t t, u; @@ -46,15 +46,8 @@ _acb_poly_sin_cos_series_basecase(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen for (k = 1; k < n; k++) { - acb_zero(t); - acb_zero(u); - - for (j = 1; j < FLINT_MIN(k + 1, hlen); j++) - { - acb_submul(t, a + j, s + k - j, prec); - acb_addmul(u, a + j, c + k - j, prec); - } - + acb_dot(t, NULL, 1, a + 1, 1, s + k - 1, -1, FLINT_MIN(k, hlen - 1), prec); + acb_dot(u, NULL, 0, a + 1, 1, c + k - 1, -1, FLINT_MIN(k, hlen - 1), prec); acb_div_ui(c + k, t, k, prec); acb_div_ui(s + k, u, k, prec); } @@ -92,4 +85,3 @@ acb_poly_sin_cos_series_basecase(acb_poly_t s, acb_poly_t c, _acb_poly_set_length(c, n); _acb_poly_normalise(c); } - diff --git a/acb_poly/sinh_cosh_series.c b/acb_poly/sinh_cosh_series.c index ff3accce..d2d82233 100644 --- a/acb_poly/sinh_cosh_series.c +++ b/acb_poly/sinh_cosh_series.c @@ -12,7 +12,7 @@ #include "acb_poly.h" void -_acb_poly_sinh_cosh_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen, slong n, slong prec) +_acb_poly_sinh_cosh_series(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen, slong n, slong prec) { hlen = FLINT_MIN(hlen, n); @@ -32,10 +32,20 @@ _acb_poly_sinh_cosh_series(acb_ptr s, acb_ptr c, const acb_srcptr h, slong hlen, acb_mul(c + 1, s, t, prec); acb_clear(t); } - else if (hlen < 60 || n < 120) - _acb_poly_sinh_cosh_series_basecase(s, c, h, hlen, n, prec); else - _acb_poly_sinh_cosh_series_exponential(s, c, h, hlen, n, prec); + { + slong cutoff; + + if (prec <= 128) + cutoff = 400; + else + cutoff = 30000 / pow(log(prec), 3); + + if (hlen < cutoff) + _acb_poly_sinh_cosh_series_basecase(s, c, h, hlen, n, prec); + else + _acb_poly_sinh_cosh_series_exponential(s, c, h, hlen, n, prec); + } } void @@ -69,4 +79,3 @@ acb_poly_sinh_cosh_series(acb_poly_t s, acb_poly_t c, _acb_poly_set_length(c, n); _acb_poly_normalise(c); } - diff --git a/acb_poly/sinh_cosh_series_basecase.c b/acb_poly/sinh_cosh_series_basecase.c index 0e3b393b..9e98f28a 100644 --- a/acb_poly/sinh_cosh_series_basecase.c +++ b/acb_poly/sinh_cosh_series_basecase.c @@ -15,7 +15,7 @@ void _acb_poly_sinh_cosh_series_basecase(acb_ptr s, acb_ptr c, acb_srcptr h, slong hlen, slong n, slong prec) { - slong j, k, alen = FLINT_MIN(n, hlen); + slong k, alen = FLINT_MIN(n, hlen); acb_ptr a; acb_t t, u; @@ -37,15 +37,8 @@ _acb_poly_sinh_cosh_series_basecase(acb_ptr s, acb_ptr c, acb_srcptr h, slong hl for (k = 1; k < n; k++) { - acb_zero(t); - acb_zero(u); - - for (j = 1; j < FLINT_MIN(k + 1, hlen); j++) - { - acb_addmul(t, a + j, s + k - j, prec); - acb_addmul(u, a + j, c + k - j, prec); - } - + acb_dot(t, NULL, 0, a + 1, 1, s + k - 1, -1, FLINT_MIN(k, hlen - 1), prec); + acb_dot(u, NULL, 0, a + 1, 1, c + k - 1, -1, FLINT_MIN(k, hlen - 1), prec); acb_div_ui(c + k, t, k, prec); acb_div_ui(s + k, u, k, prec); } @@ -83,4 +76,3 @@ acb_poly_sinh_cosh_series_basecase(acb_poly_t s, acb_poly_t c, _acb_poly_set_length(c, n); _acb_poly_normalise(c); } - diff --git a/arb/test/t-dot.c b/arb/test/t-dot.c index 9c5e2c67..58230145 100644 --- a/arb/test/t-dot.c +++ b/arb/test/t-dot.c @@ -35,7 +35,7 @@ int main() else len = n_randint(state, 3); - if (n_randint(state, 100) == 0 || len > 10) + if (n_randint(state, 10) != 0 || len > 10) { prec = 2 + n_randint(state, 500); xbits = 2 + n_randint(state, 500); diff --git a/doc/source/acb.rst b/doc/source/acb.rst index ecb8760e..f7f0136e 100644 --- a/doc/source/acb.rst +++ b/doc/source/acb.rst @@ -499,6 +499,46 @@ Arithmetic Sets *z* to the quotient of *x* and *y*. +Dot product +------------------------------------------------------------------------------- + +.. function:: void acb_dot_precise(acb_t res, const acb_t s, int subtract, acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec) + +.. function:: void acb_dot_simple(acb_t res, const acb_t s, int subtract, acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec) + +.. function:: void acb_dot(acb_t res, const acb_t s, int subtract, acb_srcptr x, slong xstep, acb_srcptr y, slong ystep, slong len, slong prec) + + Computes the dot product of the vectors *x* and *y*, setting + *res* to `s + (-1)^{subtract} \sum_{i=0}^{len-1} x_i y_i`. + + The initial term *s* is optional and can be + omitted by passing *NULL* (equivalently, `s = 0`). + The parameter *subtract* must be 0 or 1. + The length *len* is allowed to be negative, which is equivalent + to a length of zero. + The parameters *xstep* or *ystep* specify a step length for + traversing subsequences of the vectors *x* and *y*; either can be + negative to step in the reverse direction starting from + the initial pointer. + Aliasing is allowed between *res* and *s* but not between + *res* and the entries of *x* and *y*. + + The default version determines the optimal precision for each term + and performs all internal calculations using mpn arithmetic + with minimal overhead. This is the preferred way to compute a + dot product; it is generally much faster and more precise + than a simple loop. + + The *simple* version performs fused multiply-add operations in + a simple loop. This can be used for + testing purposes and is also used as a fallback by the + default version when the exponents are out of range + for the optimized code. + + The *precise* version computes the dot product exactly up to the + final rounding. This can be extremely slow and is only intended + for testing. + Mathematical constants ------------------------------------------------------------------------------- diff --git a/doc/source/arb.rst b/doc/source/arb.rst index 44064dd4..c3a87bb2 100644 --- a/doc/source/arb.rst +++ b/doc/source/arb.rst @@ -782,7 +782,7 @@ Arithmetic Sets `z = x / (2^n - 1)`, rounded to *prec* bits. -Sum and dot product +Dot product ------------------------------------------------------------------------------- .. function:: void arb_dot_precise(arb_t res, const arb_t s, int subtract, arb_srcptr x, slong xstep, arb_srcptr y, slong ystep, slong len, slong prec)