diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index 6561d9ed..7756ec63 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -35,7 +35,9 @@ #define BIGDECIMAL_VERSION "4.1.0" -#define NTT_MULTIPLICATION_THRESHOLD 350 +/* Make sure VPMULT_BATCH_SIZE*BASE*BASE does not overflow DECDIG_DBL */ +#define VPMULT_BATCH_SIZE 16 +#define NTT_MULTIPLICATION_THRESHOLD 450 #define NEWTON_RAPHSON_DIVISION_THRESHOLD 100 #define SIGNED_VALUE_MAX INTPTR_MAX #define SIGNED_VALUE_MIN INTPTR_MIN @@ -4842,7 +4844,7 @@ VP_EXPORT size_t VpMult(Real *c, Real *a, Real *b) { ssize_t a_batch_max, b_batch_max; - DECDIG_DBL batch[15]; + DECDIG_DBL batch[VPMULT_BATCH_SIZE * 2 - 1]; if (!VpIsDefOP(c, a, b, OP_SW_MULT)) return 0; /* No significant digit */ @@ -4882,27 +4884,28 @@ VpMult(Real *c, Real *a, Real *b) c->Prec = a->Prec + b->Prec; /* set precision */ memset(c->frac, 0, c->Prec * sizeof(DECDIG)); /* Initialize c */ - // Process 8 decdigits at a time to reduce the number of carry operations. - a_batch_max = (a->Prec - 1) / 8; - b_batch_max = (b->Prec - 1) / 8; + // Process VPMULT_BATCH_SIZE decdigits at a time to reduce the number of carry operations. + a_batch_max = (a->Prec - 1) / VPMULT_BATCH_SIZE; + b_batch_max = (b->Prec - 1) / VPMULT_BATCH_SIZE; for (ssize_t ibatch = a_batch_max; ibatch >= 0; ibatch--) { - int isize = ibatch == a_batch_max ? (a->Prec - 1) % 8 + 1 : 8; + int isize = ibatch == a_batch_max ? (a->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE; for (ssize_t jbatch = b_batch_max; jbatch >= 0; jbatch--) { - int jsize = jbatch == b_batch_max ? (b->Prec - 1) % 8 + 1 : 8; + int jsize = jbatch == b_batch_max ? (b->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE; memset(batch, 0, (isize + jsize - 1) * sizeof(DECDIG_DBL)); // Perform multiplication without carry calculation. - // 999999999 * 999999999 * 8 < 2**63 - 1, so DECDIG_DBL can hold the intermediate sum without overflow. + // BASE * BASE * VPMULT_BATCH_SIZE < 2**64 should be satisfied so that + // DECDIG_DBL can hold the intermediate sum without overflow. for (int i = 0; i < isize; i++) { for (int j = 0; j < jsize; j++) { - batch[i + j] += (DECDIG_DBL)a->frac[ibatch * 8 + i] * b->frac[jbatch * 8 + j]; + batch[i + j] += (DECDIG_DBL)a->frac[ibatch * VPMULT_BATCH_SIZE + i] * b->frac[jbatch * VPMULT_BATCH_SIZE + j]; } } // Add the batch result to c with carry calculation. DECDIG_DBL carry = 0; for (int k = isize + jsize - 2; k >= 0; k--) { - size_t l = (ibatch + jbatch) * 8 + k + 1; + size_t l = (ibatch + jbatch) * VPMULT_BATCH_SIZE + k + 1; DECDIG_DBL s = c->frac[l] + batch[k] + carry; c->frac[l] = (DECDIG)(s % BASE); carry = (DECDIG_DBL)(s / BASE); @@ -4911,7 +4914,7 @@ VpMult(Real *c, Real *a, Real *b) // Adding carry may exceed BASE, but it won't cause overflow of DECDIG. // Exceeded value will be resolved in the carry operation of next (ibatch + jbatch - 1) batch. // WARNING: This safety strongly relies on the current nested loop execution order. - c->frac[(ibatch + jbatch) * 8] += (DECDIG)carry; + c->frac[(ibatch + jbatch) * VPMULT_BATCH_SIZE] += (DECDIG)carry; } } diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index 5b5dab65..ce690aee 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -15,7 +15,7 @@ def setup def test_vpmult # Max carry case - [*32...40].repeated_permutation(2) do |n, m| + [*32...48].repeated_permutation(2) do |n, m| x = BigDecimal('9' * BASE_FIG * n) y = BigDecimal('9' * BASE_FIG * m) assert_equal(x.to_i * y.to_i, x.vpmult(y)) @@ -30,7 +30,7 @@ def test_vpmult def test_nttmult # Max carry case - [*32...40].repeated_permutation(2) do |n, m| + [*32...48].repeated_permutation(2) do |n, m| x = BigDecimal('9' * BASE_FIG * n) y = BigDecimal('9' * BASE_FIG * m) assert_equal(x.to_i * y.to_i, x.nttmult(y))