Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions ext/bigdecimal/bigdecimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

#define BIGDECIMAL_VERSION "4.1.0"

#define NTT_MULTIPLICATION_THRESHOLD 350
/* Make sure VPMULT_BATCH_SIZE*BASE*BASE does not overflow DECDIG_DBL */
#define VPMULT_BATCH_SIZE 16
#define NTT_MULTIPLICATION_THRESHOLD 450
#define NEWTON_RAPHSON_DIVISION_THRESHOLD 100
#define SIGNED_VALUE_MAX INTPTR_MAX
#define SIGNED_VALUE_MIN INTPTR_MIN
Expand Down Expand Up @@ -4842,7 +4844,7 @@ VP_EXPORT size_t
VpMult(Real *c, Real *a, Real *b)
{
ssize_t a_batch_max, b_batch_max;
DECDIG_DBL batch[15];
DECDIG_DBL batch[VPMULT_BATCH_SIZE * 2 - 1];

if (!VpIsDefOP(c, a, b, OP_SW_MULT)) return 0; /* No significant digit */

Expand Down Expand Up @@ -4882,27 +4884,28 @@ VpMult(Real *c, Real *a, Real *b)
c->Prec = a->Prec + b->Prec; /* set precision */
memset(c->frac, 0, c->Prec * sizeof(DECDIG)); /* Initialize c */

// Process 8 decdigits at a time to reduce the number of carry operations.
a_batch_max = (a->Prec - 1) / 8;
b_batch_max = (b->Prec - 1) / 8;
// Process VPMULT_BATCH_SIZE decdigits at a time to reduce the number of carry operations.
a_batch_max = (a->Prec - 1) / VPMULT_BATCH_SIZE;
b_batch_max = (b->Prec - 1) / VPMULT_BATCH_SIZE;
for (ssize_t ibatch = a_batch_max; ibatch >= 0; ibatch--) {
int isize = ibatch == a_batch_max ? (a->Prec - 1) % 8 + 1 : 8;
int isize = ibatch == a_batch_max ? (a->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE;
for (ssize_t jbatch = b_batch_max; jbatch >= 0; jbatch--) {
int jsize = jbatch == b_batch_max ? (b->Prec - 1) % 8 + 1 : 8;
int jsize = jbatch == b_batch_max ? (b->Prec - 1) % VPMULT_BATCH_SIZE + 1 : VPMULT_BATCH_SIZE;
memset(batch, 0, (isize + jsize - 1) * sizeof(DECDIG_DBL));

// Perform multiplication without carry calculation.
// 999999999 * 999999999 * 8 < 2**63 - 1, so DECDIG_DBL can hold the intermediate sum without overflow.
// BASE * BASE * VPMULT_BATCH_SIZE < 2**64 should be satisfied so that
// DECDIG_DBL can hold the intermediate sum without overflow.
for (int i = 0; i < isize; i++) {
for (int j = 0; j < jsize; j++) {
batch[i + j] += (DECDIG_DBL)a->frac[ibatch * 8 + i] * b->frac[jbatch * 8 + j];
batch[i + j] += (DECDIG_DBL)a->frac[ibatch * VPMULT_BATCH_SIZE + i] * b->frac[jbatch * VPMULT_BATCH_SIZE + j];
}
}

// Add the batch result to c with carry calculation.
DECDIG_DBL carry = 0;
for (int k = isize + jsize - 2; k >= 0; k--) {
size_t l = (ibatch + jbatch) * 8 + k + 1;
size_t l = (ibatch + jbatch) * VPMULT_BATCH_SIZE + k + 1;
DECDIG_DBL s = c->frac[l] + batch[k] + carry;
c->frac[l] = (DECDIG)(s % BASE);
carry = (DECDIG_DBL)(s / BASE);
Expand All @@ -4911,7 +4914,7 @@ VpMult(Real *c, Real *a, Real *b)
// Adding carry may exceed BASE, but it won't cause overflow of DECDIG.
// Exceeded value will be resolved in the carry operation of next (ibatch + jbatch - 1) batch.
// WARNING: This safety strongly relies on the current nested loop execution order.
c->frac[(ibatch + jbatch) * 8] += (DECDIG)carry;
c->frac[(ibatch + jbatch) * VPMULT_BATCH_SIZE] += (DECDIG)carry;
}
}

Expand Down
4 changes: 2 additions & 2 deletions test/bigdecimal/test_vp_operation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setup

def test_vpmult
# Max carry case
[*32...40].repeated_permutation(2) do |n, m|
[*32...48].repeated_permutation(2) do |n, m|
x = BigDecimal('9' * BASE_FIG * n)
y = BigDecimal('9' * BASE_FIG * m)
assert_equal(x.to_i * y.to_i, x.vpmult(y))
Expand All @@ -30,7 +30,7 @@ def test_vpmult

def test_nttmult
# Max carry case
[*32...40].repeated_permutation(2) do |n, m|
[*32...48].repeated_permutation(2) do |n, m|
x = BigDecimal('9' * BASE_FIG * n)
y = BigDecimal('9' * BASE_FIG * m)
assert_equal(x.to_i * y.to_i, x.nttmult(y))
Expand Down