v8
V8 is Google’s open source high-performance JavaScript and WebAssembly engine, written in C++.
Loading...
Searching...
No Matches
mul-fft.cc
Go to the documentation of this file.
1// Copyright 2021 the V8 project authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// FFT-based multiplication, due to Schönhage and Strassen.
6// This implementation mostly follows the description given in:
7// Christoph Lüders: Fast Multiplication of Large Integers,
8// http://arxiv.org/abs/1503.04955
9
12#include "src/bigint/util.h"
13
14namespace v8 {
15namespace bigint {
16
17namespace {
18
20// Part 1: Functions for "mod F_n" arithmetic.
21// F_n is of the shape 2^K + 1, and for convenience we use K to count the
22// number of digits rather than the number of bits, so F_n (or K) are implicit
23// and deduced from the length {len} of the digits array.
24
25// Helper function for {ModFn} below.
26void ModFn_Helper(digit_t* x, int len, signed_digit_t high) {
27 if (high > 0) {
28 digit_t borrow = high;
29 x[len - 1] = 0;
30 for (int i = 0; i < len; i++) {
31 x[i] = digit_sub(x[i], borrow, &borrow);
32 if (borrow == 0) break;
33 }
34 } else {
35 digit_t carry = -high;
36 x[len - 1] = 0;
37 for (int i = 0; i < len; i++) {
38 x[i] = digit_add2(x[i], carry, &carry);
39 if (carry == 0) break;
40 }
41 }
42}
43
44// {x} := {x} mod F_n, assuming that {x} is "slightly" larger than F_n (e.g.
45// after addition of two numbers that were mod-F_n-normalized before).
46void ModFn(digit_t* x, int len) {
47 int K = len - 1;
48 signed_digit_t high = x[K];
49 if (high == 0) return;
50 ModFn_Helper(x, len, high);
51 high = x[K];
52 if (high == 0) return;
53 DCHECK(high == 1 || high == -1);
54 ModFn_Helper(x, len, high);
55 high = x[K];
56 if (high == -1) ModFn_Helper(x, len, high);
57}
58
59// {dest} := {src} mod F_n, assuming that {src} is about twice as long as F_n
60// (e.g. after multiplication of two numbers that were mod-F_n-normalized
61// before).
62// {len} is length of {dest}; {src} is twice as long.
63void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
64 int K = len - 1;
65 digit_t borrow = 0;
66 for (int i = 0; i < K; i++) {
67 dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
68 }
69 dest[K] = digit_sub2(0, src[2 * K], borrow, &borrow);
70 // {borrow} may be non-zero here, that's OK as {ModFn} will take care of it.
71 ModFn(dest, len);
72}
73
74// Sets {sum} := {a} + {b} and {diff} := {a} - {b}, which is more efficient
75// than computing sum and difference separately. Applies "mod F_n" normalization
76// to both results.
77void SumDiff(digit_t* sum, digit_t* diff, const digit_t* a, const digit_t* b,
78 int len) {
79 digit_t carry = 0;
80 digit_t borrow = 0;
81 for (int i = 0; i < len; i++) {
82 // Read both values first, because inputs and outputs can overlap.
83 digit_t ai = a[i];
84 digit_t bi = b[i];
85 sum[i] = digit_add3(ai, bi, carry, &carry);
86 diff[i] = digit_sub2(ai, bi, borrow, &borrow);
87 }
88 ModFn(sum, len);
89 ModFn(diff, len);
90}
91
92// {result} := ({input} << shift) mod F_n, where shift >= K.
93void ShiftModFn_Large(digit_t* result, const digit_t* input, int digit_shift,
94 int bits_shift, int K) {
95 // If {digit_shift} is greater than K, we use the following transformation
96 // (where, since everything is mod 2^K + 1, we are allowed to add or
97 // subtract any multiple of 2^K + 1 at any time):
98 // x * 2^{K+m} mod 2^K + 1
99 // == x * 2^K * 2^m - (2^K + 1)*(x * 2^m) mod 2^K + 1
100 // == x * 2^K * 2^m - x * 2^K * 2^m - x * 2^m mod 2^K + 1
101 // == -x * 2^m mod 2^K + 1
102 // So the flow is the same as for m < K, but we invert the subtraction's
103 // operands. In order to avoid underflow, we virtually initialize the
104 // result to 2^K + 1:
105 // input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
106 // result = [ 1][0000] .... .... [0000][0001]
107 // + [ iK ] .... [ iX ]
108 // - [iX-1] .... [ i0 ]
109 DCHECK(digit_shift >= K);
110 digit_shift -= K;
111 digit_t borrow = 0;
112 if (bits_shift == 0) {
113 digit_t carry = 1;
114 for (int i = 0; i < digit_shift; i++) {
115 result[i] = digit_add2(input[i + K - digit_shift], carry, &carry);
116 }
117 result[digit_shift] = digit_sub(input[K] + carry, input[0], &borrow);
118 for (int i = digit_shift + 1; i < K; i++) {
119 digit_t d = input[i - digit_shift];
120 result[i] = digit_sub2(0, d, borrow, &borrow);
121 }
122 } else {
123 digit_t add_carry = 1;
124 digit_t input_carry =
125 input[K - digit_shift - 1] >> (kDigitBits - bits_shift);
126 for (int i = 0; i < digit_shift; i++) {
127 digit_t d = input[i + K - digit_shift];
128 digit_t summand = (d << bits_shift) | input_carry;
129 result[i] = digit_add2(summand, add_carry, &add_carry);
130 input_carry = d >> (kDigitBits - bits_shift);
131 }
132 {
133 // result[digit_shift] = (add_carry + iK_part) - i0_part
134 digit_t d = input[K];
135 digit_t iK_part = (d << bits_shift) | input_carry;
136 digit_t iK_carry = d >> (kDigitBits - bits_shift);
137 digit_t sum = digit_add2(add_carry, iK_part, &add_carry);
138 // {iK_carry} is less than a full digit, so we can merge {add_carry}
139 // into it without overflow.
140 iK_carry += add_carry;
141 d = input[0];
142 digit_t i0_part = d << bits_shift;
143 result[digit_shift] = digit_sub(sum, i0_part, &borrow);
144 input_carry = d >> (kDigitBits - bits_shift);
145 if (digit_shift + 1 < K) {
146 d = input[1];
147 digit_t subtrahend = (d << bits_shift) | input_carry;
148 result[digit_shift + 1] =
149 digit_sub2(iK_carry, subtrahend, borrow, &borrow);
150 input_carry = d >> (kDigitBits - bits_shift);
151 }
152 }
153 for (int i = digit_shift + 2; i < K; i++) {
154 digit_t d = input[i - digit_shift];
155 digit_t subtrahend = (d << bits_shift) | input_carry;
156 result[i] = digit_sub2(0, subtrahend, borrow, &borrow);
157 input_carry = d >> (kDigitBits - bits_shift);
158 }
159 }
160 // The virtual 1 in result[K] should be eliminated by {borrow}. If there
161 // is no borrow, then the virtual initialization was too much. Subtract
162 // 2^K + 1.
163 result[K] = 0;
164 if (borrow != 1) {
165 borrow = 1;
166 for (int i = 0; i < K; i++) {
167 result[i] = digit_sub(result[i], borrow, &borrow);
168 if (borrow == 0) break;
169 }
170 if (borrow != 0) {
171 // The result must be 2^K.
172 for (int i = 0; i < K; i++) result[i] = 0;
173 result[K] = 1;
174 }
175 }
176}
177
178// Sets {result} := {input} * 2^{power_of_two} mod 2^{K} + 1.
179// This function is highly relevant for overall performance.
180void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
181 int zero_above = 0x7FFFFFFF) {
182 // The modulo-reduction amounts to a subtraction, which we combine
183 // with the shift as follows:
184 // input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
185 // result = [iX-1] .... [ i0 ] <---------- shift by {power_of_two}
186 // - [ iK ] .... [ iX ]
187 // where "X" is the index "K - digit_shift".
188 int digit_shift = power_of_two / kDigitBits;
189 int bits_shift = power_of_two % kDigitBits;
190 // By an analogous construction to the "digit_shift >= K" case,
191 // it turns out that:
192 // x * 2^{2K+m} == x * 2^m mod 2^K + 1.
193 while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'!
194 if (digit_shift >= K) {
195 return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
196 }
197 digit_t borrow = 0;
198 if (bits_shift == 0) {
199 // We do a single pass over {input}, starting by copying digits [i1] to
200 // [iX-1] to result indices digit_shift+1 to K-1.
201 int i = 1;
202 // Read input digits unless we know they are zero.
203 int cap = std::min(K - digit_shift, zero_above);
204 for (; i < cap; i++) {
205 result[i + digit_shift] = input[i];
206 }
207 // Any remaining work can hard-code the knowledge that input[i] == 0.
208 for (; i < K - digit_shift; i++) {
209 DCHECK(input[i] == 0);
210 result[i + digit_shift] = 0;
211 }
212 // Second phase: subtract input digits [iX] to [iK] from (virtually) zero-
213 // initialized result indices 0 to digit_shift-1.
214 cap = std::min(K, zero_above);
215 for (; i < cap; i++) {
216 digit_t d = input[i];
217 result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
218 }
219 // Any remaining work can hard-code the knowledge that input[i] == 0.
220 for (; i < K; i++) {
221 DCHECK(input[i] == 0);
222 result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
223 }
224 // Last step: subtract [iK] from [i0] and store at result index digit_shift.
225 result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
226 } else {
227 // Same flow as before, but taking bits_shift != 0 into account.
228 // First phase: result indices digit_shift+1 to K.
229 digit_t carry = 0;
230 int i = 0;
231 // Read input digits unless we know they are zero.
232 int cap = std::min(K - digit_shift, zero_above);
233 for (; i < cap; i++) {
234 digit_t d = input[i];
235 result[i + digit_shift] = (d << bits_shift) | carry;
236 carry = d >> (kDigitBits - bits_shift);
237 }
238 // Any remaining work can hard-code the knowledge that input[i] == 0.
239 for (; i < K - digit_shift; i++) {
240 DCHECK(input[i] == 0);
241 result[i + digit_shift] = carry;
242 carry = 0;
243 }
244 // Second phase: result indices 0 to digit_shift - 1.
245 cap = std::min(K, zero_above);
246 for (; i < cap; i++) {
247 digit_t d = input[i];
248 result[i - K + digit_shift] =
249 digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
250 carry = d >> (kDigitBits - bits_shift);
251 }
252 // Any remaining work can hard-code the knowledge that input[i] == 0.
253 if (i < K) {
254 DCHECK(input[i] == 0);
255 result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
256 carry = 0;
257 i++;
258 }
259 for (; i < K; i++) {
260 DCHECK(input[i] == 0);
261 result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
262 }
263 // Last step: compute result[digit_shift].
264 digit_t d = input[K];
265 result[digit_shift] = digit_sub2(
266 result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);
267 // No carry left.
268 DCHECK((d >> (kDigitBits - bits_shift)) == 0);
269 }
270 result[K] = 0;
271 for (int i = digit_shift + 1; i <= K && borrow > 0; i++) {
272 result[i] = digit_sub(result[i], borrow, &borrow);
273 }
274 if (borrow > 0) {
275 // Underflow means we subtracted too much. Add 2^K + 1.
276 digit_t carry = 1;
277 for (int i = 0; i <= K; i++) {
278 result[i] = digit_add2(result[i], carry, &carry);
279 if (carry == 0) break;
280 }
281 result[K] = digit_add2(result[K], 1, &carry);
282 }
283}
284
286// Part 2: FFT-based multiplication is very sensitive to appropriate choice
287// of parameters. The following functions choose the parameters that the
288// subsequent actual computation will use. This is partially based on formal
289// constraints and partially on experimentally-determined heuristics.
290
291struct Parameters {
292 // We never use the default values, but skipping zero-initialization
293 // of these fields saddens and confuses MSan.
294 int m{0};
295 int K{0};
296 int n{0};
297 int s{0};
298 int r{0};
299};
300
301// Computes parameters for the main calculation, given a bit length {N} and
302// an {m}. See the paper for details.
303void ComputeParameters(int N, int m, Parameters* params) {
304 N *= kDigitBits;
305 int n = 1 << m; // 2^m
306 int nhalf = n >> 1;
307 int s = (N + n - 1) >> m; // ceil(N/n)
308 s = RoundUp(s, kDigitBits);
309 int K = m + 2 * s + 1; // K must be at least this big...
310 K = RoundUp(K, nhalf); // ...and a multiple of n/2.
311 int r = K >> (m - 1); // Which multiple?
312
313 // We want recursive calls to make progress, so force K to be a multiple
314 // of 8 if it's above the recursion threshold. Otherwise, K must be a
315 // multiple of kDigitBits.
316 const int threshold = (K + 1 >= kFftInnerThreshold * kDigitBits)
317 ? 3 + kLog2DigitBits
318 : kLog2DigitBits;
319 int K_tz = CountTrailingZeros(K);
320 while (K_tz < threshold) {
321 K += (1 << K_tz);
322 r = K >> (m - 1);
323 K_tz = CountTrailingZeros(K);
324 }
325
326 DCHECK(K % kDigitBits == 0);
327 DCHECK(s % kDigitBits == 0);
328 params->K = K / kDigitBits;
329 params->s = s / kDigitBits;
330 params->n = n;
331 params->r = r;
332}
333
334// Computes parameters for recursive invocations ("inner layer").
335void ComputeParameters_Inner(int N, Parameters* params) {
336 int max_m = CountTrailingZeros(N);
337 int N_bits = BitLength(N);
338 int m = N_bits - 4; // Don't let s get too small.
339 m = std::min(max_m, m);
340 N *= kDigitBits;
341 int n = 1 << m; // 2^m
342 // We can't round up s in the inner layer, because N = n*s is fixed.
343 int s = N >> m;
344 DCHECK(N == s * n);
345 int K = m + 2 * s + 1; // K must be at least this big...
346 K = RoundUp(K, n); // ...and a multiple of n and kDigitBits.
347 K = RoundUp(K, kDigitBits);
348 params->r = K >> m; // Which multiple?
349 DCHECK(K % kDigitBits == 0);
350 DCHECK(s % kDigitBits == 0);
351 params->K = K / kDigitBits;
352 params->s = s / kDigitBits;
353 params->n = n;
354 params->m = m;
355}
356
357int PredictInnerK(int N) {
358 Parameters params;
359 ComputeParameters_Inner(N, &params);
360 return params.K;
361}
362
363// Applies heuristics to decide whether {m} should be decremented, by looking
364// at what would happen to {K} and {s} if {m} was decremented.
365bool ShouldDecrementM(const Parameters& current, const Parameters& next,
366 const Parameters& after_next) {
367 // K == 64 seems to work particularly well.
368 if (current.K == 64 && next.K >= 112) return false;
369 // Small values for s are never efficient.
370 if (current.s < 6) return true;
371 // The time is roughly determined by K * n. When we decrement m, then
372 // n always halves, and K usually gets bigger, by up to 2x.
373 // For not-quite-so-small s, look at how much bigger K would get: if
374 // the K increase is small enough, making n smaller is worth it.
375 // Empirically, it's most meaningful to look at the K *after* next.
376 // The specific threshold values have been chosen by running many
377 // benchmarks on inputs of many sizes, and manually selecting thresholds
378 // that seemed to produce good results.
379 double factor = static_cast<double>(after_next.K) / current.K;
380 if ((current.s == 6 && factor < 3.85) || // --
381 (current.s == 7 && factor < 3.73) || // --
382 (current.s == 8 && factor < 3.55) || // --
383 (current.s == 9 && factor < 3.50) || // --
384 factor < 3.4) {
385 return true;
386 }
387 // If K is just below the recursion threshold, make sure we do recurse,
388 // unless doing so would be particularly inefficient (large inner_K).
389 // If K is just above the recursion threshold, doubling it often makes
390 // the inner call more efficient.
391 if (current.K >= 160 && current.K < 250 && PredictInnerK(next.K) < 28) {
392 return true;
393 }
394 // If we found no reason to decrement, keep m as large as possible.
395 return false;
396}
397
398// Decides what parameters to use for a given input bit length {N}.
399// Returns the chosen m.
400int GetParameters(int N, Parameters* params) {
401 int N_bits = BitLength(N);
402 int max_m = N_bits - 3; // Larger m make s too small.
403 max_m = std::max(kLog2DigitBits, max_m); // Smaller m break the logic below.
404 int m = max_m;
405 Parameters current;
406 ComputeParameters(N, m, &current);
407 Parameters next;
408 ComputeParameters(N, m - 1, &next);
409 while (m > 2) {
410 Parameters after_next;
411 ComputeParameters(N, m - 2, &after_next);
412 if (ShouldDecrementM(current, next, after_next)) {
413 m--;
414 current = next;
415 next = after_next;
416 } else {
417 break;
418 }
419 }
420 *params = current;
421 return m;
422}
423
425// Part 3: Fast Fourier Transformation.
426
427class FFTContainer {
428 public:
429 // {n} is the number of chunks, whose length is {K}+1.
430 // {K} determines F_n = 2^(K * kDigitBits) + 1.
431 FFTContainer(int n, int K, ProcessorImpl* processor)
432 : n_(n), K_(K), length_(K + 1), processor_(processor) {
433 storage_ = new digit_t[length_ * n_];
434 part_ = new digit_t*[n_];
435 digit_t* ptr = storage_;
436 for (int i = 0; i < n; i++, ptr += length_) {
437 part_[i] = ptr;
438 }
439 temp_ = new digit_t[length_ * 2];
440 }
441 FFTContainer() = delete;
442 FFTContainer(const FFTContainer&) = delete;
443 FFTContainer& operator=(const FFTContainer&) = delete;
444
445 ~FFTContainer() {
446 delete[] storage_;
447 delete[] part_;
448 delete[] temp_;
449 }
450
451 void Start_Default(Digits X, int chunk_size, int theta, int omega);
452 void Start(Digits X, int chunk_size, int theta, int omega);
453
454 void NormalizeAndRecombine(int omega, int m, RWDigits Z, int chunk_size);
455 void CounterWeightAndRecombine(int theta, int m, RWDigits Z, int chunk_size);
456
457 void FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
458 digit_t* temp);
459 void FFT_Recurse(int start, int half, int omega, digit_t* temp);
460
461 void BackwardFFT(int start, int len, int omega);
462 void BackwardFFT_Threadsafe(int start, int len, int omega, digit_t* temp);
463
464 void PointwiseMultiply(const FFTContainer& other);
465 void DoPointwiseMultiplication(const FFTContainer& other, int start, int end,
466 digit_t* temp);
467
468 int length() const { return length_; }
469
470 private:
471 const int n_; // Number of parts.
472 const int K_; // Always length_ - 1.
473 const int length_; // Length of each part, in digits.
474 ProcessorImpl* processor_;
475 digit_t* storage_; // Combined storage of all parts.
476 digit_t** part_; // Pointers to each part.
477 digit_t* temp_; // Temporary storage with size 2 * length_.
478};
479
480inline void CopyAndZeroExtend(digit_t* dst, const digit_t* src,
481 int digits_to_copy, size_t total_bytes) {
482 size_t bytes_to_copy = digits_to_copy * sizeof(digit_t);
483 memcpy(dst, static_cast<const void*>(src), bytes_to_copy);
484 memset(dst + digits_to_copy, 0, total_bytes - bytes_to_copy);
485}
486
487// Reads {X} into the FFTContainer's internal storage, dividing it into chunks
488// while doing so; then performs the forward FFT.
489void FFTContainer::Start_Default(Digits X, int chunk_size, int theta,
490 int omega) {
491 int len = X.len();
492 const digit_t* pointer = X.digits();
493 const size_t part_length_in_bytes = length_ * sizeof(digit_t);
494 int current_theta = 0;
495 int i = 0;
496 for (; i < n_ && len > 0; i++, current_theta += theta) {
497 chunk_size = std::min(chunk_size, len);
498 // For invocations via MultiplyFFT_Inner, X.len() == n_ * chunk_size + 1,
499 // because the outer layer's "K" is passed as the inner layer's "N".
500 // Since X is (mod Fn)-normalized on the outer layer, there is the rare
501 // corner case where X[n_ * chunk_size] == 1. Detect that case, and handle
502 // the extra bit as part of the last chunk; we always have the space.
503 if (i == n_ - 1 && len == chunk_size + 1) {
504 DCHECK(X[n_ * chunk_size] <= 1);
505 DCHECK(length_ >= chunk_size + 1);
506 chunk_size++;
507 }
508 if (current_theta != 0) {
509 // Multiply with theta^i, and reduce modulo 2^K + 1.
510 // We pass theta as a shift amount; it really means 2^theta.
511 CopyAndZeroExtend(temp_, pointer, chunk_size, part_length_in_bytes);
512 ShiftModFn(part_[i], temp_, current_theta, K_, chunk_size);
513 } else {
514 CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
515 }
516 pointer += chunk_size;
517 len -= chunk_size;
518 }
519 DCHECK(len == 0);
520 for (; i < n_; i++) {
521 memset(part_[i], 0, part_length_in_bytes);
522 }
523 FFT_ReturnShuffledThreadsafe(0, n_, omega, temp_);
524}
525
526// This version of Start is optimized for the case where ~half of the
527// container will be filled with padding zeros.
528void FFTContainer::Start(Digits X, int chunk_size, int theta, int omega) {
529 int len = X.len();
530 if (len > n_ * chunk_size / 2) {
531 return Start_Default(X, chunk_size, theta, omega);
532 }
533 DCHECK(theta == 0);
534 const digit_t* pointer = X.digits();
535 const size_t part_length_in_bytes = length_ * sizeof(digit_t);
536 int nhalf = n_ / 2;
537 // Unrolled first iteration.
538 CopyAndZeroExtend(part_[0], pointer, chunk_size, part_length_in_bytes);
539 CopyAndZeroExtend(part_[nhalf], pointer, chunk_size, part_length_in_bytes);
540 pointer += chunk_size;
541 len -= chunk_size;
542 int i = 1;
543 for (; i < nhalf && len > 0; i++) {
544 chunk_size = std::min(chunk_size, len);
545 CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
546 int w = omega * i;
547 ShiftModFn(part_[i + nhalf], part_[i], w, K_, chunk_size);
548 pointer += chunk_size;
549 len -= chunk_size;
550 }
551 for (; i < nhalf; i++) {
552 memset(part_[i], 0, part_length_in_bytes);
553 memset(part_[i + nhalf], 0, part_length_in_bytes);
554 }
555 FFT_Recurse(0, nhalf, omega, temp_);
556}
557
558// Forward transformation.
559// We use the "DIF" aka "decimation in frequency" transform, because it
560// leaves the result in "bit reversed" order, which is precisely what we
561// need as input for the "DIT" aka "decimation in time" backwards transform.
562void FFTContainer::FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
563 digit_t* temp) {
564 DCHECK((len & 1) == 0); // {len} must be even.
565 int half = len / 2;
566 SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
567 length_);
568 for (int k = 1; k < half; k++) {
569 SumDiff(part_[start + k], temp, part_[start + k], part_[start + half + k],
570 length_);
571 int w = omega * k;
572 ShiftModFn(part_[start + half + k], temp, w, K_);
573 }
574 FFT_Recurse(start, half, omega, temp);
575}
576
577// Recursive step of the above, factored out for additional callers.
578void FFTContainer::FFT_Recurse(int start, int half, int omega, digit_t* temp) {
579 if (half > 1) {
580 FFT_ReturnShuffledThreadsafe(start, half, 2 * omega, temp);
581 FFT_ReturnShuffledThreadsafe(start + half, half, 2 * omega, temp);
582 }
583}
584
585// Backward transformation.
586// We use the "DIT" aka "decimation in time" transform here, because it
587// turns bit-reversed input into normally sorted output.
588void FFTContainer::BackwardFFT(int start, int len, int omega) {
589 BackwardFFT_Threadsafe(start, len, omega, temp_);
590}
591
592void FFTContainer::BackwardFFT_Threadsafe(int start, int len, int omega,
593 digit_t* temp) {
594 DCHECK((len & 1) == 0); // {len} must be even.
595 int half = len / 2;
596 // Don't recurse for half == 2, as PointwiseMultiply already performed
597 // the first level of the backwards FFT.
598 if (half > 2) {
599 BackwardFFT_Threadsafe(start, half, 2 * omega, temp);
600 BackwardFFT_Threadsafe(start + half, half, 2 * omega, temp);
601 }
602 SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
603 length_);
604 for (int k = 1; k < half; k++) {
605 int w = omega * (len - k);
606 ShiftModFn(temp, part_[start + half + k], w, K_);
607 SumDiff(part_[start + k], part_[start + half + k], part_[start + k], temp,
608 length_);
609 }
610}
611
612// Recombines the result's parts into {Z}, after backwards FFT.
613void FFTContainer::NormalizeAndRecombine(int omega, int m, RWDigits Z,
614 int chunk_size) {
615 Z.Clear();
616 int z_index = 0;
617 const int shift = n_ * omega - m;
618 for (int i = 0; i < n_; i++, z_index += chunk_size) {
619 digit_t* part = part_[i];
620 ShiftModFn(temp_, part, shift, K_);
621 digit_t carry = 0;
622 int zi = z_index;
623 int j = 0;
624 for (; j < length_ && zi < Z.len(); j++, zi++) {
625 Z[zi] = digit_add3(Z[zi], temp_[j], carry, &carry);
626 }
627 for (; j < length_; j++) {
628 DCHECK(temp_[j] == 0);
629 }
630 if (carry != 0) {
631 DCHECK(zi < Z.len());
632 Z[zi] = carry;
633 }
634 }
635}
636
637// Helper function for {CounterWeightAndRecombine} below.
638bool ShouldBeNegative(const digit_t* x, int xlen, digit_t threshold, int s) {
639 if (x[2 * s] >= threshold) return true;
640 for (int i = 2 * s + 1; i < xlen; i++) {
641 if (x[i] > 0) return true;
642 }
643 return false;
644}
645
646// Same as {NormalizeAndRecombine} above, but for the needs of the recursive
647// invocation ("inner layer") of FFT multiplication, where an additional
648// counter-weighting step is required.
649void FFTContainer::CounterWeightAndRecombine(int theta, int m, RWDigits Z,
650 int s) {
651 Z.Clear();
652 int z_index = 0;
653 for (int k = 0; k < n_; k++, z_index += s) {
654 int shift = -theta * k - m;
655 if (shift < 0) shift += 2 * n_ * theta;
656 DCHECK(shift >= 0);
657 digit_t* input = part_[k];
658 ShiftModFn(temp_, input, shift, K_);
659 int remaining_z = Z.len() - z_index;
660 if (ShouldBeNegative(temp_, length_, k + 1, s)) {
661 // Subtract F_n from input before adding to result. We use the following
662 // transformation (knowing that X < F_n):
663 // Z + (X - F_n) == Z - (F_n - X)
664 digit_t borrow_z = 0;
665 digit_t borrow_Fn = 0;
666 {
667 // i == 0:
668 digit_t d = digit_sub(1, temp_[0], &borrow_Fn);
669 Z[z_index] = digit_sub(Z[z_index], d, &borrow_z);
670 }
671 int i = 1;
672 for (; i < K_ && i < remaining_z; i++) {
673 digit_t d = digit_sub2(0, temp_[i], borrow_Fn, &borrow_Fn);
674 Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
675 }
676 DCHECK(i == K_ && K_ == length_ - 1);
677 for (; i < length_ && i < remaining_z; i++) {
678 digit_t d = digit_sub2(1, temp_[i], borrow_Fn, &borrow_Fn);
679 Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
680 }
681 DCHECK(borrow_Fn == 0);
682 for (; borrow_z > 0 && i < remaining_z; i++) {
683 Z[z_index + i] = digit_sub(Z[z_index + i], borrow_z, &borrow_z);
684 }
685 } else {
686 digit_t carry = 0;
687 int i = 0;
688 for (; i < length_ && i < remaining_z; i++) {
689 Z[z_index + i] = digit_add3(Z[z_index + i], temp_[i], carry, &carry);
690 }
691 for (; i < length_; i++) {
692 DCHECK(temp_[i] == 0);
693 }
694 for (; carry > 0 && i < remaining_z; i++) {
695 Z[z_index + i] = digit_add2(Z[z_index + i], carry, &carry);
696 }
697 // {carry} might be != 0 here if Z was negative before. That's fine.
698 }
699 }
700}
701
702// Main FFT function for recursive invocations ("inner layer").
703void MultiplyFFT_Inner(RWDigits Z, Digits X, Digits Y, const Parameters& params,
704 ProcessorImpl* processor) {
705 int omega = 2 * params.r; // really: 2^(2r)
706 int theta = params.r; // really: 2^r
707
708 FFTContainer a(params.n, params.K, processor);
709 a.Start_Default(X, params.s, theta, omega);
710 FFTContainer b(params.n, params.K, processor);
711 b.Start_Default(Y, params.s, theta, omega);
712
713 a.PointwiseMultiply(b);
714 if (processor->should_terminate()) return;
715
716 FFTContainer& c = a;
717 c.BackwardFFT(0, params.n, omega);
718
719 c.CounterWeightAndRecombine(theta, params.m, Z, params.s);
720}
721
722// Actual implementation of pointwise multiplications.
723void FFTContainer::DoPointwiseMultiplication(const FFTContainer& other,
724 int start, int end,
725 digit_t* temp) {
726 // The (K_ & 3) != 0 condition makes sure that the inner FFT gets
727 // to split the work into at least 4 chunks.
728 bool use_fft = length_ >= kFftInnerThreshold && (K_ & 3) == 0;
729 Parameters params;
730 if (use_fft) ComputeParameters_Inner(K_, &params);
731 RWDigits result(temp, 2 * length_);
732 for (int i = start; i < end; i++) {
733 Digits A(part_[i], length_);
734 Digits B(other.part_[i], length_);
735 if (use_fft) {
736 MultiplyFFT_Inner(result, A, B, params, processor_);
737 } else {
738 processor_->Multiply(result, A, B);
739 }
740 if (processor_->should_terminate()) return;
741 ModFnDoubleWidth(part_[i], result.digits(), length_);
742 // To improve cache friendliness, we perform the first level of the
743 // backwards FFT here.
744 if ((i & 1) == 1) {
745 SumDiff(part_[i - 1], part_[i], part_[i - 1], part_[i], length_);
746 }
747 }
748}
749
750// Convenient entry point for pointwise multiplications.
751void FFTContainer::PointwiseMultiply(const FFTContainer& other) {
752 DCHECK(n_ == other.n_);
753 DoPointwiseMultiplication(other, 0, n_, temp_);
754}
755
756} // namespace
757
759// Part 4: Tying everything together into a multiplication algorithm.
760
761// TODO(jkummerow): Consider doing a "Mersenne transform" and CRT reconstruction
762// of the final result. Might yield a few percent of perf improvement.
763
764// TODO(jkummerow): Consider implementing the "sqrt(2) trick".
765// Gaudry/Kruppa/Zimmerman report that it saved them around 10%.
766
767void ProcessorImpl::MultiplyFFT(RWDigits Z, Digits X, Digits Y) {
768 Parameters params;
769 int m = GetParameters(X.len() + Y.len(), &params);
770 int omega = params.r; // really: 2^r
771
772 FFTContainer a(params.n, params.K, this);
773 a.Start(X, params.s, 0, omega);
774 if (X == Y) {
775 // Squaring.
776 a.PointwiseMultiply(a);
777 } else {
778 FFTContainer b(params.n, params.K, this);
779 b.Start(Y, params.s, 0, omega);
780 a.PointwiseMultiply(b);
781 }
782 if (should_terminate()) return;
783
784 a.BackwardFFT(0, params.n, omega);
785 a.NormalizeAndRecombine(omega, m, Z, params.s);
786}
787
788} // namespace bigint
789} // namespace v8
int start
int end
LineAndColumn current
too high values may cause the compiler to set high thresholds for inlining to as much as possible avoid inlined allocation of objects that cannot escape trace load stores from virtual maglev objects use TurboFan fast string builder analyze liveness of environment slots and zap dead values trace TurboFan load elimination emit data about basic block usage in builtins to this enable builtin reordering when run mksnapshot flag for emit warnings when applying builtin profile data verify register allocation in TurboFan randomly schedule instructions to stress dependency tracking enable store store elimination in TurboFan rewrite far to near simulate GC compiler thread race related to allow float parameters to be passed in simulator mode JS Wasm Run additional turbo_optimize_inlined_js_wasm_wrappers enable experimental feedback collection in generic lowering enable Turboshaft s WasmLoadElimination enable Turboshaft s low level load elimination for JS enable Turboshaft s escape analysis for string concatenation use enable Turbolev features that we want to ship in the not too far future trace individual Turboshaft reduction steps trace intermediate Turboshaft reduction steps invocation count threshold for early optimization Enables optimizations which favor memory size over execution speed Enables sampling allocation profiler with X as a sample interval min size of a semi the new space consists of two semi spaces max size of the Collect garbage after Collect garbage after keeps maps alive for< n > old space garbage collections print one detailed trace line in allocation gc speed threshold for starting incremental marking via a task in percent of available threshold for starting incremental marking immediately in percent of available Use a single schedule for determining a marking schedule between JS and C objects schedules the minor GC task with kUserVisible priority max worker number of concurrent for NumberOfWorkerThreads start background threads that allocate memory concurrent_array_buffer_sweeping use parallel threads to clear weak refs in the atomic pause trace progress of the incremental marking trace object counts and memory usage report a tick only when allocated zone memory changes by this amount TracingFlags::gc_stats TracingFlags::gc_stats track native contexts that are expected to be garbage collected verify heap pointers before and after GC memory reducer runs GC with ReduceMemoryFootprint flag Maximum number of memory reducer GCs scheduled Old gen GC speed is computed directly from gc tracer counters Perform compaction on full GCs based on V8 s default heuristics Perform compaction on every full GC Perform code space compaction when finalizing a full GC with stack Stress GC compaction to flush out bugs with moving objects flush of baseline code when it has not been executed recently Use time base code flushing instead of age Use a progress bar to scan large objects in increments when incremental marking is active force incremental marking for small heaps and run it more often force marking at random points between and X(inclusive) percent " "of the regular marking start limit") DEFINE_INT(stress_scavenge
std::optional< TNode< JSArray > > a
ZoneVector< RpoNumber > & result
int x
#define xlen
int K
Definition mul-fft.cc:295
int s
Definition mul-fft.cc:297
ProcessorImpl * processor_
Definition mul-fft.cc:474
int m
Definition mul-fft.cc:294
int n
Definition mul-fft.cc:296
const int length_
Definition mul-fft.cc:473
digit_t * temp_
Definition mul-fft.cc:477
const int n_
Definition mul-fft.cc:471
int r
Definition mul-fft.cc:298
digit_t ** part_
Definition mul-fft.cc:476
digit_t * storage_
Definition mul-fft.cc:475
const int K_
Definition mul-fft.cc:472
intptr_t signed_digit_t
Definition bigint.h:35
constexpr int BitLength(int n)
Definition util.h:65
static constexpr int kDigitBits
Definition bigint.h:51
constexpr int kFftInnerThreshold
digit_t digit_add3(digit_t a, digit_t b, digit_t c, digit_t *carry)
constexpr int CountTrailingZeros(uint32_t value)
Definition util.h:54
digit_t digit_sub2(digit_t a, digit_t b, digit_t borrow_in, digit_t *borrow_out)
digit_t digit_sub(digit_t a, digit_t b, digit_t *borrow)
uintptr_t digit_t
Definition bigint.h:34
digit_t digit_add2(digit_t a, digit_t b, digit_t *carry)
constexpr int RoundUp(int x, int y)
Definition util.h:25
constexpr int B
constexpr int N
constexpr int A
#define DCHECK(condition)
Definition logging.h:482