1use crate::{u256::U256, Binary};
2use crunchy::unroll;
3
4#[derive(PartialEq, Eq, Clone, Debug)]
15struct Matrix(u64, u64, u64, u64, bool);
16
17impl Matrix {
18 const IDENTITY: Self = Self(1, 0, 0, 1, true);
19}
20
21#[rustfmt::skip]
31#[allow(clippy::shadow_unrelated)]
33fn mat_mul(a: &mut U256, b: &mut U256, (q00, q01, q10, q11): (u64, u64, u64, u64)) {
34 use crate::algorithms::limb_operations::{mac, msb};
35 let (ai, ac) = mac( 0, q00, a.limb(0), 0);
36 let (ai, ab) = msb(ai, q01, b.limb(0), 0);
37 let (bi, bc) = mac( 0, q11, b.limb(0), 0);
38 let (bi, bb) = msb(bi, q10, a.limb(0), 0);
39 a.set_limb(0, ai);
40 b.set_limb(0, bi);
41 let (ai, ac) = mac( 0, q00, a.limb(1), ac);
42 let (ai, ab) = msb(ai, q01, b.limb(1), ab);
43 let (bi, bc) = mac( 0, q11, b.limb(1), bc);
44 let (bi, bb) = msb(bi, q10, a.limb(1), bb);
45 a.set_limb(1, ai);
46 b.set_limb(1, bi);
47 let (ai, ac) = mac( 0, q00, a.limb(2), ac);
48 let (ai, ab) = msb(ai, q01, b.limb(2), ab);
49 let (bi, bc) = mac( 0, q11, b.limb(2), bc);
50 let (bi, bb) = msb(bi, q10, a.limb(2), bb);
51 a.set_limb(2, ai);
52 b.set_limb(2, bi);
53 let (ai, _) = mac( 0, q00, a.limb(3), ac);
54 let (ai, _) = msb(ai, q01, b.limb(3), ab);
55 let (bi, _) = mac( 0, q11, b.limb(3), bc);
56 let (bi, _) = msb(bi, q10, a.limb(3), bb);
57 a.set_limb(3, ai);
58 b.set_limb(3, bi);
59}
60
61fn lehmer_update(a0: &mut U256, a1: &mut U256, Matrix(q00, q01, q10, q11, even): &Matrix) {
63 if *even {
64 mat_mul(a0, a1, (*q00, *q01, *q10, *q11));
65 } else {
66 mat_mul(a0, a1, (*q10, *q11, *q00, *q01));
67 core::mem::swap(a0, a1);
68 }
69}
70
71#[allow(clippy::cognitive_complexity)]
79fn div1(mut a: u64, b: u64) -> u64 {
80 debug_assert!(a >= b);
81 debug_assert!(b > 0);
82 unroll! {
83 for i in 1..20 {
84 a -= b;
85 if a < b {
86 return i as u64
87 }
88 }
89 }
90 19 + a / b
91}
92
93#[inline(always)]
107#[allow(clippy::cognitive_complexity)]
109fn lehmer_unroll(a2: u64, a3: &mut u64, k2: u64, k3: &mut u64) {
110 debug_assert!(a2 < *a3);
111 debug_assert!(a2 > 0);
112 unroll! {
113 for i in 1..17 {
114 *a3 -= a2;
115 *k3 += k2;
116 if *a3 < a2 {
117 return;
118 }
119 }
120 }
121 let q = *a3 / a2;
122 *a3 -= q * a2;
123 *k3 += q * k2;
124}
125
126#[allow(clippy::shadow_unrelated)]
132fn lehmer_small(mut r0: u64, mut r1: u64) -> Matrix {
133 debug_assert!(r0 >= r1);
134 if r1 == 0_u64 {
135 return Matrix::IDENTITY;
136 }
137 let mut q00 = 1_u64;
138 let mut q01 = 0_u64;
139 let mut q10 = 0_u64;
140 let mut q11 = 1_u64;
141 loop {
142 let q = div1(r0, r1);
144 r0 -= q * r1;
145 q00 += q * q10;
146 q01 += q * q11;
147 if r0 == 0_u64 {
148 return Matrix(q10, q11, q00, q01, false);
149 }
150 let q = div1(r1, r0);
151 r1 -= q * r0;
152 q10 += q * q00;
153 q11 += q * q01;
154 if r1 == 0_u64 {
155 return Matrix(q00, q01, q10, q11, true);
156 }
157 }
158}
159
160#[allow(clippy::shadow_unrelated)]
170fn lehmer_loop(a0: u64, mut a1: u64) -> Matrix {
171 const LIMIT: u64 = 1_u64 << 32;
172 debug_assert!(a0 >= 1_u64 << 63);
173 debug_assert!(a0 >= a1);
174
175 let mut k0 = 1_u64 << 32; let mut k1 = 1_u64; let mut even = true;
181 if a1 < LIMIT {
182 return Matrix::IDENTITY;
183 }
184
185 let q = div1(a0, a1);
187 let mut a2 = a0 - q * a1;
188 let mut k2 = k0 + q * k1;
189 if a2 < LIMIT {
190 let u2 = k2 >> 32;
191 let v2 = k2 % LIMIT;
192
193 if a2 >= v2 && a1 - a2 >= u2 {
195 return Matrix(0, 1, u2, v2, false);
196 } else {
197 return Matrix::IDENTITY;
198 }
199 }
200
201 let q = div1(a1, a2);
203 let mut a3 = a1 - q * a2;
204 let mut k3 = k1 + q * k2;
205
206 while a3 >= LIMIT {
209 a1 = a2;
210 a2 = a3;
211 a3 = a1;
212 k0 = k1;
213 k1 = k2;
214 k2 = k3;
215 k3 = k1;
216 lehmer_unroll(a2, &mut a3, k2, &mut k3);
217 if a3 < LIMIT {
218 even = false;
219 break;
220 }
221 a1 = a2;
222 a2 = a3;
223 a3 = a1;
224 k0 = k1;
225 k1 = k2;
226 k2 = k3;
227 k3 = k1;
228 lehmer_unroll(a2, &mut a3, k2, &mut k3);
229 }
230 let u0 = k0 >> 32;
232 let u1 = k1 >> 32;
233 let u2 = k2 >> 32;
234 let u3 = k3 >> 32;
235 let v0 = k0 % LIMIT;
236 let v1 = k1 % LIMIT;
237 let v2 = k2 % LIMIT;
238 let v3 = k3 % LIMIT;
239 debug_assert!(a2 >= LIMIT);
240 debug_assert!(a3 < LIMIT);
241
242 if even {
245 debug_assert!(a2 >= v2);
247 if a1 - a2 >= u2 + u1 {
248 if a3 >= u3 && a2 - a3 >= v3 + v2 {
250 Matrix(u2, v2, u3, v3, true)
252 } else {
253 Matrix(u1, v1, u2, v2, false)
255 }
256 } else {
257 Matrix(u0, v0, u1, v1, true)
259 }
260 } else {
261 debug_assert!(a2 >= u2);
263 if a1 - a2 >= v2 + v1 {
264 if a3 >= v3 && a2 - a3 >= u3 + u2 {
266 Matrix(u2, v2, u3, v3, false)
268 } else {
269 Matrix(u1, v1, u2, v2, true)
271 }
272 } else {
273 Matrix(u0, v0, u1, v1, false)
275 }
276 }
277}
278
279#[allow(clippy::shadow_unrelated)]
294fn lehmer_double(mut r0: U256, mut r1: U256) -> Matrix {
295 debug_assert!(r0 >= r1);
296 if r0.leading_zeros() >= 192 {
297 debug_assert!(r1.leading_zeros() >= 192);
299 debug_assert!(r0.limb(0) >= r1.limb(0));
300 return lehmer_small(r0.limb(0), r1.limb(0));
301 }
302 let s = r0.leading_zeros();
303 let r0s = r0.clone() << s;
304 let r1s = r1.clone() << s;
305 let q = lehmer_loop(r0s.limb(3), r1s.limb(3));
306 if q == Matrix::IDENTITY {
307 return q;
308 }
309 lehmer_update(&mut r0, &mut r1, &q);
316 let s = r0.leading_zeros();
317 let r0s = r0 << s;
318 let r1s = r1 << s;
319 let qn = lehmer_loop(r0s.limb(3), r1s.limb(3));
320
321 Matrix(
323 qn.0 * q.0 + qn.1 * q.2,
324 qn.0 * q.1 + qn.1 * q.3,
325 qn.2 * q.0 + qn.3 * q.2,
326 qn.2 * q.1 + qn.3 * q.3,
327 qn.4 ^ !q.4,
328 )
329}
330
331pub(crate) fn gcd(mut r0: U256, mut r1: U256) -> U256 {
335 if r1 > r0 {
336 core::mem::swap(&mut r0, &mut r1);
337 }
338 debug_assert!(r0 >= r1);
339 while r1 != U256::ZERO {
340 let q = lehmer_double(r0.clone(), r1.clone());
341 if q == Matrix::IDENTITY {
342 let q = &r0 / &r1;
346 let t = r0 - &q * &r1;
347 r0 = r1;
348 r1 = t;
349 } else {
350 lehmer_update(&mut r0, &mut r1, &q);
351 }
352 }
353 r0
354}
355
356#[allow(clippy::module_name_repetitions)]
374pub(crate) fn gcd_extended(mut r0: U256, mut r1: U256) -> (U256, U256, U256, bool) {
375 let swapped = r1 > r0;
376 if swapped {
377 core::mem::swap(&mut r0, &mut r1);
378 }
379 debug_assert!(r0 >= r1);
380 let mut s0 = U256::ONE;
381 let mut s1 = U256::ZERO;
382 let mut t0 = U256::ZERO;
383 let mut t1 = U256::ONE;
384 let mut even = true;
385 while r1 != U256::ZERO {
386 let q = lehmer_double(r0.clone(), r1.clone());
387 if q == Matrix::IDENTITY {
388 let q = &r0 / &r1;
392 let t = r0 - &q * &r1;
393 r0 = r1;
394 r1 = t;
395 let t = s0 - &q * &s1;
396 s0 = s1;
397 s1 = t;
398 let t = t0 - q * &t1;
399 t0 = t1;
400 t1 = t;
401 even = !even;
402 } else {
403 lehmer_update(&mut r0, &mut r1, &q);
404 lehmer_update(&mut s0, &mut s1, &q);
405 lehmer_update(&mut t0, &mut t1, &q);
406 even ^= !q.4;
407 }
408 }
409 if even {
411 t0 = U256::ZERO - t0;
413 } else {
414 s0 = U256::ZERO - s0;
416 }
417 if swapped {
418 core::mem::swap(&mut s0, &mut t0);
419 even = !even;
420 }
421 (r0, s0, t0, even)
422}
423
424pub(crate) fn inv_mod(modulus: &U256, num: &U256) -> Option<U256> {
442 let mut r0 = modulus.clone();
443 let mut r1 = num.clone();
444 if r1 >= r0 {
445 r1 %= &r0;
446 }
447 let mut t0 = U256::ZERO;
448 let mut t1 = U256::ONE;
449 let mut even = true;
450 while r1 != U256::ZERO {
451 let q = lehmer_double(r0.clone(), r1.clone());
452 if q == Matrix::IDENTITY {
453 let q = &r0 / &r1;
456 let t = r0 - &q * &r1;
457 r0 = r1;
458 r1 = t;
459 let t = t0 - q * &t1;
460 t0 = t1;
461 t1 = t;
462 even = !even;
463 } else {
464 lehmer_update(&mut r0, &mut r1, &q);
465 lehmer_update(&mut t0, &mut t1, &q);
466 even ^= !q.4;
467 }
468 }
469 if r0 == U256::ONE {
470 Some(if even { modulus + t0 } else { t0 })
472 } else {
473 None
474 }
475}
476
477#[allow(clippy::unreadable_literal)]
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use num_traits::identities::{One, Zero};
483 use proptest::prelude::*;
484 use zkp_macros_decl::u256h;
485
486 #[test]
487 fn test_lehmer_small() {
488 assert_eq!(lehmer_small(0, 0), Matrix::IDENTITY);
489 assert_eq!(
490 lehmer_small(14535145444257436950, 5818365597666026993),
491 Matrix(
492 379355176803460069,
493 947685836737753349,
494 831195085380860999,
495 2076449349179633850,
496 false
497 )
498 );
499 assert_eq!(
500 lehmer_small(15507080595343815048, 10841422679839906593),
501 Matrix(
502 40154122160696118,
503 57434639988632077,
504 3613807559946635531,
505 5169026865114605016,
506 true
507 )
508 );
509 }
510
511 #[test]
512 fn test_issue() {
513 let a = u256h!("0000000000000054000000000000004f000000000000001f0000000000000028");
515 let b = u256h!("0000000000000054000000000000005b000000000000002b000000000000005d");
516 let _ = gcd(a, b);
517 }
518
519 #[test]
520 fn test_lehmer_loop() {
521 assert_eq!(lehmer_loop(1_u64 << 63, 0), Matrix::IDENTITY);
522 assert_eq!(
523 lehmer_loop(16194659139127649777, 14535145444257436950),
525 Matrix(320831736, 357461893, 1018828859, 1135151083, true)
526 );
527 assert_eq!(
528 lehmer_loop(15267531864828975732, 6325623274722585764,),
530 Matrix(88810257, 214352542, 774927313, 1870365485, false)
531 );
532 }
533
534 proptest!(
535 #[test]
536 #[allow(clippy::shadow_unrelated)]
538 fn test_lehmer_loop_match_gcd(mut a: u64, mut b: u64) {
539 const LIMIT: u64 = 1_u64 << 32;
540
541 a |= 1_u64 << 63;
543 if b > a {
544 core::mem::swap(&mut a, &mut b)
545 }
546
547 let update_matrix = lehmer_loop(a, b);
549
550 assert!(update_matrix.0 < LIMIT);
552 assert!(update_matrix.1 < LIMIT);
553 assert!(update_matrix.2 < LIMIT);
554 assert!(update_matrix.3 < LIMIT);
555 prop_assume!(update_matrix != Matrix::IDENTITY);
556
557 assert!(update_matrix.0 <= update_matrix.2);
558 assert!(update_matrix.2 <= update_matrix.3);
559 assert!(update_matrix.1 <= update_matrix.3);
560
561 let mut a0 = a;
563 let mut a1 = b;
564 let mut s0 = 1;
565 let mut s1 = 0;
566 let mut t0 = 0;
567 let mut t1 = 1;
568 let mut even = true;
569 let mut result = false;
570 while a1 > 0 {
571 let r = a0 / a1;
572 let t = a0 - r * a1;
573 a0 = a1;
574 a1 = t;
575 let t = s0 + r * s1;
576 s0 = s1;
577 s1 = t;
578 let t = t0 + r * t1;
579 t0 = t1;
580 t1 = t;
581 even = !even;
582 if update_matrix == Matrix(s0, t0, s1, t1, even) {
583 result = true;
584 break;
585 }
586 }
587 prop_assert!(result)
588 }
589
590 #[test]
591 fn test_mat_mul_match_formula(a: U256, b: U256, q00: u64, q01: u64, q10: u64, q11: u64) {
592 let a_expected = q00 * a.clone() - q01 * b.clone();
593 let b_expected = q11 * b.clone() - q10 * a.clone();
594 let mut a_result = a;
595 let mut b_result = b;
596 mat_mul(&mut a_result, &mut b_result, (q00, q01, q10, q11));
597 prop_assert_eq!(a_result, a_expected);
598 prop_assert_eq!(b_result, b_expected);
599 }
600 );
601
602 #[test]
603 fn test_lehmer_double() {
604 assert_eq!(lehmer_double(U256::ZERO, U256::ZERO), Matrix::IDENTITY);
605 assert_eq!(
606 lehmer_double(
608 u256h!("518a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814"),
609 u256h!("018a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814")
610 ),
611 Matrix(
612 2927556694930003,
613 154961230633081597,
614 3017020641586254,
615 159696730135159213,
616 true
617 )
618 );
619 }
620
621 #[test]
622 fn test_gcd_lehmer() {
623 assert_eq!(
624 gcd_extended(U256::ZERO, U256::ZERO),
625 (U256::ZERO, U256::ONE, U256::ZERO, true)
626 );
627 assert_eq!(
628 gcd_extended(
629 u256h!("fea5a792d0a17b24827908e5524bcceec3ec6a92a7a42eac3b93e2bb351cf4f2"),
630 u256h!("00028735553c6c798ed1ffb8b694f8f37b672b1bab7f80c4e6f4c0e710c79fb4")
631 ),
632 (
633 u256h!("0000000000000000000000000000000000000000000000000000000000000002"),
634 u256h!("00000b5a5ecb4dfc4ea08773d0593986592959a646b2f97655ed839928274ebb"),
635 u256h!("0477865490d3994853934bf7eae7dad9afac55ccbf412a60c18fc9bea58ec8ba"),
636 false
637 )
638 );
639 assert_eq!(
640 gcd_extended(
641 u256h!("518a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814"),
642 u256h!("018a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814")
643 ),
644 (
645 U256::from(4),
646 u256h!("002c851a0dddfaa03b9db2e39d48067d9b57fa0d238b70c7feddf8d267accc41"),
647 u256h!("0934869c752ae9c7d2ed8aa55e7754e5492aaac49f8c9f3416156313a16c1174"),
648 true
649 )
650 );
651 assert_eq!(
652 gcd_extended(
653 u256h!("7dfd26515f3cd365ea32e1a43dbac87a25d0326fd834a889cb1e4c6c3c8d368c"),
654 u256h!("3d341ef315cbe5b9f0ab79255f9684e153deaf5f460a8425819c84ec1e80a2f3")
655 ),
656 (
657 u256h!("0000000000000000000000000000000000000000000000000000000000000001"),
658 u256h!("0bbc35a0c1fd8f1ae85377ead5a901d4fbf0345fa303a87a4b4b68429cd69293"),
659 u256h!("18283a24821b7de14cf22afb0e1a7efb4212b7f373988f5a0d75f6ee0b936347"),
660 false
661 )
662 );
663 assert_eq!(
664 gcd_extended(
665 u256h!("836fab5d425345751b3425e733e8150a17fdab2d5fb840ede5e0879f41497a4f"),
666 u256h!("196e875b381eb95d9b5c6c3f198c5092b3ccc21279a7e68bc42cb6bca2d2644d")
667 ),
668 (
669 u256h!("000000000000000000000000000000000000000000000000c59f8490536754fd"),
670 u256h!("000000000000000006865401d85836d50a2bd608f152186fb24072a122d0dc5d"),
671 u256h!("000000000000000021b8940f60792f546cbeb17f8b852d33a00b14b323d6de70"),
672 false
673 )
674 );
675 assert_eq!(
676 gcd_extended(
677 u256h!("00253222ed7b612113dbea0be0e1a0b88f2c0c16250f54bf1ec35d62671bf83a"),
678 u256h!("0000000000025d4e064960ef2964b2170f1cd63ab931968621dde8a867079fd4")
679 ),
680 (
681 u256h!("000000000000000000000000000505b22b0a9fd5a6e2166e3486f0109e6f60b2"),
682 u256h!("0000000000000000000000000000000000000000000000001f16d40433587ae9"),
683 u256h!("0000000000000000000000000000000000000001e91177fbec66b1233e79662e"),
684 true
685 )
686 );
687 assert_eq!(
688 gcd_extended(
689 u256h!("0000000000025d4e064960ef2964b2170f1cd63ab931968621dde8a867079fd4"),
690 u256h!("00253222ed7b612113dbea0be0e1a0b88f2c0c16250f54bf1ec35d62671bf83a")
691 ),
692 (
693 u256h!("000000000000000000000000000505b22b0a9fd5a6e2166e3486f0109e6f60b2"),
694 u256h!("0000000000000000000000000000000000000001e91177fbec66b1233e79662e"),
695 u256h!("0000000000000000000000000000000000000000000000001f16d40433587ae9"),
696 false
697 )
698 );
699 }
700
701 #[test]
702 fn test_gcd_lehmer_extended_equal_inputs() {
703 let a = U256::from(10);
704 let b = U256::from(10);
705 let (gcd, u, v, even) = gcd_extended(a.clone(), b.clone());
706 assert_eq!(&a % &gcd, U256::ZERO);
707 assert_eq!(&b % &gcd, U256::ZERO);
708 assert!(!even);
709 assert_eq!(gcd, v * b - u * a);
710 }
711
712 proptest!(
713 #[test]
714 fn test_gcd_lehmer_extended(a: U256, b: U256) {
715 let (gcd, u, v, even) = gcd_extended(a.clone(), b.clone());
716 prop_assert!((&a % &gcd).is_zero());
717 prop_assert!((&b % &gcd).is_zero());
718
719 if even {
720 prop_assert_eq!(gcd, u * a - v * b);
721 } else {
722 prop_assert_eq!(gcd, v * b - u * a);
723 }
724 }
725
726 #[test]
727 fn test_inv_lehmer(mut a: U256) {
728 const MODULUS: U256 =
729 u256h!("0800000000000011000000000000000000000000000000000000000000000001");
730 a %= MODULUS;
731 match inv_mod(&MODULUS, &a) {
732 None => prop_assert!(a.is_zero()),
733 Some(a_inv) => prop_assert!(a.mulmod(&a_inv, &MODULUS).is_one()),
734 }
735 }
736 );
737}