1use super::RECURSION_THRESHOLD;
2use crate::fastdiv::Div64;
3use core::{fmt::Debug, iter::zip};
4
5#[allow(unused_imports)]
6use pulp::*;
7
8pub(crate) trait PrimeModulus: Debug + Copy {
9 type Div: Debug + Copy;
10
11 fn add(self, a: u64, b: u64) -> u64;
12 fn sub(self, a: u64, b: u64) -> u64;
13 fn mul(p: Self::Div, a: u64, b: u64) -> u64;
14}
15
16#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
17pub(crate) trait PrimeModulusV3: Debug + Copy {
18 type Div: Debug + Copy;
19
20 fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
21 fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
22 fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
23}
24
25#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
26#[cfg(feature = "nightly")]
27pub(crate) trait PrimeModulusV4: Debug + Copy {
28 type Div: Debug + Copy;
29
30 fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
31 fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
32 fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
33}
34
35#[derive(Copy, Clone, Debug)]
36pub struct Solinas;
37
38impl Solinas {
39 pub const P: u64 = ((1u128 << 64) - (1u128 << 32) + 1u128) as u64;
40}
41
42impl PrimeModulus for u64 {
43 type Div = Div64;
44
45 #[inline(always)]
46 fn add(self, a: u64, b: u64) -> u64 {
47 let p = self;
48 let neg_b = p - b;
53 if a >= neg_b {
54 a - neg_b
55 } else {
56 a + b
57 }
58 }
59
60 #[inline(always)]
61 fn sub(self, a: u64, b: u64) -> u64 {
62 let p = self;
63 let neg_b = p - b;
64 if a >= b {
65 a - b
66 } else {
67 a + neg_b
68 }
69 }
70
71 #[inline(always)]
72 fn mul(p: Self::Div, a: u64, b: u64) -> u64 {
73 Div64::rem_u128(a as u128 * b as u128, p)
74 }
75}
76
77impl PrimeModulus for Solinas {
78 type Div = ();
79
80 #[inline(always)]
81 fn add(self, a: u64, b: u64) -> u64 {
82 let p = Self::P;
83 let neg_b = p - b;
84 if a >= neg_b {
85 a - neg_b
86 } else {
87 a + b
88 }
89 }
90
91 #[inline(always)]
92 fn sub(self, a: u64, b: u64) -> u64 {
93 let p = Self::P;
94 let neg_b = p - b;
95 if a >= b {
96 a - b
97 } else {
98 a + neg_b
99 }
100 }
101
102 #[inline(always)]
103 fn mul(p: Self::Div, a: u64, b: u64) -> u64 {
104 let _ = p;
105 let p = Self::P;
106
107 let wide = a as u128 * b as u128;
108
109 let lo = wide as u64;
111 let hi = (wide >> 64) as u64;
112 let mid = hi & 0x0000_0000_FFFF_FFFF;
113 let hi = (hi & 0xFFFF_FFFF_0000_0000) >> 32;
114
115 let mut low2 = lo.wrapping_sub(hi);
116 if hi > lo {
117 low2 = low2.wrapping_add(p);
118 }
119
120 let mut product = mid << 32;
121 product -= mid;
122
123 let mut result = low2.wrapping_add(product);
124 if (result < product) || (result >= p) {
125 result = result.wrapping_sub(p);
126 }
127 result
128 }
129}
130
131#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
132impl PrimeModulusV3 for u64 {
133 type Div = (u64, u64, u64, u64, u64);
134
135 #[inline(always)]
136 fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
137 let p = simd.splat_u64x4(self);
138 let neg_b = simd.wrapping_sub_u64x4(p, b);
139 let not_a_ge_neg_b = simd.cmp_gt_u64x4(neg_b, a);
140 simd.select_u64x4(
141 not_a_ge_neg_b,
142 simd.wrapping_add_u64x4(a, b),
143 simd.wrapping_sub_u64x4(a, neg_b),
144 )
145 }
146
147 #[inline(always)]
148 fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
149 let p = simd.splat_u64x4(self);
150 let neg_b = simd.wrapping_sub_u64x4(p, b);
151 let not_a_ge_b = simd.cmp_gt_u64x4(b, a);
152 simd.select_u64x4(
153 not_a_ge_b,
154 simd.wrapping_add_u64x4(a, neg_b),
155 simd.wrapping_sub_u64x4(a, b),
156 )
157 }
158
159 #[inline(always)]
160 fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
161 #[inline(always)]
162 fn mul_with_carry(simd: crate::V3, l: u64x4, r: u64x4, c: u64x4) -> (u64x4, u64x4) {
163 let (lo, hi) = simd.widening_mul_u64x4(l, r);
164 let lo_plus_c = simd.wrapping_add_u64x4(lo, c);
165 let overflow = cast(simd.cmp_gt_u64x4(lo, lo_plus_c));
166 (lo_plus_c, simd.wrapping_sub_u64x4(hi, overflow))
167 }
168
169 #[inline(always)]
170 fn mul_u256_u64(
171 simd: crate::V3,
172 lhs0: u64x4,
173 lhs1: u64x4,
174 lhs2: u64x4,
175 lhs3: u64x4,
176 rhs: u64x4,
177 ) -> (u64x4, u64x4, u64x4, u64x4, u64x4) {
178 let (x0, carry) = simd.widening_mul_u64x4(lhs0, rhs);
179 let (x1, carry) = mul_with_carry(simd, lhs1, rhs, carry);
180 let (x2, carry) = mul_with_carry(simd, lhs2, rhs, carry);
181 let (x3, carry) = mul_with_carry(simd, lhs3, rhs, carry);
182 (x0, x1, x2, x3, carry)
183 }
184
185 #[inline(always)]
186 fn wrapping_mul_u256_u128(
187 simd: crate::V3,
188 lhs0: u64x4,
189 lhs1: u64x4,
190 lhs2: u64x4,
191 lhs3: u64x4,
192 rhs0: u64x4,
193 rhs1: u64x4,
194 ) -> (u64x4, u64x4, u64x4, u64x4) {
195 let (x0, x1, x2, x3, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs0);
196 let (y0, y1, y2, _, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs1);
197
198 let z0 = x0;
199
200 let z1 = simd.wrapping_add_u64x4(x1, y0);
201 let carry = cast(simd.cmp_gt_u64x4(x1, z1));
202
203 let z2 = simd.wrapping_add_u64x4(x2, y1);
204 let o0 = cast(simd.cmp_gt_u64x4(x2, z2));
205 let o1 = cast(simd.cmp_eq_u64x4(z2, carry));
206 let z2 = simd.wrapping_sub_u64x4(z2, carry);
207 let carry = simd.or_u64x4(o0, o1);
208
209 let z3 = simd.wrapping_add_u64x4(x3, y2);
210 let z3 = simd.wrapping_sub_u64x4(z3, carry);
211
212 (z0, z1, z2, z3)
213 }
214
215 let (p, p_div0, p_div1, p_div2, p_div3) = p;
216
217 let p = simd.splat_u64x4(p as _);
218 let p_div0 = simd.splat_u64x4(p_div0 as _);
219 let p_div1 = simd.splat_u64x4(p_div1 as _);
220 let p_div2 = simd.splat_u64x4(p_div2 as _);
221 let p_div3 = simd.splat_u64x4(p_div3 as _);
222
223 let (lo, hi) = simd.widening_mul_u64x4(a, b);
224 let (low_bits0, low_bits1, low_bits2, low_bits3) =
225 wrapping_mul_u256_u128(simd, p_div0, p_div1, p_div2, p_div3, lo, hi);
226
227 mul_u256_u64(simd, low_bits0, low_bits1, low_bits2, low_bits3, p).4
228 }
229}
230
231#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
232impl PrimeModulusV3 for Solinas {
233 type Div = ();
234
235 #[inline(always)]
236 fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
237 let p = simd.splat_u64x4(Self::P);
238 let neg_b = simd.wrapping_sub_u64x4(p, b);
239 let not_a_ge_neg_b = simd.cmp_gt_u64x4(neg_b, a);
240 simd.select_u64x4(
241 not_a_ge_neg_b,
242 simd.wrapping_add_u64x4(a, b),
243 simd.wrapping_sub_u64x4(a, neg_b),
244 )
245 }
246
247 #[inline(always)]
248 fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
249 let p = simd.splat_u64x4(Self::P);
250 let neg_b = simd.wrapping_sub_u64x4(p, b);
251 let not_a_ge_b = simd.cmp_gt_u64x4(b, a);
252 simd.select_u64x4(
253 not_a_ge_b,
254 simd.wrapping_add_u64x4(a, neg_b),
255 simd.wrapping_sub_u64x4(a, b),
256 )
257 }
258
259 #[inline(always)]
260 fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
261 let _ = p;
262
263 let p = simd.splat_u64x4(Self::P as _);
264
265 let (lo, hi) = simd.widening_mul_u64x4(a, b);
267 let mid = simd.and_u64x4(hi, simd.splat_u64x4(0x0000_0000_FFFF_FFFF));
268 let hi = simd.and_u64x4(hi, simd.splat_u64x4(0xFFFF_FFFF_0000_0000));
269 let hi = simd.shr_const_u64x4::<32>(hi);
270
271 let low2 = simd.wrapping_sub_u64x4(lo, hi);
272 let low2 = simd.select_u64x4(
273 simd.cmp_gt_u64x4(hi, lo),
274 simd.wrapping_add_u64x4(low2, p),
275 low2,
276 );
277
278 let product = simd.shl_const_u64x4::<32>(mid);
279 let product = simd.wrapping_sub_u64x4(product, mid);
280
281 let result = simd.wrapping_add_u64x4(low2, product);
282
283 let product_gt_result = simd.cmp_gt_u64x4(product, result);
287 let p_gt_result = simd.cmp_gt_u64x4(p, result);
288 let not_cond = simd.andnot_m64x4(product_gt_result, p_gt_result);
289
290 simd.select_u64x4(not_cond, result, simd.wrapping_sub_u64x4(result, p))
291 }
292}
293
294#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
295#[cfg(feature = "nightly")]
296impl PrimeModulusV4 for u64 {
297 type Div = (u64, u64, u64, u64, u64);
298
299 #[inline(always)]
300 fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
301 let p = simd.splat_u64x8(self);
302 let neg_b = simd.wrapping_sub_u64x8(p, b);
303 let a_ge_neg_b = simd.cmp_ge_u64x8(a, neg_b);
304 simd.select_u64x8(
305 a_ge_neg_b,
306 simd.wrapping_sub_u64x8(a, neg_b),
307 simd.wrapping_add_u64x8(a, b),
308 )
309 }
310
311 #[inline(always)]
312 fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
313 let p = simd.splat_u64x8(self);
314 let neg_b = simd.wrapping_sub_u64x8(p, b);
315 let a_ge_b = simd.cmp_ge_u64x8(a, b);
316 simd.select_u64x8(
317 a_ge_b,
318 simd.wrapping_sub_u64x8(a, b),
319 simd.wrapping_add_u64x8(a, neg_b),
320 )
321 }
322
323 #[inline(always)]
324 fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
325 #[inline(always)]
326 fn mul_with_carry(simd: crate::V4, l: u64x8, r: u64x8, c: u64x8) -> (u64x8, u64x8) {
327 let (lo, hi) = simd.widening_mul_u64x8(l, r);
328 let lo_plus_c = simd.wrapping_add_u64x8(lo, c);
329 let overflow = simd.cmp_gt_u64x8(lo, lo_plus_c);
330
331 (
332 lo_plus_c,
333 simd.wrapping_sub_u64x8(hi, simd.convert_mask_b8_to_u64x8(overflow)),
334 )
335 }
336
337 #[inline(always)]
338 fn mul_u256_u64(
339 simd: crate::V4,
340 lhs0: u64x8,
341 lhs1: u64x8,
342 lhs2: u64x8,
343 lhs3: u64x8,
344 rhs: u64x8,
345 ) -> (u64x8, u64x8, u64x8, u64x8, u64x8) {
346 let (x0, carry) = simd.widening_mul_u64x8(lhs0, rhs);
347 let (x1, carry) = mul_with_carry(simd, lhs1, rhs, carry);
348 let (x2, carry) = mul_with_carry(simd, lhs2, rhs, carry);
349 let (x3, carry) = mul_with_carry(simd, lhs3, rhs, carry);
350 (x0, x1, x2, x3, carry)
351 }
352
353 #[inline(always)]
354 fn wrapping_mul_u256_u128(
355 simd: crate::V4,
356 lhs0: u64x8,
357 lhs1: u64x8,
358 lhs2: u64x8,
359 lhs3: u64x8,
360 rhs0: u64x8,
361 rhs1: u64x8,
362 ) -> (u64x8, u64x8, u64x8, u64x8) {
363 let (x0, x1, x2, x3, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs0);
364 let (y0, y1, y2, _, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs1);
365
366 let z0 = x0;
367
368 let z1 = simd.wrapping_add_u64x8(x1, y0);
369 let carry = simd.convert_mask_b8_to_u64x8(simd.cmp_gt_u64x8(x1, z1));
370
371 let z2 = simd.wrapping_add_u64x8(x2, y1);
372 let o0 = simd.cmp_gt_u64x8(x2, z2);
373 let o1 = simd.cmp_eq_u64x8(z2, carry);
374 let z2 = simd.wrapping_sub_u64x8(z2, carry);
375 let carry = simd.convert_mask_b8_to_u64x8(b8(o0.0 | o1.0));
376
377 let z3 = simd.wrapping_add_u64x8(x3, y2);
378 let z3 = simd.wrapping_sub_u64x8(z3, carry);
379
380 (z0, z1, z2, z3)
381 }
382
383 let (p, p_div0, p_div1, p_div2, p_div3) = p;
384
385 let p = simd.splat_u64x8(p);
386 let p_div0 = simd.splat_u64x8(p_div0);
387 let p_div1 = simd.splat_u64x8(p_div1);
388 let p_div2 = simd.splat_u64x8(p_div2);
389 let p_div3 = simd.splat_u64x8(p_div3);
390
391 let (lo, hi) = simd.widening_mul_u64x8(a, b);
392 let (low_bits0, low_bits1, low_bits2, low_bits3) =
393 wrapping_mul_u256_u128(simd, p_div0, p_div1, p_div2, p_div3, lo, hi);
394
395 mul_u256_u64(simd, low_bits0, low_bits1, low_bits2, low_bits3, p).4
396 }
397}
398
399#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
400#[cfg(feature = "nightly")]
401impl PrimeModulusV4 for Solinas {
402 type Div = ();
403
404 #[inline(always)]
405 fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
406 PrimeModulusV4::add(Self::P, simd, a, b)
407 }
408
409 #[inline(always)]
410 fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
411 PrimeModulusV4::sub(Self::P, simd, a, b)
412 }
413
414 #[inline(always)]
415 fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
416 let _ = p;
417
418 let p = simd.splat_u64x8(Self::P);
419
420 let (lo, hi) = simd.widening_mul_u64x8(a, b);
422 let mid = simd.and_u64x8(hi, simd.splat_u64x8(0x0000_0000_FFFF_FFFF));
423 let hi = simd.and_u64x8(hi, simd.splat_u64x8(0xFFFF_FFFF_0000_0000));
424 let hi = simd.shr_const_u64x8::<32>(hi);
425
426 let low2 = simd.wrapping_sub_u64x8(lo, hi);
427 let low2 = simd.select_u64x8(
428 simd.cmp_gt_u64x8(hi, lo),
429 simd.wrapping_add_u64x8(low2, p),
430 low2,
431 );
432
433 let product = simd.shl_const_u64x8::<32>(mid);
434 let product = simd.wrapping_sub_u64x8(product, mid);
435
436 let result = simd.wrapping_add_u64x8(low2, product);
437
438 let product_gt_result = simd.cmp_gt_u64x8(product, result);
442 let p_gt_result = simd.cmp_gt_u64x8(p, result);
443 let not_cond = b8(!product_gt_result.0 & p_gt_result.0);
444
445 simd.select_u64x8(not_cond, result, simd.wrapping_sub_u64x8(result, p))
446 }
447}
448
449pub(crate) fn fwd_breadth_first_scalar<P: PrimeModulus>(
450 data: &mut [u64],
451 p: P,
452 p_div: P::Div,
453 twid: &[u64],
454 recursion_depth: usize,
455 recursion_half: usize,
456) {
457 let n = data.len();
458 debug_assert!(n.is_power_of_two());
459
460 let mut t = n / 2;
461 let mut m = 1;
462 let mut w_idx = (m << recursion_depth) + recursion_half * m;
463
464 while m < n {
465 let w = &twid[w_idx..];
466
467 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
468 let (z0, z1) = data.split_at_mut(t);
469
470 for (z0, z1) in zip(z0, z1) {
471 let z1w = P::mul(p_div, *z1, w1);
472
473 (*z0, *z1) = (p.add(*z0, z1w), p.sub(*z0, z1w));
474 }
475 }
476
477 t /= 2;
478 m *= 2;
479 w_idx *= 2;
480 }
481}
482
483pub(crate) fn inv_breadth_first_scalar<P: PrimeModulus>(
484 data: &mut [u64],
485 p: P,
486 p_div: P::Div,
487 inv_twid: &[u64],
488 recursion_depth: usize,
489 recursion_half: usize,
490) {
491 let n = data.len();
492 debug_assert!(n.is_power_of_two());
493
494 let mut t = 1;
495 let mut m = n;
496 let mut w_idx = (m << recursion_depth) + recursion_half * m;
497
498 while m > 1 {
499 m /= 2;
500 w_idx /= 2;
501
502 let w = &inv_twid[w_idx..];
503
504 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
505 let (z0, z1) = data.split_at_mut(t);
506
507 for (z0, z1) in zip(z0, z1) {
508 (*z0, *z1) = (p.add(*z0, *z1), P::mul(p_div, p.sub(*z0, *z1), w1));
509 }
510 }
511
512 t *= 2;
513 }
514}
515
516pub(crate) fn inv_depth_first_scalar<P: PrimeModulus>(
517 data: &mut [u64],
518 p: P,
519 p_div: P::Div,
520 inv_twid: &[u64],
521 recursion_depth: usize,
522 recursion_half: usize,
523) {
524 let n = data.len();
525 debug_assert!(n.is_power_of_two());
526 if n <= RECURSION_THRESHOLD {
527 inv_breadth_first_scalar(data, p, p_div, inv_twid, recursion_depth, recursion_half);
528 } else {
529 let (data0, data1) = data.split_at_mut(n / 2);
530 inv_depth_first_scalar(
531 data0,
532 p,
533 p_div,
534 inv_twid,
535 recursion_depth + 1,
536 recursion_half * 2,
537 );
538 inv_depth_first_scalar(
539 data1,
540 p,
541 p_div,
542 inv_twid,
543 recursion_depth + 1,
544 recursion_half * 2 + 1,
545 );
546
547 let t = n / 2;
548 let m = 1;
549 let w_idx = (m << recursion_depth) + m * recursion_half;
550
551 let w = &inv_twid[w_idx..];
552
553 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
554 let (z0, z1) = data.split_at_mut(t);
555
556 for (z0, z1) in zip(z0, z1) {
557 (*z0, *z1) = (p.add(*z0, *z1), P::mul(p_div, p.sub(*z0, *z1), w1));
558 }
559 }
560 }
561}
562
563#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
564pub(crate) fn fwd_breadth_first_avx2<P: PrimeModulusV3>(
565 simd: crate::V3,
566 data: &mut [u64],
567 p: P,
568 p_div: P::Div,
569 twid: &[u64],
570 recursion_depth: usize,
571 recursion_half: usize,
572) {
573 struct Impl<'a, P: PrimeModulusV3> {
574 simd: crate::V3,
575 data: &'a mut [u64],
576 p: P,
577 p_div: P::Div,
578 twid: &'a [u64],
579 recursion_depth: usize,
580 recursion_half: usize,
581 }
582 impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
583 type Output = ();
584
585 #[inline(always)]
586 fn call(self) -> Self::Output {
587 let Self {
588 simd,
589 data,
590 p,
591 p_div,
592 twid,
593 recursion_depth,
594 recursion_half,
595 } = self;
596 let n = data.len();
597 debug_assert!(n.is_power_of_two());
598
599 let mut t = n / 2;
600 let mut m = 1;
601 let mut w_idx = (m << recursion_depth) + recursion_half * m;
602 while m < n / 4 {
603 let w = &twid[w_idx..];
604
605 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
606 let (z0, z1) = data.split_at_mut(t);
607 let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
608 let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
609 let w1 = simd.splat_u64x4(w1);
610
611 for (z0_, z1_) in zip(z0, z1) {
612 let mut z0 = cast(*z0_);
613 let mut z1 = cast(*z1_);
614 let z1w = P::mul(p_div, simd, z1, w1);
615 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
616 *z0_ = cast(z0);
617 *z1_ = cast(z1);
618 }
619 }
620
621 t /= 2;
622 m *= 2;
623 w_idx *= 2;
624 }
625
626 {
629 let w = pulp::as_arrays::<2, _>(&twid[w_idx..]).0;
630 let data = pulp::as_arrays_mut::<4, _>(data).0;
631 let data = pulp::as_arrays_mut::<2, _>(data).0;
632
633 for (z0z0z1z1, w1) in zip(data, w) {
634 let w1 = simd.permute2_u64x4(*w1);
635 let [mut z0, mut z1] = simd.interleave2_u64x4(cast(*z0z0z1z1));
636 let z1w = P::mul(p_div, simd, z1, w1);
637 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
638 *z0z0z1z1 = cast(simd.interleave2_u64x4([z0, z1]));
639 }
640
641 w_idx *= 2;
642 }
643
644 {
647 let w = pulp::as_arrays::<4, _>(&twid[w_idx..]).0;
648 let data = pulp::as_arrays_mut::<4, _>(data).0;
649 let data = pulp::as_arrays_mut::<2, _>(data).0;
650
651 for (z0z1, w1) in zip(data, w) {
652 let w1 = simd.permute1_u64x4(*w1);
653 let [mut z0, mut z1] = simd.interleave1_u64x4(cast(*z0z1));
654 let z1w = P::mul(p_div, simd, z1, w1);
655 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
656 *z0z1 = cast(simd.interleave1_u64x4([z0, z1]));
657 }
658 }
659 }
660 }
661 simd.vectorize(Impl {
662 simd,
663 data,
664 p,
665 p_div,
666 twid,
667 recursion_depth,
668 recursion_half,
669 });
670}
671
672#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
673pub(crate) fn inv_breadth_first_avx2<P: PrimeModulusV3>(
674 simd: crate::V3,
675 data: &mut [u64],
676 p: P,
677 p_div: P::Div,
678 inv_twid: &[u64],
679 recursion_depth: usize,
680 recursion_half: usize,
681) {
682 struct Impl<'a, P: PrimeModulusV3> {
683 simd: crate::V3,
684 data: &'a mut [u64],
685 p: P,
686 p_div: P::Div,
687 inv_twid: &'a [u64],
688 recursion_depth: usize,
689 recursion_half: usize,
690 }
691 impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
692 type Output = ();
693
694 #[inline(always)]
695 fn call(self) -> Self::Output {
696 let Self {
697 simd,
698 data,
699 p,
700 p_div,
701 inv_twid,
702 recursion_depth,
703 recursion_half,
704 } = self;
705
706 let n = data.len();
707 debug_assert!(n.is_power_of_two());
708
709 let mut t = 1;
710 let mut m = n;
711 let mut w_idx = (m << recursion_depth) + recursion_half * m;
712
713 {
716 m /= 2;
717 w_idx /= 2;
718
719 let w = pulp::as_arrays::<4, _>(&inv_twid[w_idx..]).0;
720 let data = pulp::as_arrays_mut::<4, _>(data).0;
721 let data = pulp::as_arrays_mut::<2, _>(data).0;
722
723 for (z0z1, w1) in zip(data, w) {
724 let w1 = simd.permute1_u64x4(*w1);
725 let [mut z0, mut z1] = simd.interleave1_u64x4(cast(*z0z1));
726 (z0, z1) = (
727 p.add(simd, z0, z1),
728 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
729 );
730 *z0z1 = cast(simd.interleave1_u64x4([z0, z1]));
731 }
732
733 t *= 2;
734 }
735
736 {
739 m /= 2;
740 w_idx /= 2;
741
742 let w = pulp::as_arrays::<2, _>(&inv_twid[w_idx..]).0;
743 let data = pulp::as_arrays_mut::<4, _>(data).0;
744 let data = pulp::as_arrays_mut::<2, _>(data).0;
745
746 for (z0z0z1z1, w1) in zip(data, w) {
747 let w1 = simd.permute2_u64x4(*w1);
748 let [mut z0, mut z1] = simd.interleave2_u64x4(cast(*z0z0z1z1));
749 (z0, z1) = (
750 p.add(simd, z0, z1),
751 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
752 );
753 *z0z0z1z1 = cast(simd.interleave2_u64x4([z0, z1]));
754 }
755
756 t *= 2;
757 }
758
759 while m > 1 {
760 m /= 2;
761 w_idx /= 2;
762
763 let w = &inv_twid[w_idx..];
764
765 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
766 let (z0, z1) = data.split_at_mut(t);
767 let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
768 let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
769 let w1 = simd.splat_u64x4(w1);
770
771 for (z0_, z1_) in zip(z0, z1) {
772 let mut z0 = cast(*z0_);
773 let mut z1 = cast(*z1_);
774 (z0, z1) = (
775 p.add(simd, z0, z1),
776 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
777 );
778 *z0_ = cast(z0);
779 *z1_ = cast(z1);
780 }
781 }
782
783 t *= 2;
784 }
785 }
786 }
787
788 simd.vectorize(Impl {
789 simd,
790 data,
791 p,
792 p_div,
793 inv_twid,
794 recursion_depth,
795 recursion_half,
796 });
797}
798
799#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
800#[cfg(feature = "nightly")]
801pub(crate) fn fwd_breadth_first_avx512<P: PrimeModulusV4>(
802 simd: crate::V4,
803 data: &mut [u64],
804 p: P,
805 p_div: P::Div,
806 twid: &[u64],
807 recursion_depth: usize,
808 recursion_half: usize,
809) {
810 struct Impl<'a, P: PrimeModulusV4> {
811 simd: crate::V4,
812 data: &'a mut [u64],
813 p: P,
814 p_div: P::Div,
815 twid: &'a [u64],
816 recursion_depth: usize,
817 recursion_half: usize,
818 }
819 impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
820 type Output = ();
821
822 #[inline(always)]
823 fn call(self) -> Self::Output {
824 let Self {
825 simd,
826 data,
827 p,
828 p_div,
829 twid,
830 recursion_depth,
831 recursion_half,
832 } = self;
833
834 let n = data.len();
835 debug_assert!(n.is_power_of_two());
836
837 let mut t = n / 2;
838 let mut m = 1;
839 let mut w_idx = (m << recursion_depth) + recursion_half * m;
840 while m < n / 8 {
841 let w = &twid[w_idx..];
842
843 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
844 let (z0, z1) = data.split_at_mut(t);
845 let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
846 let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
847 let w1 = simd.splat_u64x8(w1);
848
849 for (z0_, z1_) in zip(z0, z1) {
850 let mut z0 = cast(*z0_);
851 let mut z1 = cast(*z1_);
852 let z1w = P::mul(p_div, simd, z1, w1);
853 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
854 *z0_ = cast(z0);
855 *z1_ = cast(z1);
856 }
857 }
858
859 t /= 2;
860 m *= 2;
861 w_idx *= 2;
862 }
863
864 {
867 let w = pulp::as_arrays::<2, _>(&twid[w_idx..]).0;
868 let data = pulp::as_arrays_mut::<8, _>(data).0;
869 let data = pulp::as_arrays_mut::<2, _>(data).0;
870
871 for (z0z0z0z0z1z1z1z1, w1) in zip(data, w) {
872 let w1 = simd.permute4_u64x8(*w1);
873 let [mut z0, mut z1] = simd.interleave4_u64x8(cast(*z0z0z0z0z1z1z1z1));
874 let z1w = P::mul(p_div, simd, z1, w1);
875 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
876 *z0z0z0z0z1z1z1z1 = cast(simd.interleave4_u64x8([z0, z1]));
877 }
878
879 w_idx *= 2;
880 }
881
882 {
885 let w = pulp::as_arrays::<4, _>(&twid[w_idx..]).0;
886 let data = pulp::as_arrays_mut::<8, _>(data).0;
887 let data = pulp::as_arrays_mut::<2, _>(data).0;
888
889 for (z0z0z1z1, w1) in zip(data, w) {
890 let w1 = simd.permute2_u64x8(*w1);
891 let [mut z0, mut z1] = simd.interleave2_u64x8(cast(*z0z0z1z1));
892 let z1w = P::mul(p_div, simd, z1, w1);
893 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
894 *z0z0z1z1 = cast(simd.interleave2_u64x8([z0, z1]));
895 }
896
897 w_idx *= 2;
898 }
899
900 {
903 let w = pulp::as_arrays::<8, _>(&twid[w_idx..]).0;
904 let data = pulp::as_arrays_mut::<8, _>(data).0;
905 let data = pulp::as_arrays_mut::<2, _>(data).0;
906
907 for (z0z1, w1) in zip(data, w) {
908 let w1 = simd.permute1_u64x8(*w1);
909 let [mut z0, mut z1] = simd.interleave1_u64x8(cast(*z0z1));
910 let z1w = P::mul(p_div, simd, z1, w1);
911 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
912 *z0z1 = cast(simd.interleave1_u64x8([z0, z1]));
913 }
914 }
915 }
916 }
917
918 simd.vectorize(Impl {
919 simd,
920 data,
921 p,
922 p_div,
923 twid,
924 recursion_depth,
925 recursion_half,
926 });
927}
928
929#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
930#[cfg(feature = "nightly")]
931pub(crate) fn fwd_depth_first_avx512<P: PrimeModulusV4>(
932 simd: crate::V4,
933 data: &mut [u64],
934 p: P,
935 p_div: P::Div,
936 twid: &[u64],
937 recursion_depth: usize,
938 recursion_half: usize,
939) {
940 struct Impl<'a, P: PrimeModulusV4> {
941 simd: crate::V4,
942 data: &'a mut [u64],
943 p: P,
944 p_div: P::Div,
945 twid: &'a [u64],
946 recursion_depth: usize,
947 recursion_half: usize,
948 }
949 impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
950 type Output = ();
951
952 #[inline(always)]
953 fn call(self) -> Self::Output {
954 let Self {
955 simd,
956 data,
957 p,
958 p_div,
959 twid,
960 recursion_depth,
961 recursion_half,
962 } = self;
963
964 let n = data.len();
965 debug_assert!(n.is_power_of_two());
966
967 if n <= RECURSION_THRESHOLD {
968 fwd_breadth_first_avx512(
969 simd,
970 data,
971 p,
972 p_div,
973 twid,
974 recursion_depth,
975 recursion_half,
976 );
977 } else {
978 let t = n / 2;
979 let m = 1;
980 let w_idx = (m << recursion_depth) + m * recursion_half;
981
982 let w = &twid[w_idx..];
983
984 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
985 let (z0, z1) = data.split_at_mut(t);
986 let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
987 let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
988 let w1 = simd.splat_u64x8(w1);
989
990 for (z0_, z1_) in zip(z0, z1) {
991 let mut z0 = cast(*z0_);
992 let mut z1 = cast(*z1_);
993 let z1w = P::mul(p_div, simd, z1, w1);
994 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
995 *z0_ = cast(z0);
996 *z1_ = cast(z1);
997 }
998 }
999
1000 let (data0, data1) = data.split_at_mut(n / 2);
1001 fwd_depth_first_avx512(
1002 simd,
1003 data0,
1004 p,
1005 p_div,
1006 twid,
1007 recursion_depth + 1,
1008 recursion_half * 2,
1009 );
1010 fwd_depth_first_avx512(
1011 simd,
1012 data1,
1013 p,
1014 p_div,
1015 twid,
1016 recursion_depth + 1,
1017 recursion_half * 2 + 1,
1018 );
1019 }
1020 }
1021 }
1022
1023 simd.vectorize(Impl {
1024 simd,
1025 data,
1026 p,
1027 p_div,
1028 twid,
1029 recursion_depth,
1030 recursion_half,
1031 });
1032}
1033
1034#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1035#[cfg(feature = "nightly")]
1036pub(crate) fn inv_depth_first_avx512<P: PrimeModulusV4>(
1037 simd: crate::V4,
1038 data: &mut [u64],
1039 p: P,
1040 p_div: P::Div,
1041 inv_twid: &[u64],
1042 recursion_depth: usize,
1043 recursion_half: usize,
1044) {
1045 struct Impl<'a, P: PrimeModulusV4> {
1046 simd: crate::V4,
1047 data: &'a mut [u64],
1048 p: P,
1049 p_div: P::Div,
1050 inv_twid: &'a [u64],
1051 recursion_depth: usize,
1052 recursion_half: usize,
1053 }
1054 impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
1055 type Output = ();
1056
1057 #[inline(always)]
1058 fn call(self) -> Self::Output {
1059 let Self {
1060 simd,
1061 data,
1062 p,
1063 p_div,
1064 inv_twid,
1065 recursion_depth,
1066 recursion_half,
1067 } = self;
1068 let n = data.len();
1069 debug_assert!(n.is_power_of_two());
1070
1071 if n <= RECURSION_THRESHOLD {
1072 inv_breadth_first_avx512(
1073 simd,
1074 data,
1075 p,
1076 p_div,
1077 inv_twid,
1078 recursion_depth,
1079 recursion_half,
1080 );
1081 } else {
1082 let (data0, data1) = data.split_at_mut(n / 2);
1083 inv_depth_first_avx512(
1084 simd,
1085 data0,
1086 p,
1087 p_div,
1088 inv_twid,
1089 recursion_depth + 1,
1090 recursion_half * 2,
1091 );
1092 inv_depth_first_avx512(
1093 simd,
1094 data1,
1095 p,
1096 p_div,
1097 inv_twid,
1098 recursion_depth + 1,
1099 recursion_half * 2 + 1,
1100 );
1101
1102 let t = n / 2;
1103 let m = 1;
1104 let w_idx = (m << recursion_depth) + m * recursion_half;
1105
1106 let w = &inv_twid[w_idx..];
1107
1108 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1109 let (z0, z1) = data.split_at_mut(t);
1110 let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
1111 let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
1112 let w1 = simd.splat_u64x8(w1);
1113
1114 for (z0_, z1_) in zip(z0, z1) {
1115 let mut z0 = cast(*z0_);
1116 let mut z1 = cast(*z1_);
1117 (z0, z1) = (
1118 p.add(simd, z0, z1),
1119 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1120 );
1121 *z0_ = cast(z0);
1122 *z1_ = cast(z1);
1123 }
1124 }
1125 }
1126 }
1127 }
1128
1129 simd.vectorize(Impl {
1130 simd,
1131 data,
1132 p,
1133 p_div,
1134 inv_twid,
1135 recursion_depth,
1136 recursion_half,
1137 });
1138}
1139
1140#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1141pub(crate) fn inv_depth_first_avx2<P: PrimeModulusV3>(
1142 simd: crate::V3,
1143 data: &mut [u64],
1144 p: P,
1145 p_div: P::Div,
1146 inv_twid: &[u64],
1147 recursion_depth: usize,
1148 recursion_half: usize,
1149) {
1150 struct Impl<'a, P: PrimeModulusV3> {
1151 simd: crate::V3,
1152 data: &'a mut [u64],
1153 p: P,
1154 p_div: P::Div,
1155 inv_twid: &'a [u64],
1156 recursion_depth: usize,
1157 recursion_half: usize,
1158 }
1159 impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
1160 type Output = ();
1161
1162 #[inline(always)]
1163 fn call(self) -> Self::Output {
1164 let Self {
1165 simd,
1166 data,
1167 p,
1168 p_div,
1169 inv_twid,
1170 recursion_depth,
1171 recursion_half,
1172 } = self;
1173 let n = data.len();
1174 debug_assert!(n.is_power_of_two());
1175
1176 if n <= RECURSION_THRESHOLD {
1177 inv_breadth_first_avx2(
1178 simd,
1179 data,
1180 p,
1181 p_div,
1182 inv_twid,
1183 recursion_depth,
1184 recursion_half,
1185 );
1186 } else {
1187 let (data0, data1) = data.split_at_mut(n / 2);
1188 inv_depth_first_avx2(
1189 simd,
1190 data0,
1191 p,
1192 p_div,
1193 inv_twid,
1194 recursion_depth + 1,
1195 recursion_half * 2,
1196 );
1197 inv_depth_first_avx2(
1198 simd,
1199 data1,
1200 p,
1201 p_div,
1202 inv_twid,
1203 recursion_depth + 1,
1204 recursion_half * 2 + 1,
1205 );
1206
1207 let t = n / 2;
1208 let m = 1;
1209 let w_idx = (m << recursion_depth) + m * recursion_half;
1210
1211 let w = &inv_twid[w_idx..];
1212
1213 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1214 let (z0, z1) = data.split_at_mut(t);
1215 let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
1216 let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
1217 let w1 = simd.splat_u64x4(w1);
1218
1219 for (z0_, z1_) in zip(z0, z1) {
1220 let mut z0 = cast(*z0_);
1221 let mut z1 = cast(*z1_);
1222 (z0, z1) = (
1223 p.add(simd, z0, z1),
1224 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1225 );
1226 *z0_ = cast(z0);
1227 *z1_ = cast(z1);
1228 }
1229 }
1230 }
1231 }
1232 }
1233 simd.vectorize(Impl {
1234 simd,
1235 data,
1236 p,
1237 p_div,
1238 inv_twid,
1239 recursion_depth,
1240 recursion_half,
1241 });
1242}
1243
1244#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1245pub(crate) fn fwd_depth_first_avx2<P: PrimeModulusV3>(
1246 simd: crate::V3,
1247 data: &mut [u64],
1248 p: P,
1249 p_div: P::Div,
1250 twid: &[u64],
1251 recursion_depth: usize,
1252 recursion_half: usize,
1253) {
1254 struct Impl<'a, P: PrimeModulusV3> {
1255 simd: crate::V3,
1256 data: &'a mut [u64],
1257 p: P,
1258 p_div: P::Div,
1259 twid: &'a [u64],
1260 recursion_depth: usize,
1261 recursion_half: usize,
1262 }
1263 impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
1264 type Output = ();
1265
1266 #[inline(always)]
1267 fn call(self) -> Self::Output {
1268 let Self {
1269 simd,
1270 data,
1271 p,
1272 p_div,
1273 twid,
1274 recursion_depth,
1275 recursion_half,
1276 } = self;
1277 let n = data.len();
1278 debug_assert!(n.is_power_of_two());
1279
1280 if n <= RECURSION_THRESHOLD {
1281 fwd_breadth_first_avx2(simd, data, p, p_div, twid, recursion_depth, recursion_half);
1282 } else {
1283 let t = n / 2;
1284 let m = 1;
1285 let w_idx = (m << recursion_depth) + m * recursion_half;
1286
1287 let w = &twid[w_idx..];
1288
1289 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1290 let (z0, z1) = data.split_at_mut(t);
1291 let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
1292 let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
1293 let w1 = simd.splat_u64x4(w1);
1294
1295 for (z0_, z1_) in zip(z0, z1) {
1296 let mut z0 = cast(*z0_);
1297 let mut z1 = cast(*z1_);
1298 let z1w = P::mul(p_div, simd, z1, w1);
1299 (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
1300 *z0_ = cast(z0);
1301 *z1_ = cast(z1);
1302 }
1303 }
1304
1305 let (data0, data1) = data.split_at_mut(n / 2);
1306 fwd_depth_first_avx2(
1307 simd,
1308 data0,
1309 p,
1310 p_div,
1311 twid,
1312 recursion_depth + 1,
1313 recursion_half * 2,
1314 );
1315 fwd_depth_first_avx2(
1316 simd,
1317 data1,
1318 p,
1319 p_div,
1320 twid,
1321 recursion_depth + 1,
1322 recursion_half * 2 + 1,
1323 );
1324 }
1325 }
1326 }
1327 simd.vectorize(Impl {
1328 simd,
1329 data,
1330 p,
1331 p_div,
1332 twid,
1333 recursion_depth,
1334 recursion_half,
1335 });
1336}
1337
1338pub(crate) fn fwd_depth_first_scalar<P: PrimeModulus>(
1339 data: &mut [u64],
1340 p: P,
1341 p_div: P::Div,
1342 twid: &[u64],
1343 recursion_depth: usize,
1344 recursion_half: usize,
1345) {
1346 let n = data.len();
1347 debug_assert!(n.is_power_of_two());
1348
1349 if n <= RECURSION_THRESHOLD {
1350 fwd_breadth_first_scalar(data, p, p_div, twid, recursion_depth, recursion_half);
1351 } else {
1352 let t = n / 2;
1353 let m = 1;
1354 let w_idx = (m << recursion_depth) + m * recursion_half;
1355
1356 let w = &twid[w_idx..];
1357
1358 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1359 let (z0, z1) = data.split_at_mut(t);
1360
1361 for (z0, z1) in zip(z0, z1) {
1362 let z1w = P::mul(p_div, *z1, w1);
1363
1364 (*z0, *z1) = (p.add(*z0, z1w), p.sub(*z0, z1w));
1365 }
1366 }
1367
1368 let (data0, data1) = data.split_at_mut(n / 2);
1369 fwd_depth_first_scalar(
1370 data0,
1371 p,
1372 p_div,
1373 twid,
1374 recursion_depth + 1,
1375 recursion_half * 2,
1376 );
1377 fwd_depth_first_scalar(
1378 data1,
1379 p,
1380 p_div,
1381 twid,
1382 recursion_depth + 1,
1383 recursion_half * 2 + 1,
1384 );
1385 }
1386}
1387
1388#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1389#[cfg(feature = "nightly")]
1390pub(crate) fn fwd_avx512<P: PrimeModulusV4>(
1391 simd: crate::V4,
1392 data: &mut [u64],
1393 p: P,
1394 p_div: P::Div,
1395 twid: &[u64],
1396) {
1397 fwd_depth_first_avx512(simd, data, p, p_div, twid, 0, 0);
1398}
1399
1400#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1401pub(crate) fn fwd_avx2<P: PrimeModulusV3>(
1402 simd: crate::V3,
1403 data: &mut [u64],
1404 p: P,
1405 p_div: P::Div,
1406 twid: &[u64],
1407) {
1408 fwd_depth_first_avx2(simd, data, p, p_div, twid, 0, 0);
1409}
1410
1411pub(crate) fn fwd_scalar<P: PrimeModulus>(data: &mut [u64], p: P, p_div: P::Div, twid: &[u64]) {
1412 fwd_depth_first_scalar(data, p, p_div, twid, 0, 0);
1413}
1414
1415#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1416#[cfg(feature = "nightly")]
1417pub(crate) fn inv_avx512<P: PrimeModulusV4>(
1418 simd: crate::V4,
1419 data: &mut [u64],
1420 p: P,
1421 p_div: P::Div,
1422 twid: &[u64],
1423) {
1424 inv_depth_first_avx512(simd, data, p, p_div, twid, 0, 0);
1425}
1426
1427#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1428pub(crate) fn inv_avx2<P: PrimeModulusV3>(
1429 simd: crate::V3,
1430 data: &mut [u64],
1431 p: P,
1432 p_div: P::Div,
1433 twid: &[u64],
1434) {
1435 inv_depth_first_avx2(simd, data, p, p_div, twid, 0, 0);
1436}
1437
1438pub(crate) fn inv_scalar<P: PrimeModulus>(data: &mut [u64], p: P, p_div: P::Div, twid: &[u64]) {
1439 inv_depth_first_scalar(data, p, p_div, twid, 0, 0);
1440}
1441
1442#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1443#[cfg(feature = "nightly")]
1444pub(crate) fn inv_breadth_first_avx512<P: PrimeModulusV4>(
1445 simd: crate::V4,
1446 data: &mut [u64],
1447 p: P,
1448 p_div: P::Div,
1449 inv_twid: &[u64],
1450 recursion_depth: usize,
1451 recursion_half: usize,
1452) {
1453 struct Impl<'a, P: PrimeModulusV4> {
1454 simd: crate::V4,
1455 data: &'a mut [u64],
1456 p: P,
1457 p_div: P::Div,
1458 inv_twid: &'a [u64],
1459 recursion_depth: usize,
1460 recursion_half: usize,
1461 }
1462 impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
1463 type Output = ();
1464
1465 #[inline(always)]
1466 fn call(self) -> Self::Output {
1467 let Self {
1468 simd,
1469 data,
1470 p,
1471 p_div,
1472 inv_twid,
1473 recursion_depth,
1474 recursion_half,
1475 } = self;
1476
1477 let n = data.len();
1478 debug_assert!(n.is_power_of_two());
1479
1480 let mut t = 1;
1481 let mut m = n;
1482 let mut w_idx = (m << recursion_depth) + recursion_half * m;
1483
1484 {
1487 m /= 2;
1488 w_idx /= 2;
1489
1490 let w = pulp::as_arrays::<8, _>(&inv_twid[w_idx..]).0;
1491 let data = pulp::as_arrays_mut::<8, _>(data).0;
1492 let data = pulp::as_arrays_mut::<2, _>(data).0;
1493
1494 for (z0z1, w1) in zip(data, w) {
1495 let w1 = simd.permute1_u64x8(*w1);
1496 let [mut z0, mut z1] = simd.interleave1_u64x8(cast(*z0z1));
1497 (z0, z1) = (
1498 p.add(simd, z0, z1),
1499 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1500 );
1501 *z0z1 = cast(simd.interleave1_u64x8([z0, z1]));
1502 }
1503
1504 t *= 2;
1505 }
1506
1507 {
1510 m /= 2;
1511 w_idx /= 2;
1512
1513 let w = pulp::as_arrays::<4, _>(&inv_twid[w_idx..]).0;
1514 let data = pulp::as_arrays_mut::<8, _>(data).0;
1515 let data = pulp::as_arrays_mut::<2, _>(data).0;
1516
1517 for (z0z0z1z1, w1) in zip(data, w) {
1518 let w1 = simd.permute2_u64x8(*w1);
1519 let [mut z0, mut z1] = simd.interleave2_u64x8(cast(*z0z0z1z1));
1520 (z0, z1) = (
1521 p.add(simd, z0, z1),
1522 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1523 );
1524 *z0z0z1z1 = cast(simd.interleave2_u64x8([z0, z1]));
1525 }
1526
1527 t *= 2;
1528 }
1529
1530 {
1533 m /= 2;
1534 w_idx /= 2;
1535
1536 let w = pulp::as_arrays::<2, _>(&inv_twid[w_idx..]).0;
1537 let data = pulp::as_arrays_mut::<8, _>(data).0;
1538 let data = pulp::as_arrays_mut::<2, _>(data).0;
1539
1540 for (z0z0z0z0z1z1z1z1, w1) in zip(data, w) {
1541 let w1 = simd.permute4_u64x8(*w1);
1542 let [mut z0, mut z1] = simd.interleave4_u64x8(cast(*z0z0z0z0z1z1z1z1));
1543 (z0, z1) = (
1544 p.add(simd, z0, z1),
1545 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1546 );
1547 *z0z0z0z0z1z1z1z1 = cast(simd.interleave4_u64x8([z0, z1]));
1548 }
1549
1550 t *= 2;
1551 }
1552
1553 while m > 1 {
1554 m /= 2;
1555 w_idx /= 2;
1556
1557 let w = &inv_twid[w_idx..];
1558
1559 for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1560 let (z0, z1) = data.split_at_mut(t);
1561 let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
1562 let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
1563 let w1 = simd.splat_u64x8(w1);
1564
1565 for (z0_, z1_) in zip(z0, z1) {
1566 let mut z0 = cast(*z0_);
1567 let mut z1 = cast(*z1_);
1568 (z0, z1) = (
1569 p.add(simd, z0, z1),
1570 P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1571 );
1572 *z0_ = cast(z0);
1573 *z1_ = cast(z1);
1574 }
1575 }
1576
1577 t *= 2;
1578 }
1579 }
1580 }
1581
1582 simd.vectorize(Impl {
1583 simd,
1584 data,
1585 p,
1586 p_div,
1587 inv_twid,
1588 recursion_depth,
1589 recursion_half,
1590 });
1591}
1592
1593#[cfg(test)]
1594mod tests {
1595 use super::*;
1596 use crate::prime64::{
1597 init_negacyclic_twiddles,
1598 tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
1599 };
1600 use alloc::vec;
1601
1602 extern crate alloc;
1603
1604 #[test]
1605 fn test_product() {
1606 for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
1607 let p = Solinas::P;
1608
1609 let (lhs, rhs, negacyclic_convolution) =
1610 random_lhs_rhs_with_negacyclic_convolution(n, p);
1611
1612 let mut twid = vec![0u64; n];
1613 let mut inv_twid = vec![0u64; n];
1614 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1615
1616 let mut prod = vec![0u64; n];
1617 let mut lhs_fourier = lhs.clone();
1618 let mut rhs_fourier = rhs.clone();
1619
1620 fwd_breadth_first_scalar(&mut lhs_fourier, p, Div64::new(p), &twid, 0, 0);
1621 fwd_breadth_first_scalar(&mut rhs_fourier, p, Div64::new(p), &twid, 0, 0);
1622
1623 for i in 0..n {
1624 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1625 }
1626
1627 inv_breadth_first_scalar(&mut prod, p, Div64::new(p), &inv_twid, 0, 0);
1628 let result = prod;
1629
1630 for i in 0..n {
1631 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1632 }
1633 }
1634 }
1635
1636 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1637 #[test]
1638 fn test_product_avx2() {
1639 if let Some(simd) = crate::V3::try_new() {
1640 for n in [8, 16, 32, 64, 128, 256, 512, 1024] {
1641 let p = Solinas::P;
1642
1643 let (lhs, rhs, negacyclic_convolution) =
1644 random_lhs_rhs_with_negacyclic_convolution(n, p);
1645
1646 let mut twid = vec![0u64; n];
1647 let mut inv_twid = vec![0u64; n];
1648 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1649
1650 let mut prod = vec![0u64; n];
1651 let mut lhs_fourier = lhs.clone();
1652 let mut rhs_fourier = rhs.clone();
1653
1654 let crate::u256 { x0, x1, x2, x3 } = Div64::new(p).double_reciprocal;
1655 fwd_breadth_first_avx2(simd, &mut lhs_fourier, p, (p, x0, x1, x2, x3), &twid, 0, 0);
1656 fwd_breadth_first_avx2(simd, &mut rhs_fourier, p, (p, x0, x1, x2, x3), &twid, 0, 0);
1657
1658 for i in 0..n {
1659 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1660 }
1661
1662 inv_breadth_first_avx2(simd, &mut prod, p, (p, x0, x1, x2, x3), &inv_twid, 0, 0);
1663 let result = prod;
1664
1665 for i in 0..n {
1666 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1667 }
1668 }
1669 }
1670 }
1671
1672 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1673 #[cfg(feature = "nightly")]
1674 #[test]
1675 fn test_product_avx512() {
1676 if let Some(simd) = crate::V4::try_new() {
1677 for n in [16, 32, 64, 128, 256, 512, 1024] {
1678 let p = Solinas::P;
1679
1680 let (lhs, rhs, negacyclic_convolution) =
1681 random_lhs_rhs_with_negacyclic_convolution(n, p);
1682
1683 let mut twid = vec![0u64; n];
1684 let mut inv_twid = vec![0u64; n];
1685 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1686
1687 let mut prod = vec![0u64; n];
1688 let mut lhs_fourier = lhs.clone();
1689 let mut rhs_fourier = rhs.clone();
1690
1691 let crate::u256 { x0, x1, x2, x3 } = Div64::new(p).double_reciprocal;
1692 fwd_breadth_first_avx512(
1693 simd,
1694 &mut lhs_fourier,
1695 p,
1696 (p, x0, x1, x2, x3),
1697 &twid,
1698 0,
1699 0,
1700 );
1701 fwd_breadth_first_avx512(
1702 simd,
1703 &mut rhs_fourier,
1704 p,
1705 (p, x0, x1, x2, x3),
1706 &twid,
1707 0,
1708 0,
1709 );
1710
1711 for i in 0..n {
1712 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1713 }
1714
1715 inv_breadth_first_avx512(simd, &mut prod, p, (p, x0, x1, x2, x3), &inv_twid, 0, 0);
1716 let result = prod;
1717
1718 for i in 0..n {
1719 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1720 }
1721 }
1722 }
1723 }
1724
1725 #[test]
1726 fn test_product_solinas() {
1727 for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
1728 let p = Solinas::P;
1729
1730 let (lhs, rhs, negacyclic_convolution) =
1731 random_lhs_rhs_with_negacyclic_convolution(n, p);
1732
1733 let mut twid = vec![0u64; n];
1734 let mut inv_twid = vec![0u64; n];
1735 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1736
1737 let mut prod = vec![0u64; n];
1738 let mut lhs_fourier = lhs.clone();
1739 let mut rhs_fourier = rhs.clone();
1740
1741 fwd_breadth_first_scalar(&mut lhs_fourier, Solinas, (), &twid, 0, 0);
1742 fwd_breadth_first_scalar(&mut rhs_fourier, Solinas, (), &twid, 0, 0);
1743
1744 for i in 0..n {
1745 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1746 }
1747
1748 inv_breadth_first_scalar(&mut prod, Solinas, (), &inv_twid, 0, 0);
1749 let result = prod;
1750
1751 for i in 0..n {
1752 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1753 }
1754 }
1755 }
1756
1757 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1758 #[test]
1759 fn test_product_solinas_avx2() {
1760 if let Some(simd) = crate::V3::try_new() {
1761 for n in [8, 16, 32, 64, 128, 256, 512, 1024] {
1762 let p = Solinas::P;
1763
1764 let (lhs, rhs, negacyclic_convolution) =
1765 random_lhs_rhs_with_negacyclic_convolution(n, p);
1766
1767 let mut twid = vec![0u64; n];
1768 let mut inv_twid = vec![0u64; n];
1769 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1770
1771 let mut prod = vec![0u64; n];
1772 let mut lhs_fourier = lhs.clone();
1773 let mut rhs_fourier = rhs.clone();
1774
1775 fwd_breadth_first_avx2(simd, &mut lhs_fourier, Solinas, (), &twid, 0, 0);
1776 fwd_breadth_first_avx2(simd, &mut rhs_fourier, Solinas, (), &twid, 0, 0);
1777
1778 for i in 0..n {
1779 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1780 }
1781
1782 inv_breadth_first_avx2(simd, &mut prod, Solinas, (), &inv_twid, 0, 0);
1783 let result = prod;
1784
1785 for i in 0..n {
1786 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1787 }
1788 }
1789 }
1790 }
1791
1792 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1793 #[cfg(feature = "nightly")]
1794 #[test]
1795 fn test_product_solinas_avx512() {
1796 if let Some(simd) = crate::V4::try_new() {
1797 for n in [16, 32, 64, 128, 256, 512, 1024] {
1798 let p = Solinas::P;
1799
1800 let (lhs, rhs, negacyclic_convolution) =
1801 random_lhs_rhs_with_negacyclic_convolution(n, p);
1802
1803 let mut twid = vec![0u64; n];
1804 let mut inv_twid = vec![0u64; n];
1805 init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1806
1807 let mut prod = vec![0u64; n];
1808 let mut lhs_fourier = lhs.clone();
1809 let mut rhs_fourier = rhs.clone();
1810
1811 fwd_breadth_first_avx512(simd, &mut lhs_fourier, Solinas, (), &twid, 0, 0);
1812 fwd_breadth_first_avx512(simd, &mut rhs_fourier, Solinas, (), &twid, 0, 0);
1813
1814 for i in 0..n {
1815 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1816 }
1817
1818 inv_breadth_first_avx512(simd, &mut prod, Solinas, (), &inv_twid, 0, 0);
1819 let result = prod;
1820
1821 for i in 0..n {
1822 assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1823 }
1824 }
1825 }
1826 }
1827}