use crate::u256::U256;
use crunchy::unroll;
#[derive(PartialEq, Eq, Clone, Debug)]
struct Matrix(u64, u64, u64, u64, bool);
impl Matrix {
    const IDENTITY: Self = Self(1, 0, 0, 1, true);
}
#[rustfmt::skip]
#[allow(clippy::shadow_unrelated)]
fn mat_mul(a: &mut U256, b: &mut U256, (q00, q01, q10, q11): (u64, u64, u64, u64)) {
    use crate::utils::{mac, msb};
    let (ai, ac) = mac( 0, q00, a.c0, 0);
    let (ai, ab) = msb(ai, q01, b.c0, 0);
    let (bi, bc) = mac( 0, q11, b.c0, 0);
    let (bi, bb) = msb(bi, q10, a.c0, 0);
    a.c0 = ai;
    b.c0 = bi;
    let (ai, ac) = mac( 0, q00, a.c1, ac);
    let (ai, ab) = msb(ai, q01, b.c1, ab);
    let (bi, bc) = mac( 0, q11, b.c1, bc);
    let (bi, bb) = msb(bi, q10, a.c1, bb);
    a.c1 = ai;
    b.c1 = bi;
    let (ai, ac) = mac( 0, q00, a.c2, ac);
    let (ai, ab) = msb(ai, q01, b.c2, ab);
    let (bi, bc) = mac( 0, q11, b.c2, bc);
    let (bi, bb) = msb(bi, q10, a.c2, bb);
    a.c2 = ai;
    b.c2 = bi;
    let (ai, _) = mac( 0, q00, a.c3, ac);
    let (ai, _) = msb(ai, q01, b.c3, ab);
    let (bi, _) = mac( 0, q11, b.c3, bc);
    let (bi, _) = msb(bi, q10, a.c3, bb);
    a.c3 = ai;
    b.c3 = bi;
}
fn lehmer_update(a0: &mut U256, a1: &mut U256, Matrix(q00, q01, q10, q11, even): &Matrix) {
    if *even {
        mat_mul(a0, a1, (*q00, *q01, *q10, *q11));
    } else {
        mat_mul(a0, a1, (*q10, *q11, *q00, *q01));
        core::mem::swap(a0, a1);
    }
}
#[allow(clippy::cognitive_complexity)]
fn div1(mut a: u64, b: u64) -> u64 {
    debug_assert!(a >= b);
    debug_assert!(b > 0);
    unroll! {
        for i in 1..20 {
            a -= b;
            if a < b {
                return i as u64
            }
        }
    }
    19 + a / b
}
#[inline(always)]
#[allow(clippy::cognitive_complexity)]
fn lehmer_unroll(a2: u64, a3: &mut u64, k2: u64, k3: &mut u64) {
    debug_assert!(a2 < *a3);
    debug_assert!(a2 > 0);
    unroll! {
        for i in 1..17 {
            *a3 -= a2;
            *k3 += k2;
            if *a3 < a2 {
                return;
            }
        }
    }
    let q = *a3 / a2;
    *a3 -= q * a2;
    *k3 += q * k2;
}
#[allow(clippy::shadow_unrelated)]
fn lehmer_small(mut r0: u64, mut r1: u64) -> Matrix {
    debug_assert!(r0 >= r1);
    if r1 == 0_u64 {
        return Matrix::IDENTITY;
    }
    let mut q00 = 1_u64;
    let mut q01 = 0_u64;
    let mut q10 = 0_u64;
    let mut q11 = 1_u64;
    loop {
        
        let q = div1(r0, r1);
        r0 -= q * r1;
        q00 += q * q10;
        q01 += q * q11;
        if r0 == 0_u64 {
            return Matrix(q10, q11, q00, q01, false);
        }
        let q = div1(r1, r0);
        r1 -= q * r0;
        q10 += q * q00;
        q11 += q * q01;
        if r1 == 0_u64 {
            return Matrix(q00, q01, q10, q11, true);
        }
    }
}
#[allow(clippy::shadow_unrelated)]
fn lehmer_loop(a0: u64, mut a1: u64) -> Matrix {
    const LIMIT: u64 = 1_u64 << 32;
    debug_assert!(a0 >= 1_u64 << 63);
    debug_assert!(a0 >= a1);
    
    
    
    let mut k0 = 1_u64 << 32; 
    let mut k1 = 1_u64; 
    let mut even = true;
    if a1 < LIMIT {
        return Matrix::IDENTITY;
    }
    
    let q = div1(a0, a1);
    let mut a2 = a0 - q * a1;
    let mut k2 = k0 + q * k1;
    if a2 < LIMIT {
        let u2 = k2 >> 32;
        let v2 = k2 % LIMIT;
        
        if a2 >= v2 && a1 - a2 >= u2 {
            return Matrix(0, 1, u2, v2, false);
        } else {
            return Matrix::IDENTITY;
        }
    }
    
    let q = div1(a1, a2);
    let mut a3 = a1 - q * a2;
    let mut k3 = k1 + q * k2;
    
    
    while a3 >= LIMIT {
        a1 = a2;
        a2 = a3;
        a3 = a1;
        k0 = k1;
        k1 = k2;
        k2 = k3;
        k3 = k1;
        lehmer_unroll(a2, &mut a3, k2, &mut k3);
        if a3 < LIMIT {
            even = false;
            break;
        }
        a1 = a2;
        a2 = a3;
        a3 = a1;
        k0 = k1;
        k1 = k2;
        k2 = k3;
        k3 = k1;
        lehmer_unroll(a2, &mut a3, k2, &mut k3);
    }
    
    let u0 = k0 >> 32;
    let u1 = k1 >> 32;
    let u2 = k2 >> 32;
    let u3 = k3 >> 32;
    let v0 = k0 % LIMIT;
    let v1 = k1 % LIMIT;
    let v2 = k2 % LIMIT;
    let v3 = k3 % LIMIT;
    debug_assert!(a2 >= LIMIT);
    debug_assert!(a3 < LIMIT);
    
    
    if even {
        
        debug_assert!(a2 >= v2);
        if a1 - a2 >= u2 + u1 {
            
            if a3 >= u3 && a2 - a3 >= v3 + v2 {
                
                Matrix(u2, v2, u3, v3, true)
            } else {
                
                Matrix(u1, v1, u2, v2, false)
            }
        } else {
            
            Matrix(u0, v0, u1, v1, true)
        }
    } else {
        
        debug_assert!(a2 >= u2);
        if a1 - a2 >= v2 + v1 {
            
            if a3 >= v3 && a2 - a3 >= u3 + u2 {
                
                Matrix(u2, v2, u3, v3, false)
            } else {
                
                Matrix(u1, v1, u2, v2, true)
            }
        } else {
            
            Matrix(u0, v0, u1, v1, false)
        }
    }
}
#[allow(clippy::shadow_unrelated)]
fn lehmer_double(mut r0: U256, mut r1: U256) -> Matrix {
    debug_assert!(r0 >= r1);
    if r0.bits() < 64 {
        debug_assert!(r1.bits() < 64);
        debug_assert!(r0.c0 >= r1.c0);
        return lehmer_small(r0.c0, r1.c0);
    }
    let s = r0.leading_zeros();
    let r0s = r0.clone() << s;
    let r1s = r1.clone() << s;
    let q = lehmer_loop(r0s.c3, r1s.c3);
    if q == Matrix::IDENTITY {
        return q;
    }
    
    
    
    
    
    lehmer_update(&mut r0, &mut r1, &q);
    let s = r0.leading_zeros();
    let r0s = r0.clone() << s;
    let r1s = r1.clone() << s;
    let qn = lehmer_loop(r0s.c3, r1s.c3);
    
    Matrix(
        qn.0 * q.0 + qn.1 * q.2,
        qn.0 * q.1 + qn.1 * q.3,
        qn.2 * q.0 + qn.3 * q.2,
        qn.2 * q.1 + qn.3 * q.3,
        qn.4 ^ !q.4,
    )
}
pub fn gcd(mut r0: U256, mut r1: U256) -> U256 {
    if r1 > r0 {
        core::mem::swap(&mut r0, &mut r1);
    }
    debug_assert!(r0 >= r1);
    while r1 != U256::ZERO {
        let q = lehmer_double(r0.clone(), r1.clone());
        if q == Matrix::IDENTITY {
            
            
            
            let q = &r0 / &r1;
            let t = r0 - &q * &r1;
            r0 = r1;
            r1 = t;
        } else {
            lehmer_update(&mut r0, &mut r1, &q);
        }
    }
    r0
}
#[allow(clippy::module_name_repetitions)]
pub fn gcd_extended(mut r0: U256, mut r1: U256) -> (U256, U256, U256, bool) {
    let swapped = r1 > r0;
    if swapped {
        core::mem::swap(&mut r0, &mut r1);
    }
    debug_assert!(r0 >= r1);
    let mut s0 = U256::ONE;
    let mut s1 = U256::ZERO;
    let mut t0 = U256::ZERO;
    let mut t1 = U256::ONE;
    let mut even = true;
    while r1 != U256::ZERO {
        let q = lehmer_double(r0.clone(), r1.clone());
        if q == Matrix::IDENTITY {
            
            
            
            let q = &r0 / &r1;
            let t = r0 - &q * &r1;
            r0 = r1;
            r1 = t;
            let t = s0 - &q * &s1;
            s0 = s1;
            s1 = t;
            let t = t0 - q * &t1;
            t0 = t1;
            t1 = t;
            even = !even;
        } else {
            lehmer_update(&mut r0, &mut r1, &q);
            lehmer_update(&mut s0, &mut s1, &q);
            lehmer_update(&mut t0, &mut t1, &q);
            even ^= !q.4;
        }
    }
    
    if even {
        
        t0 = U256::ZERO - t0;
    } else {
        
        s0 = U256::ZERO - s0;
    }
    if swapped {
        core::mem::swap(&mut s0, &mut t0);
        even = !even;
    }
    (r0, s0, t0, even)
}
pub(crate) fn inv_mod(modulus: &U256, num: &U256) -> Option<U256> {
    let mut r0 = modulus.clone();
    let mut r1 = num.clone();
    if r1 >= r0 {
        r1 %= &r0;
    }
    let mut t0 = U256::ZERO;
    let mut t1 = U256::ONE;
    let mut even = true;
    while r1 != U256::ZERO {
        let q = lehmer_double(r0.clone(), r1.clone());
        if q == Matrix::IDENTITY {
            
            
            let q = &r0 / &r1;
            let t = r0 - &q * &r1;
            r0 = r1;
            r1 = t;
            let t = t0 - q * &t1;
            t0 = t1;
            t1 = t;
            even = !even;
        } else {
            lehmer_update(&mut r0, &mut r1, &q);
            lehmer_update(&mut t0, &mut t1, &q);
            even ^= !q.4;
        }
    }
    if r0 == U256::ONE {
        
        Some(if even { modulus + t0 } else { t0 })
    } else {
        None
    }
}
#[allow(clippy::unreadable_literal)]
#[cfg(test)]
mod tests {
    use super::*;
    use quickcheck_macros::quickcheck;
    use zkp_macros_decl::u256h;
    #[test]
    fn test_lehmer_small() {
        assert_eq!(lehmer_small(0, 0), Matrix::IDENTITY);
        assert_eq!(
            lehmer_small(14535145444257436950, 5818365597666026993),
            Matrix(
                379355176803460069,
                947685836737753349,
                831195085380860999,
                2076449349179633850,
                false
            )
        );
        assert_eq!(
            lehmer_small(15507080595343815048, 10841422679839906593),
            Matrix(
                40154122160696118,
                57434639988632077,
                3613807559946635531,
                5169026865114605016,
                true
            )
        );
    }
    #[test]
    fn test_issue() {
        
        let a = u256h!("0000000000000054000000000000004f000000000000001f0000000000000028");
        let b = u256h!("0000000000000054000000000000005b000000000000002b000000000000005d");
        let _ = gcd(a, b);
    }
    #[test]
    fn test_lehmer_loop() {
        assert_eq!(lehmer_loop(1_u64 << 63, 0), Matrix::IDENTITY);
        assert_eq!(
            
            lehmer_loop(16194659139127649777, 14535145444257436950),
            Matrix(320831736, 357461893, 1018828859, 1135151083, true)
        );
        assert_eq!(
            
            lehmer_loop(15267531864828975732, 6325623274722585764,),
            Matrix(88810257, 214352542, 774927313, 1870365485, false)
        );
    }
    #[quickcheck]
    
    #[allow(clippy::shadow_unrelated)]
    fn test_lehmer_loop_match_gcd(mut a: u64, mut b: u64) -> bool {
        const LIMIT: u64 = 1_u64 << 32;
        
        a |= 1_u64 << 63;
        if b > a {
            core::mem::swap(&mut a, &mut b)
        }
        
        let update_matrix = lehmer_loop(a, b);
        
        assert!(update_matrix.0 < LIMIT);
        assert!(update_matrix.1 < LIMIT);
        assert!(update_matrix.2 < LIMIT);
        assert!(update_matrix.3 < LIMIT);
        if update_matrix == Matrix::IDENTITY {
            return true;
        }
        assert!(update_matrix.0 <= update_matrix.2);
        assert!(update_matrix.2 <= update_matrix.3);
        assert!(update_matrix.1 <= update_matrix.3);
        
        let mut a0 = a;
        let mut a1 = b;
        let mut s0 = 1;
        let mut s1 = 0;
        let mut t0 = 0;
        let mut t1 = 1;
        let mut even = true;
        while a1 > 0 {
            let r = a0 / a1;
            let t = a0 - r * a1;
            a0 = a1;
            a1 = t;
            let t = s0 + r * s1;
            s0 = s1;
            s1 = t;
            let t = t0 + r * t1;
            t0 = t1;
            t1 = t;
            even = !even;
            if update_matrix == Matrix(s0, t0, s1, t1, even) {
                return true;
            }
        }
        false
    }
    #[quickcheck]
    fn test_mat_mul_match_formula(
        a: U256,
        b: U256,
        q00: u64,
        q01: u64,
        q10: u64,
        q11: u64,
    ) -> bool {
        let a_expected = q00 * a.clone() - q01 * b.clone();
        let b_expected = q11 * b.clone() - q10 * a.clone();
        let mut a_result = a;
        let mut b_result = b;
        mat_mul(&mut a_result, &mut b_result, (q00, q01, q10, q11));
        a_result == a_expected && b_result == b_expected
    }
    #[test]
    fn test_lehmer_double() {
        assert_eq!(lehmer_double(U256::ZERO, U256::ZERO), Matrix::IDENTITY);
        assert_eq!(
            
            lehmer_double(
                u256h!("518a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814"),
                u256h!("018a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814")
            ),
            Matrix(
                2927556694930003,
                154961230633081597,
                3017020641586254,
                159696730135159213,
                true
            )
        );
    }
    #[test]
    fn test_gcd_lehmer() {
        assert_eq!(
            gcd_extended(U256::ZERO, U256::ZERO),
            (U256::ZERO, U256::ONE, U256::ZERO, true)
        );
        assert_eq!(
            gcd_extended(
                u256h!("fea5a792d0a17b24827908e5524bcceec3ec6a92a7a42eac3b93e2bb351cf4f2"),
                u256h!("00028735553c6c798ed1ffb8b694f8f37b672b1bab7f80c4e6f4c0e710c79fb4")
            ),
            (
                u256h!("0000000000000000000000000000000000000000000000000000000000000002"),
                u256h!("00000b5a5ecb4dfc4ea08773d0593986592959a646b2f97655ed839928274ebb"),
                u256h!("0477865490d3994853934bf7eae7dad9afac55ccbf412a60c18fc9bea58ec8ba"),
                false
            )
        );
        assert_eq!(
            gcd_extended(
                u256h!("518a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814"),
                u256h!("018a5cc4c55ac5b050a0831b65e827e5e39fd4515e4e094961c61509e7870814")
            ),
            (
                U256::from(4),
                u256h!("002c851a0dddfaa03b9db2e39d48067d9b57fa0d238b70c7feddf8d267accc41"),
                u256h!("0934869c752ae9c7d2ed8aa55e7754e5492aaac49f8c9f3416156313a16c1174"),
                true
            )
        );
        assert_eq!(
            gcd_extended(
                u256h!("7dfd26515f3cd365ea32e1a43dbac87a25d0326fd834a889cb1e4c6c3c8d368c"),
                u256h!("3d341ef315cbe5b9f0ab79255f9684e153deaf5f460a8425819c84ec1e80a2f3")
            ),
            (
                u256h!("0000000000000000000000000000000000000000000000000000000000000001"),
                u256h!("0bbc35a0c1fd8f1ae85377ead5a901d4fbf0345fa303a87a4b4b68429cd69293"),
                u256h!("18283a24821b7de14cf22afb0e1a7efb4212b7f373988f5a0d75f6ee0b936347"),
                false
            )
        );
        assert_eq!(
            gcd_extended(
                u256h!("836fab5d425345751b3425e733e8150a17fdab2d5fb840ede5e0879f41497a4f"),
                u256h!("196e875b381eb95d9b5c6c3f198c5092b3ccc21279a7e68bc42cb6bca2d2644d")
            ),
            (
                u256h!("000000000000000000000000000000000000000000000000c59f8490536754fd"),
                u256h!("000000000000000006865401d85836d50a2bd608f152186fb24072a122d0dc5d"),
                u256h!("000000000000000021b8940f60792f546cbeb17f8b852d33a00b14b323d6de70"),
                false
            )
        );
        assert_eq!(
            gcd_extended(
                u256h!("00253222ed7b612113dbea0be0e1a0b88f2c0c16250f54bf1ec35d62671bf83a"),
                u256h!("0000000000025d4e064960ef2964b2170f1cd63ab931968621dde8a867079fd4")
            ),
            (
                u256h!("000000000000000000000000000505b22b0a9fd5a6e2166e3486f0109e6f60b2"),
                u256h!("0000000000000000000000000000000000000000000000001f16d40433587ae9"),
                u256h!("0000000000000000000000000000000000000001e91177fbec66b1233e79662e"),
                true
            )
        );
        assert_eq!(
            gcd_extended(
                u256h!("0000000000025d4e064960ef2964b2170f1cd63ab931968621dde8a867079fd4"),
                u256h!("00253222ed7b612113dbea0be0e1a0b88f2c0c16250f54bf1ec35d62671bf83a")
            ),
            (
                u256h!("000000000000000000000000000505b22b0a9fd5a6e2166e3486f0109e6f60b2"),
                u256h!("0000000000000000000000000000000000000001e91177fbec66b1233e79662e"),
                u256h!("0000000000000000000000000000000000000000000000001f16d40433587ae9"),
                false
            )
        );
    }
    #[test]
    fn test_gcd_lehmer_extended_equal_inputs() {
        let a = U256::from(10);
        let b = U256::from(10);
        let (gcd, u, v, even) = gcd_extended(a.clone(), b.clone());
        assert_eq!(&a % &gcd, U256::ZERO);
        assert_eq!(&b % &gcd, U256::ZERO);
        assert!(!even);
        assert_eq!(gcd, v * b - u * a);
    }
    #[quickcheck]
    fn test_gcd_lehmer_extended(a: U256, b: U256) -> bool {
        let (gcd, u, v, even) = gcd_extended(a.clone(), b.clone());
        &a % &gcd == U256::ZERO
            && &b % &gcd == U256::ZERO
            && gcd == if even { u * a - v * b } else { v * b - u * a }
    }
    #[quickcheck]
    fn test_inv_lehmer(mut a: U256) -> bool {
        const MODULUS: U256 =
            u256h!("0800000000000011000000000000000000000000000000000000000000000001");
        a %= MODULUS;
        match inv_mod(&MODULUS, &a) {
            None => a == U256::ZERO,
            Some(a_inv) => a.mulmod(&a_inv, &MODULUS) == U256::ONE,
        }
    }
}