diff --git a/arb.h b/arb.h index acacd0d2..e84b2ad6 100644 --- a/arb.h +++ b/arb.h @@ -425,6 +425,10 @@ void arb_submul_si(arb_t z, const arb_t x, slong y, slong prec); void arb_submul_ui(arb_t z, const arb_t x, ulong y, slong prec); void arb_submul_fmpz(arb_t z, const arb_t x, const fmpz_t y, slong prec); +void arb_fma(arb_t res, const arb_t x, const arb_t y, const arb_t z, slong prec); +void arb_fma_arf(arb_t res, const arb_t x, const arf_t y, const arb_t z, slong prec); +void arb_fma_ui(arb_t res, const arb_t x, ulong y, const arb_t z, slong prec); + void arb_dot_simple(arb_t res, const arb_t initial, int subtract, arb_srcptr x, slong xstep, arb_srcptr y, slong ystep, slong len, slong prec); void arb_dot_precise(arb_t res, const arb_t initial, int subtract, diff --git a/arb/fma.c b/arb/fma.c new file mode 100644 index 00000000..0c9156f5 --- /dev/null +++ b/arb/fma.c @@ -0,0 +1,125 @@ +/* + Copyright (C) 2021 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "arb.h" + +void +arb_fma_arf(arb_t res, const arb_t x, const arf_t y, const arb_t z, slong prec) +{ + mag_t ym; + int inexact; + + if (arb_is_exact(x)) + { + inexact = arf_fma(arb_midref(res), arb_midref(x), y, arb_midref(z), prec, ARB_RND); + + if (inexact) + arf_mag_add_ulp(arb_radref(res), arb_radref(z), arb_midref(res), prec); + else + mag_set(arb_radref(res), arb_radref(z)); + } + else if (ARB_IS_LAGOM(res) && ARB_IS_LAGOM(x) && ARF_IS_LAGOM(y) && ARB_IS_LAGOM(z)) + { + mag_t tm; + + mag_fast_init_set_arf(ym, y); + *tm = *arb_radref(z); + mag_fast_addmul(tm, ym, arb_radref(x)); + *arb_radref(res) = *tm; + + inexact = arf_fma(arb_midref(res), arb_midref(x), y, arb_midref(z), prec, ARB_RND); + if (inexact) + arf_mag_fast_add_ulp(arb_radref(res), arb_radref(res), arb_midref(res), prec); + } + else + { + mag_t tm; + mag_init(tm); + + mag_init_set_arf(ym, y); + mag_set(tm, arb_radref(z)); + + mag_addmul(tm, ym, arb_radref(x)); + mag_set(arb_radref(res), tm); + + inexact = arf_fma(arb_midref(res), arb_midref(x), y, arb_midref(z), prec, ARB_RND); + if (inexact) + arf_mag_add_ulp(arb_radref(res), arb_radref(res), arb_midref(res), prec); + + mag_clear(tm); + mag_clear(ym); + } +} + +void +arb_fma(arb_t res, const arb_t x, const arb_t y, const arb_t z, slong prec) +{ + mag_t zr, xm, ym; + int inexact; + + if (arb_is_exact(y)) + { + arb_fma_arf(res, x, arb_midref(y), z, prec); + } + else if (arb_is_exact(x)) + { + arb_fma_arf(res, y, arb_midref(x), z, prec); + } + else if (ARB_IS_LAGOM(res) && ARB_IS_LAGOM(x) && ARB_IS_LAGOM(y) && ARB_IS_LAGOM(z)) + { + mag_fast_init_set_arf(xm, arb_midref(x)); + mag_fast_init_set_arf(ym, arb_midref(y)); + + mag_fast_init_set(zr, arb_radref(z)); + mag_fast_addmul(zr, xm, arb_radref(y)); + mag_fast_addmul(zr, ym, arb_radref(x)); + mag_fast_addmul(zr, arb_radref(x), arb_radref(y)); + + inexact = arf_fma(arb_midref(res), arb_midref(x), arb_midref(y), arb_midref(z), + prec, ARF_RND_DOWN); + + if (inexact) + arf_mag_fast_add_ulp(zr, zr, arb_midref(res), prec); + + *arb_radref(res) = *zr; + } + else + { + mag_init_set_arf(xm, arb_midref(x)); + mag_init_set_arf(ym, arb_midref(y)); + + mag_init_set(zr, arb_radref(z)); + mag_addmul(zr, xm, arb_radref(y)); + mag_addmul(zr, ym, arb_radref(x)); + mag_addmul(zr, arb_radref(x), arb_radref(y)); + + inexact = arf_fma(arb_midref(res), arb_midref(x), arb_midref(y), arb_midref(z), + prec, ARF_RND_DOWN); + + if (inexact) + arf_mag_add_ulp(arb_radref(res), zr, arb_midref(res), prec); + else + mag_set(arb_radref(res), zr); + + mag_clear(zr); + mag_clear(xm); + mag_clear(ym); + } +} + +void +arb_fma_ui(arb_t res, const arb_t x, ulong y, const arb_t z, slong prec) +{ + arf_t t; + arf_init_set_ui(t, y); /* no need to free */ + arb_fma_arf(res, x, t, z, prec); +} + diff --git a/arb/test/t-fma.c b/arb/test/t-fma.c new file mode 100644 index 00000000..6561673a --- /dev/null +++ b/arb/test/t-fma.c @@ -0,0 +1,125 @@ +/* + Copyright (C) 2012 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "arb.h" + +void +arb_fma_naive(arb_t res, const arb_t x, const arb_t y, const arb_t z, slong prec) +{ + arb_t t; + arb_init(t); + arb_set(t, z); + arb_addmul(t, x, y, prec); + arb_set(res, t); + arb_clear(t); +} + +int main() +{ + slong iter; + flint_rand_t state; + + flint_printf("fma...."); + fflush(stdout); + + flint_randinit(state); + + for (iter = 0; iter < 10000 * arb_test_multiplier(); iter++) + { + arb_t x, y, z, res1, res2; + slong prec; + int aliasing; + + arb_init(x); + arb_init(y); + arb_init(z); + arb_init(res1); + arb_init(res2); + + prec = 2 + n_randint(state, 200); + + arb_randtest_special(x, state, 200, 100); + arb_randtest_special(y, state, 200, 100); + arb_randtest_special(z, state, 200, 100); + arb_randtest_special(res1, state, 200, 100); + arb_randtest_special(res2, state, 200, 100); + + if (n_randint(state, 10) == 0 && + fmpz_bits(ARF_EXPREF(arb_midref(x))) < 10 && + fmpz_bits(ARF_EXPREF(arb_midref(y))) < 10 && + fmpz_bits(ARF_EXPREF(arb_midref(z))) < 10) + { + prec = ARF_PREC_EXACT; + } + + aliasing = n_randint(state, 7); + + switch (aliasing) + { + case 0: + arb_fma(res1, x, y, z, prec); + arb_fma_naive(res2, x, y, z, prec); + break; + case 1: + arb_set(res1, z); + arb_fma(res1, x, y, res1, prec); + arb_fma_naive(res2, x, y, z, prec); + break; + case 2: + arb_set(res1, x); + arb_fma(res1, res1, y, z, prec); + arb_fma_naive(res2, x, y, z, prec); + break; + case 3: + arb_set(res1, y); + arb_fma(res1, x, res1, z, prec); + arb_fma_naive(res2, x, y, z, prec); + break; + case 4: + arb_fma(res1, x, x, z, prec); + arb_fma_naive(res2, x, x, z, prec); + break; + case 5: + arb_set(res1, x); + arb_fma(res1, res1, res1, z, prec); + arb_fma_naive(res2, x, x, z, prec); + break; + default: + arb_set(res1, x); + arb_fma(res1, res1, res1, res1, prec); + arb_fma_naive(res2, x, x, x, prec); + break; + } + + if (!arb_equal(res1, res2)) + { + flint_printf("FAIL!\n"); + flint_printf("prec = %wd, aliasing = %d\n\n", prec, aliasing); + flint_printf("x = "); arb_printd(x, 30); flint_printf("\n\n"); + flint_printf("y = "); arb_printd(y, 30); flint_printf("\n\n"); + flint_printf("z = "); arb_printd(z, 30); flint_printf("\n\n"); + flint_printf("res1 = "); arb_printd(res1, 30); flint_printf("\n\n"); + flint_printf("res2 = "); arb_printd(res2, 30); flint_printf("\n\n"); + flint_abort(); + } + + arb_clear(x); + arb_clear(y); + arb_clear(z); + arb_clear(res1); + arb_clear(res2); + } + + flint_randclear(state); + flint_cleanup(); + flint_printf("PASS\n"); + return EXIT_SUCCESS; +} diff --git a/arf.h b/arf.h index c9e62f32..495293ce 100644 --- a/arf.h +++ b/arf.h @@ -1071,6 +1071,8 @@ arf_submul_fmpz(arf_ptr z, arf_srcptr x, const fmpz_t y, slong prec, arf_rnd_t r return arf_submul_mpz(z, x, COEFF_TO_PTR(*y), prec, rnd); } +int arf_fma(arf_ptr res, arf_srcptr x, arf_srcptr y, arf_srcptr z, slong prec, arf_rnd_t rnd); + int arf_sosq(arf_t z, const arf_t x, const arf_t y, slong prec, arf_rnd_t rnd); int arf_div(arf_ptr z, arf_srcptr x, arf_srcptr y, slong prec, arf_rnd_t rnd); diff --git a/arf/fma.c b/arf/fma.c new file mode 100644 index 00000000..36afc0bc --- /dev/null +++ b/arf/fma.c @@ -0,0 +1,78 @@ +/* + Copyright (C) 2021 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "arf.h" + +int +arf_fma(arf_ptr res, arf_srcptr x, arf_srcptr y, arf_srcptr z, slong prec, arf_rnd_t rnd) +{ + mp_size_t xn, yn, zn, tn, alloc; + mp_srcptr xptr, yptr, zptr; + mp_ptr tptr, tptr2; + fmpz_t texp; + slong shift; + int tsgnbit, inexact; + ARF_MUL_TMP_DECL + + if (arf_is_special(x) || arf_is_special(y) || arf_is_special(z)) + { + if (arf_is_zero(z)) + { + return arf_mul(res, x, y, prec, rnd); + } + else if (arf_is_finite(x) && arf_is_finite(y)) + { + return arf_set_round(res, z, prec, rnd); + } + else + { + /* todo: speed up */ + arf_t t; + arf_init(t); + arf_mul(t, x, y, ARF_PREC_EXACT, ARF_RND_DOWN); + inexact = arf_add(res, z, t, prec, rnd); + arf_clear(t); + return inexact; + } + } + + tsgnbit = ARF_SGNBIT(x) ^ ARF_SGNBIT(y); + ARF_GET_MPN_READONLY(xptr, xn, x); + ARF_GET_MPN_READONLY(yptr, yn, y); + ARF_GET_MPN_READONLY(zptr, zn, z); + + fmpz_init(texp); + + _fmpz_add2_fast(texp, ARF_EXPREF(x), ARF_EXPREF(y), 0); + shift = _fmpz_sub_small(ARF_EXPREF(z), texp); + + alloc = tn = xn + yn; + ARF_MUL_TMP_ALLOC(tptr2, alloc) + tptr = tptr2; + + ARF_MPN_MUL(tptr, xptr, xn, yptr, yn); + + tn -= (tptr[0] == 0); + tptr += (tptr[0] == 0); + + if (shift >= 0) + inexact = _arf_add_mpn(res, zptr, zn, ARF_SGNBIT(z), ARF_EXPREF(z), + tptr, tn, tsgnbit, shift, prec, rnd); + else + inexact = _arf_add_mpn(res, tptr, tn, tsgnbit, texp, + zptr, zn, ARF_SGNBIT(z), -shift, prec, rnd); + + ARF_MUL_TMP_FREE(tptr2, alloc) + fmpz_clear(texp); + + return inexact; +} + diff --git a/arf/test/t-fma.c b/arf/test/t-fma.c new file mode 100644 index 00000000..9cefaac3 --- /dev/null +++ b/arf/test/t-fma.c @@ -0,0 +1,141 @@ +/* + Copyright (C) 2012 Fredrik Johansson + + This file is part of Arb. + + Arb is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. See . +*/ + +#include "arf.h" + +int +arf_fma_naive(arf_t res, const arf_t x, const arf_t y, const arf_t z, slong prec, arf_rnd_t rnd) +{ + arf_t t; + int inexact; + + arf_init(t); + arf_mul(t, x, y, ARF_PREC_EXACT, ARF_RND_DOWN); + + inexact = arf_add(res, z, t, prec, rnd); + + arf_clear(t); + + return inexact; +} + +int main() +{ + slong iter; + flint_rand_t state; + + flint_printf("fma...."); + fflush(stdout); + + flint_randinit(state); + + for (iter = 0; iter < 10000 * arb_test_multiplier(); iter++) + { + arf_t x, y, z, res1, res2; + slong prec, r1, r2; + arf_rnd_t rnd; + int aliasing; + + arf_init(x); + arf_init(y); + arf_init(z); + arf_init(res1); + arf_init(res2); + + prec = 2 + n_randint(state, 200); + + arf_randtest_special(x, state, 200, 100); + arf_randtest_special(y, state, 200, 100); + arf_randtest_special(z, state, 200, 100); + arf_randtest_special(res1, state, 200, 100); + arf_randtest_special(res2, state, 200, 100); + + if (n_randint(state, 10) == 0 && + fmpz_bits(ARF_EXPREF(x)) < 10 && + fmpz_bits(ARF_EXPREF(y)) < 10 && + fmpz_bits(ARF_EXPREF(z)) < 10) + { + prec = ARF_PREC_EXACT; + } + + switch (n_randint(state, 5)) + { + case 0: rnd = ARF_RND_DOWN; break; + case 1: rnd = ARF_RND_UP; break; + case 2: rnd = ARF_RND_FLOOR; break; + case 3: rnd = ARF_RND_CEIL; break; + default: rnd = ARF_RND_NEAR; break; + } + + aliasing = n_randint(state, 7); + + switch (aliasing) + { + case 0: + r1 = arf_fma(res1, x, y, z, prec, rnd); + r2 = arf_fma_naive(res2, x, y, z, prec, rnd); + break; + case 1: + arf_set(res1, z); + r1 = arf_fma(res1, x, y, res1, prec, rnd); + r2 = arf_fma_naive(res2, x, y, z, prec, rnd); + break; + case 2: + arf_set(res1, x); + r1 = arf_fma(res1, res1, y, z, prec, rnd); + r2 = arf_fma_naive(res2, x, y, z, prec, rnd); + break; + case 3: + arf_set(res1, y); + r1 = arf_fma(res1, x, res1, z, prec, rnd); + r2 = arf_fma_naive(res2, x, y, z, prec, rnd); + break; + case 4: + r1 = arf_fma(res1, x, x, z, prec, rnd); + r2 = arf_fma_naive(res2, x, x, z, prec, rnd); + break; + case 5: + arf_set(res1, x); + r1 = arf_fma(res1, res1, res1, z, prec, rnd); + r2 = arf_fma_naive(res2, x, x, z, prec, rnd); + break; + default: + arf_set(res1, x); + r1 = arf_fma(res1, res1, res1, res1, prec, rnd); + r2 = arf_fma_naive(res2, x, x, x, prec, rnd); + break; + } + + if (!arf_equal(res1, res2) || r1 != r2) + { + flint_printf("FAIL!\n"); + flint_printf("prec = %wd, rnd = %d, aliasing = %d\n\n", prec, rnd, aliasing); + flint_printf("x = "); arf_print(x); flint_printf("\n\n"); + flint_printf("y = "); arf_print(y); flint_printf("\n\n"); + flint_printf("z = "); arf_print(z); flint_printf("\n\n"); + flint_printf("res1 = "); arf_print(res1); flint_printf("\n\n"); + flint_printf("res2 = "); arf_print(res2); flint_printf("\n\n"); + flint_printf("r1 = %wd, r2 = %wd\n", r1, r2); + flint_abort(); + } + + arf_clear(x); + arf_clear(y); + arf_clear(z); + arf_clear(res1); + arf_clear(res2); + } + + flint_randclear(state); + flint_cleanup(); + flint_printf("PASS\n"); + return EXIT_SUCCESS; +} diff --git a/doc/source/arb.rst b/doc/source/arb.rst index 31ba1f90..4161d06a 100644 --- a/doc/source/arb.rst +++ b/doc/source/arb.rst @@ -828,6 +828,13 @@ Arithmetic Sets `z = z - x \cdot y`, rounded to prec bits. The precision can be *ARF_PREC_EXACT* provided that the result fits in memory. +.. function:: void arb_fma(arb_t res, const arb_t x, const arb_t y, const arb_t z, slong prec) + void arb_fma_arf(arb_t res, const arb_t x, const arf_t y, const arb_t z, slong prec) + void arb_fma_ui(arb_t res, const arb_t x, ulong y, const arb_t z, slong prec) + + Sets *res* to `x \cdot y + z`. This is equivalent to an *addmul* except + that *res* and *z* can be separate variables. + .. function:: void arb_inv(arb_t z, const arb_t x, slong prec) Sets *z* to `1 / x`. diff --git a/doc/source/arf.rst b/doc/source/arf.rst index 603cb90c..0d8ba4fd 100644 --- a/doc/source/arf.rst +++ b/doc/source/arf.rst @@ -621,6 +621,11 @@ Addition and multiplication Performs a fused multiply-subtract `z = z - x \cdot y`, updating *z* in-place. +.. function:: int arf_fma(arf_t res, const arf_t x, const arf_t y, const arf_t z, slong prec, arf_rnd_t rnd) + + Sets *res* to `x \cdot y + z`. This is equivalent to an *addmul* except + that *res* and *z* can be separate variables. + .. function:: int arf_sosq(arf_t res, const arf_t x, const arf_t y, slong prec, arf_rnd_t rnd) Sets *res* to `x^2 + y^2`, rounded to *prec* bits in the direction specified by *rnd*.