mirror of
https://github.com/vale981/arb
synced 2025-03-05 09:21:38 -05:00
549 lines
16 KiB
C
549 lines
16 KiB
C
/*
|
|
Copyright (C) 2018 Fredrik Johansson
|
|
|
|
This file is part of Arb.
|
|
|
|
Arb is free software: you can redistribute it and/or modify it under
|
|
the terms of the GNU Lesser General Public License (LGPL) as published
|
|
by the Free Software Foundation; either version 2.1 of the License, or
|
|
(at your option) any later version. See <http://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
#include "arb_mat.h"
|
|
|
|
int arb_mat_is_lagom(const arb_mat_t A)
|
|
{
|
|
slong i, j, M, N;
|
|
|
|
M = arb_mat_nrows(A);
|
|
N = arb_mat_ncols(A);
|
|
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
if (!ARB_IS_LAGOM(arb_mat_entry(A, i, j)))
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
return 1;
|
|
}
|
|
|
|
/* allow changing this from the test code */
|
|
ARB_DLL slong arb_mat_mul_block_min_block_size = 0;
|
|
|
|
void
|
|
arb_mat_mid_addmul_block_fallback(arb_mat_t C,
|
|
const arb_mat_t A, const arb_mat_t B,
|
|
slong block_start,
|
|
slong block_end,
|
|
slong prec)
|
|
{
|
|
slong M, P, n;
|
|
slong i, j;
|
|
arb_ptr tmpA, tmpB;
|
|
|
|
M = arb_mat_nrows(A);
|
|
P = arb_mat_ncols(B);
|
|
|
|
n = block_end - block_start;
|
|
|
|
tmpA = flint_malloc(sizeof(arb_struct) * (M * n + P * n));
|
|
tmpB = tmpA + M * n;
|
|
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
for (j = 0; j < n; j++)
|
|
{
|
|
*arb_midref(tmpA + i * n + j) = *arb_midref(arb_mat_entry(A, i, block_start + j));
|
|
mag_init(arb_radref(tmpA + i * n + j));
|
|
}
|
|
}
|
|
|
|
for (i = 0; i < P; i++)
|
|
{
|
|
for (j = 0; j < n; j++)
|
|
{
|
|
*arb_midref(tmpB + i * n + j) = *arb_midref(arb_mat_entry(B, block_start + j, i));
|
|
mag_init(arb_radref(tmpB + i * n + j));
|
|
}
|
|
}
|
|
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
for (j = 0; j < P; j++)
|
|
{
|
|
arb_dot(arb_mat_entry(C, i, j),
|
|
(block_start == 0) ? NULL : arb_mat_entry(C, i, j), 0,
|
|
tmpA + i * n, 1, tmpB + j * n, 1, n, prec);
|
|
}
|
|
}
|
|
|
|
flint_free(tmpA);
|
|
}
|
|
|
|
void
|
|
arb_mat_mid_addmul_block_prescaled(arb_mat_t C,
|
|
const arb_mat_t A, const arb_mat_t B,
|
|
slong block_start,
|
|
slong block_end,
|
|
const slong * A_min, /* A per-row bottom exponent */
|
|
const slong * B_min, /* B per-row bottom exponent */
|
|
slong prec)
|
|
{
|
|
slong M, P, n;
|
|
slong i, j;
|
|
slong M0, M1, P0, P1, Mstep, Pstep;
|
|
int inexact;
|
|
|
|
/* flint_printf("block mul from %wd to %wd\n", block_start, block_end); */
|
|
|
|
M = arb_mat_nrows(A);
|
|
P = arb_mat_ncols(B);
|
|
|
|
n = block_end - block_start;
|
|
|
|
/* Create sub-blocks to keep matrices nearly square. Necessary? */
|
|
#if 1
|
|
Mstep = (M < 2 * n) ? M : n;
|
|
Pstep = (P < 2 * n) ? P : n;
|
|
#else
|
|
Mstep = M;
|
|
Pstep = P;
|
|
#endif
|
|
|
|
for (M0 = 0; M0 < M; M0 += Mstep)
|
|
{
|
|
for (P0 = 0; P0 < P; P0 += Pstep)
|
|
{
|
|
fmpz_mat_t AA, BB, CC;
|
|
arb_t t;
|
|
fmpz_t e;
|
|
|
|
M1 = FLINT_MIN(M0 + Mstep, M);
|
|
P1 = FLINT_MIN(P0 + Pstep, P);
|
|
|
|
fmpz_mat_init(AA, M1 - M0, n);
|
|
fmpz_mat_init(BB, n, P1 - P0);
|
|
fmpz_mat_init(CC, M1 - M0, P1 - P0);
|
|
|
|
/* Convert to fixed-point matrices. */
|
|
for (i = M0; i < M1; i++)
|
|
{
|
|
if (A_min[i] == WORD_MIN) /* only zeros in this row */
|
|
continue;
|
|
|
|
for (j = 0; j < n; j++)
|
|
{
|
|
inexact = arf_get_fmpz_fixed_si(fmpz_mat_entry(AA, i - M0, j),
|
|
arb_midref(arb_mat_entry(A, i, block_start + j)), A_min[i]);
|
|
|
|
if (inexact)
|
|
{
|
|
flint_printf("matrix multiplication: bad exponent!\n");
|
|
flint_abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
for (i = P0; i < P1; i++)
|
|
{
|
|
if (B_min[i] == WORD_MIN) /* only zeros in this column */
|
|
continue;
|
|
|
|
for (j = 0; j < n; j++)
|
|
{
|
|
inexact = arf_get_fmpz_fixed_si(fmpz_mat_entry(BB, j, i - P0),
|
|
arb_midref(arb_mat_entry(B, block_start + j, i)), B_min[i]);
|
|
|
|
if (inexact)
|
|
{
|
|
flint_printf("matrix multiplication: bad exponent!\n");
|
|
flint_abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
/* The main multiplication */
|
|
fmpz_mat_mul(CC, AA, BB);
|
|
/* flint_printf("bits %wd %wd %wd\n", fmpz_mat_max_bits(CC),
|
|
fmpz_mat_max_bits(AA), fmpz_mat_max_bits(BB)); */
|
|
|
|
fmpz_mat_clear(AA);
|
|
fmpz_mat_clear(BB);
|
|
|
|
arb_init(t);
|
|
|
|
/* Add to the result matrix */
|
|
for (i = M0; i < M1; i++)
|
|
{
|
|
for (j = P0; j < P1; j++)
|
|
{
|
|
*e = A_min[i] + B_min[j];
|
|
|
|
/* The first time we write this Cij */
|
|
if (block_start == 0)
|
|
{
|
|
arb_set_round_fmpz_2exp(arb_mat_entry(C, i, j),
|
|
fmpz_mat_entry(CC, i - M0, j - P0), e, prec);
|
|
}
|
|
else
|
|
{
|
|
arb_set_round_fmpz_2exp(t, fmpz_mat_entry(CC, i - M0, j - P0), e, prec);
|
|
arb_add(arb_mat_entry(C, i, j), arb_mat_entry(C, i, j), t, prec);
|
|
}
|
|
}
|
|
}
|
|
arb_clear(t);
|
|
|
|
fmpz_mat_clear(CC);
|
|
}
|
|
}
|
|
}
|
|
|
|
/* todo: squaring optimizations */
|
|
void
|
|
arb_mat_mul_block(arb_mat_t C, const arb_mat_t A, const arb_mat_t B, slong prec)
|
|
{
|
|
slong M, N, P;
|
|
slong *A_min, *A_max, *B_min, *B_max;
|
|
short *A_bits, *B_bits;
|
|
slong *A_bot, *B_bot;
|
|
slong block_start, block_end, i, j, bot, top, max_height;
|
|
slong b, A_max_bits, B_max_bits;
|
|
slong min_block_size;
|
|
arb_srcptr t;
|
|
int A_exact, B_exact;
|
|
double A_density, B_density;
|
|
|
|
M = arb_mat_nrows(A);
|
|
N = arb_mat_ncols(A);
|
|
P = arb_mat_ncols(B);
|
|
|
|
if (N != arb_mat_nrows(B) || M != arb_mat_nrows(C) || P != arb_mat_ncols(C))
|
|
{
|
|
flint_printf("arb_mat_mul_block: incompatible dimensions\n");
|
|
flint_abort();
|
|
}
|
|
|
|
if (M == 0 || N == 0 || P == 0 || arb_mat_is_zero(A) || arb_mat_is_zero(B))
|
|
{
|
|
arb_mat_zero(C);
|
|
return;
|
|
}
|
|
|
|
if (A == C || B == C)
|
|
{
|
|
arb_mat_t T;
|
|
arb_mat_init(T, M, P);
|
|
arb_mat_mul_block(T, A, B, prec);
|
|
arb_mat_swap_entrywise(T, C);
|
|
arb_mat_clear(T);
|
|
return;
|
|
}
|
|
|
|
/* We assume everywhere below that exponents cannot overflow/underflow
|
|
the small fmpz value range. */
|
|
if (!arb_mat_is_lagom(A) || !arb_mat_is_lagom(B))
|
|
{
|
|
arb_mat_mul_classical(C, A, B, prec);
|
|
return;
|
|
}
|
|
|
|
/* bottom exponents of A */
|
|
A_bot = flint_malloc(sizeof(slong) * M * N);
|
|
/* minimum bottom exponent in current row */
|
|
A_min = flint_malloc(sizeof(slong) * M);
|
|
/* maximum top exponent in current row */
|
|
A_max = flint_malloc(sizeof(slong) * M);
|
|
|
|
B_bot = flint_malloc(sizeof(slong) * N * P);
|
|
B_min = flint_malloc(sizeof(slong) * P);
|
|
B_max = flint_malloc(sizeof(slong) * P);
|
|
|
|
/* save space using shorts to store the bit sizes temporarily;
|
|
the block algorithm will not be used at extremely high precision */
|
|
A_bits = flint_malloc(sizeof(short) * M * N);
|
|
B_bits = flint_malloc(sizeof(short) * N * P);
|
|
|
|
A_exact = B_exact = 1;
|
|
A_max_bits = B_max_bits = 0;
|
|
A_density = B_density = 0;
|
|
|
|
/* Build table of bottom exponents (WORD_MIN signifies a zero),
|
|
and also collect some statistics. */
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
t = arb_mat_entry(A, i, j);
|
|
if (arf_is_zero(arb_midref(t)))
|
|
{
|
|
A_bot[i * N + j] = WORD_MIN;
|
|
A_bits[i * N + j] = 0;
|
|
}
|
|
else
|
|
{
|
|
b = arf_bits(arb_midref(t));
|
|
A_bot[i * N + j] = ARF_EXP(arb_midref(t)) - b;
|
|
A_bits[i * N + j] = b;
|
|
A_max_bits = FLINT_MAX(A_max_bits, b);
|
|
A_density++;
|
|
}
|
|
A_exact = A_exact && mag_is_zero(arb_radref(t));
|
|
}
|
|
}
|
|
|
|
for (i = 0; i < N; i++)
|
|
{
|
|
for (j = 0; j < P; j++)
|
|
{
|
|
t = arb_mat_entry(B, i, j);
|
|
if (arf_is_zero(arb_midref(t)))
|
|
{
|
|
B_bot[i * P + j] = WORD_MIN;
|
|
B_bits[i * P + j] = 0;
|
|
}
|
|
else
|
|
{
|
|
b = arf_bits(arb_midref(t));
|
|
B_bot[i * P + j] = ARF_EXP(arb_midref(t)) - b;
|
|
B_bits[i * P + j] = b;
|
|
B_max_bits = FLINT_MAX(B_max_bits, b);
|
|
B_density++;
|
|
}
|
|
B_exact = B_exact && mag_is_zero(arb_radref(t));
|
|
}
|
|
}
|
|
|
|
A_density = A_density / (M * N);
|
|
B_density = B_density / (N * P);
|
|
|
|
/* Don't shift too far when creating integer block matrices. */
|
|
max_height = 1.25 * FLINT_MIN(prec, FLINT_MAX(A_max_bits, B_max_bits)) + 192;
|
|
|
|
/* Avoid block algorithm for extremely high-precision matrices? */
|
|
/* Warning: these cutoffs are completely bogus... */
|
|
if (A_max_bits > 8000 || B_max_bits > 8000 ||
|
|
(A_density < 0.1 && B_density < 0.1 && max_height > 1024))
|
|
{
|
|
flint_free(A_bot);
|
|
flint_free(A_max);
|
|
flint_free(A_min);
|
|
flint_free(B_bot);
|
|
flint_free(B_max);
|
|
flint_free(B_min);
|
|
flint_free(A_bits);
|
|
flint_free(B_bits);
|
|
arb_mat_mul_classical(C, A, B, prec);
|
|
return;
|
|
}
|
|
|
|
if (arb_mat_mul_block_min_block_size != 0)
|
|
min_block_size = arb_mat_mul_block_min_block_size;
|
|
else
|
|
min_block_size = 30;
|
|
|
|
block_start = 0;
|
|
while (block_start < N)
|
|
{
|
|
/* Find a run of columns of A and rows of B such that the
|
|
bottom exponents differ by at most max_height. */
|
|
|
|
block_end = block_start + 1; /* index is exclusive block_end */
|
|
|
|
/* begin with this column of A and row of B */
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
A_max[i] = A_min[i] = A_bot[i * N + block_start];
|
|
A_max[i] += (slong) A_bits[i * N + block_start];
|
|
}
|
|
|
|
for (i = 0; i < P; i++)
|
|
{
|
|
B_max[i] = B_min[i] = B_bot[block_start * P + i];
|
|
B_max[i] += (slong) B_bits[block_start * P + i];
|
|
}
|
|
|
|
while (block_end < N)
|
|
{
|
|
double size;
|
|
|
|
/* End block if memory would be excessive. */
|
|
/* Necessary? */
|
|
/* Should also do initial check above, if C alone is too large. */
|
|
size = (block_end - block_start) * M * (double) A_max_bits;
|
|
size += (block_end - block_start) * P * (double) B_max_bits;
|
|
size += (M * P) * (double) (A_max_bits + B_max_bits);
|
|
size /= 8.0;
|
|
if (size > 2e9)
|
|
goto blocks_built;
|
|
|
|
/* check if we can extend with column [block_end] of A */
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
bot = A_bot[i * N + block_end];
|
|
/* zeros are irrelevant */
|
|
if (bot == WORD_MIN || A_max[i] == WORD_MIN)
|
|
continue;
|
|
top = bot + (slong) A_bits[i * N + block_end];
|
|
/* jump will be too big */
|
|
if (top > A_min[i] + max_height || bot < A_max[i] - max_height)
|
|
goto blocks_built;
|
|
}
|
|
|
|
/* check if we can extend with row [block_end] of B */
|
|
for (i = 0; i < P; i++)
|
|
{
|
|
bot = B_bot[block_end * P + i];
|
|
if (bot == WORD_MIN || B_max[i] == WORD_MIN)
|
|
continue;
|
|
top = bot + (slong) B_bits[block_end * P + i];
|
|
if (top > B_min[i] + max_height || bot < B_max[i] - max_height)
|
|
goto blocks_built;
|
|
}
|
|
|
|
/* second pass to update the extreme values */
|
|
for (i = 0; i < M; i++)
|
|
{
|
|
bot = A_bot[i * N + block_end];
|
|
top = bot + (slong) A_bits[i * N + block_end];
|
|
if (A_max[i] == WORD_MIN)
|
|
{
|
|
A_max[i] = top;
|
|
A_min[i] = bot;
|
|
}
|
|
else if (bot != WORD_MIN)
|
|
{
|
|
if (bot < A_min[i]) A_min[i] = bot;
|
|
if (top > A_max[i]) A_max[i] = top;
|
|
}
|
|
}
|
|
|
|
for (i = 0; i < P; i++)
|
|
{
|
|
bot = B_bot[block_end * P + i];
|
|
top = bot + (slong) B_bits[block_end * P + i];
|
|
if (B_max[i] == WORD_MIN)
|
|
{
|
|
B_max[i] = top;
|
|
B_min[i] = bot;
|
|
}
|
|
else if (bot != WORD_MIN)
|
|
{
|
|
if (bot < B_min[i]) B_min[i] = bot;
|
|
if (top > B_max[i]) B_max[i] = top;
|
|
}
|
|
}
|
|
|
|
block_end++;
|
|
}
|
|
|
|
blocks_built:
|
|
if (block_end - block_start < min_block_size)
|
|
{
|
|
block_end = FLINT_MIN(N, block_start + min_block_size);
|
|
|
|
arb_mat_mid_addmul_block_fallback(C, A, B,
|
|
block_start, block_end, prec);
|
|
}
|
|
else
|
|
{
|
|
arb_mat_mid_addmul_block_prescaled(C, A, B,
|
|
block_start, block_end, A_min, B_min, prec);
|
|
}
|
|
|
|
block_start = block_end;
|
|
}
|
|
|
|
flint_free(A_bot);
|
|
flint_free(A_max);
|
|
flint_free(A_min);
|
|
flint_free(B_bot);
|
|
flint_free(B_max);
|
|
flint_free(B_min);
|
|
flint_free(A_bits);
|
|
flint_free(B_bits);
|
|
|
|
/* Radius multiplications */
|
|
if (!A_exact || !B_exact)
|
|
{
|
|
mag_ptr AA, BB;
|
|
|
|
/* Shallow (since exponents are small!) mag_struct matrices
|
|
represented by linear arrays; B is transposed to improve locality. */
|
|
AA = flint_malloc(M * N * sizeof(mag_struct));
|
|
BB = flint_malloc(P * N * sizeof(mag_struct));
|
|
|
|
if (!A_exact && !B_exact)
|
|
{
|
|
/* (A+ar)(B+br) = AB + (A+ar)br + ar B
|
|
= AB + A br + ar (B + br) */
|
|
|
|
/* A + ar */
|
|
for (i = 0; i < M; i++)
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
mag_fast_init_set_arf(AA + i * N + j,
|
|
arb_midref(arb_mat_entry(A, i, j)));
|
|
mag_add(AA + i * N + j, AA + i * N + j,
|
|
arb_radref(arb_mat_entry(A, i, j)));
|
|
}
|
|
|
|
/* br */
|
|
for (i = 0; i < N; i++)
|
|
for (j = 0; j < P; j++)
|
|
BB[j * N + i] = *arb_radref(arb_mat_entry(B, i, j));
|
|
|
|
_arb_mat_addmul_rad_mag_fast(C, AA, BB, M, N, P);
|
|
|
|
/* ar */
|
|
for (i = 0; i < M; i++)
|
|
for (j = 0; j < N; j++)
|
|
AA[i * N + j] = *arb_radref(arb_mat_entry(A, i, j));
|
|
|
|
/* B */
|
|
for (i = 0; i < N; i++)
|
|
for (j = 0; j < P; j++)
|
|
mag_fast_init_set_arf(BB + j * N + i,
|
|
arb_midref(arb_mat_entry(B, i, j)));
|
|
|
|
_arb_mat_addmul_rad_mag_fast(C, AA, BB, M, N, P);
|
|
}
|
|
else if (A_exact)
|
|
{
|
|
/* A(B+br) = AB + A br */
|
|
|
|
for (i = 0; i < M; i++)
|
|
for (j = 0; j < N; j++)
|
|
mag_fast_init_set_arf(AA + i * N + j,
|
|
arb_midref(arb_mat_entry(A, i, j)));
|
|
|
|
for (i = 0; i < N; i++)
|
|
for (j = 0; j < P; j++)
|
|
BB[j * N + i] = *arb_radref(arb_mat_entry(B, i, j));
|
|
|
|
_arb_mat_addmul_rad_mag_fast(C, AA, BB, M, N, P);
|
|
}
|
|
else
|
|
{
|
|
/* (A+ar)B = AB + ar B */
|
|
|
|
for (i = 0; i < M; i++)
|
|
for (j = 0; j < N; j++)
|
|
AA[i * N + j] = *arb_radref(arb_mat_entry(A, i, j));
|
|
|
|
for (i = 0; i < N; i++)
|
|
for (j = 0; j < P; j++)
|
|
mag_fast_init_set_arf(BB + j * N + i,
|
|
arb_midref(arb_mat_entry(B, i, j)));
|
|
|
|
_arb_mat_addmul_rad_mag_fast(C, AA, BB, M, N, P);
|
|
}
|
|
|
|
flint_free(AA);
|
|
flint_free(BB);
|
|
}
|
|
}
|
|
|