1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
use crate::arithmetic::{bits_384::to_bits, utils::*};

#[inline]
pub const fn add(a: [u64; 6], b: [u64; 6], p: [u64; 6]) -> [u64; 6] {
    let (l0, c) = adc(a[0], b[0], 0);
    let (l1, c) = adc(a[1], b[1], c);
    let (l2, c) = adc(a[2], b[2], c);
    let (l3, c) = adc(a[3], b[3], c);
    let (l4, c) = adc(a[4], b[4], c);
    let (l5, _) = adc(a[5], b[5], c);

    sub([l0, l1, l2, l3, l4, l5], p, p)
}

#[inline]
pub const fn sub(a: [u64; 6], b: [u64; 6], p: [u64; 6]) -> [u64; 6] {
    let (l0, brw) = sbb(a[0], b[0], 0);
    let (l1, brw) = sbb(a[1], b[1], brw);
    let (l2, brw) = sbb(a[2], b[2], brw);
    let (l3, brw) = sbb(a[3], b[3], brw);
    let (l4, brw) = sbb(a[4], b[4], brw);
    let (l5, brw) = sbb(a[5], b[5], brw);

    let (l0, c) = adc(l0, p[0] & brw, 0);
    let (l1, c) = adc(l1, p[1] & brw, c);
    let (l2, c) = adc(l2, p[2] & brw, c);
    let (l3, c) = adc(l3, p[3] & brw, c);
    let (l4, c) = adc(l4, p[4] & brw, c);
    let (l5, _) = adc(l5, p[5] & brw, c);

    [l0, l1, l2, l3, l4, l5]
}

#[inline]
pub const fn double(a: [u64; 6], p: [u64; 6]) -> [u64; 6] {
    add(a, a, p)
}

#[inline]
pub const fn mul(a: [u64; 6], b: [u64; 6], p: [u64; 6], inv: u64) -> [u64; 6] {
    let (l0, c) = mac(0, a[0], b[0], 0);
    let (l1, c) = mac(0, a[0], b[1], c);
    let (l2, c) = mac(0, a[0], b[2], c);
    let (l3, c) = mac(0, a[0], b[3], c);
    let (l4, c) = mac(0, a[0], b[4], c);
    let (l5, l6) = mac(0, a[0], b[5], c);

    let (l1, c) = mac(l1, a[1], b[0], 0);
    let (l2, c) = mac(l2, a[1], b[1], c);
    let (l3, c) = mac(l3, a[1], b[2], c);
    let (l4, c) = mac(l4, a[1], b[3], c);
    let (l5, c) = mac(l5, a[1], b[4], c);
    let (l6, l7) = mac(l6, a[1], b[5], c);

    let (l2, c) = mac(l2, a[2], b[0], 0);
    let (l3, c) = mac(l3, a[2], b[1], c);
    let (l4, c) = mac(l4, a[2], b[2], c);
    let (l5, c) = mac(l5, a[2], b[3], c);
    let (l6, c) = mac(l6, a[2], b[4], c);
    let (l7, l8) = mac(l7, a[2], b[5], c);

    let (l3, c) = mac(l3, a[3], b[0], 0);
    let (l4, c) = mac(l4, a[3], b[1], c);
    let (l5, c) = mac(l5, a[3], b[2], c);
    let (l6, c) = mac(l6, a[3], b[3], c);
    let (l7, c) = mac(l7, a[3], b[4], c);
    let (l8, l9) = mac(l8, a[3], b[5], c);

    let (l4, c) = mac(l4, a[4], b[0], 0);
    let (l5, c) = mac(l5, a[4], b[1], c);
    let (l6, c) = mac(l6, a[4], b[2], c);
    let (l7, c) = mac(l7, a[4], b[3], c);
    let (l8, c) = mac(l8, a[4], b[4], c);
    let (l9, l10) = mac(l9, a[4], b[5], c);

    let (l5, c) = mac(l5, a[5], b[0], 0);
    let (l6, c) = mac(l6, a[5], b[1], c);
    let (l7, c) = mac(l7, a[5], b[2], c);
    let (l8, c) = mac(l8, a[5], b[3], c);
    let (l9, c) = mac(l9, a[5], b[4], c);
    let (l10, l11) = mac(l10, a[5], b[5], c);

    mont([l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11], p, inv)
}

#[inline]
pub const fn square(a: [u64; 6], p: [u64; 6], inv: u64) -> [u64; 6] {
    mul(a, a, p, inv)
}

#[inline]
pub const fn neg(a: [u64; 6], p: [u64; 6]) -> [u64; 6] {
    if (a[0] | a[1] | a[2] | a[3] | a[4] | a[5]) == 0 {
        a
    } else {
        sub(p, a, p)
    }
}

#[inline]
pub const fn mont(a: [u64; 12], p: [u64; 6], inv: u64) -> [u64; 6] {
    let rhs = a[0].wrapping_mul(inv);

    let (_, d) = mac(a[0], rhs, p[0], 0);
    let (l1, d) = mac(a[1], rhs, p[1], d);
    let (l2, d) = mac(a[2], rhs, p[2], d);
    let (l3, d) = mac(a[3], rhs, p[3], d);
    let (l4, d) = mac(a[4], rhs, p[4], d);
    let (l5, d) = mac(a[5], rhs, p[5], d);
    let (l6, e) = adc(a[6], 0, d);

    let rhs = l1.wrapping_mul(inv);

    let (_, d) = mac(l1, rhs, p[0], 0);
    let (l2, d) = mac(l2, rhs, p[1], d);
    let (l3, d) = mac(l3, rhs, p[2], d);
    let (l4, d) = mac(l4, rhs, p[3], d);
    let (l5, d) = mac(l5, rhs, p[4], d);
    let (l6, d) = mac(l6, rhs, p[5], d);
    let (l7, e) = adc(a[7], e, d);

    let rhs = l2.wrapping_mul(inv);
    let (_, d) = mac(l2, rhs, p[0], 0);
    let (l3, d) = mac(l3, rhs, p[1], d);
    let (l4, d) = mac(l4, rhs, p[2], d);
    let (l5, d) = mac(l5, rhs, p[3], d);
    let (l6, d) = mac(l6, rhs, p[4], d);
    let (l7, d) = mac(l7, rhs, p[5], d);
    let (l8, e) = adc(a[8], e, d);

    let rhs = l3.wrapping_mul(inv);
    let (_, d) = mac(l3, rhs, p[0], 0);
    let (l4, d) = mac(l4, rhs, p[1], d);
    let (l5, d) = mac(l5, rhs, p[2], d);
    let (l6, d) = mac(l6, rhs, p[3], d);
    let (l7, d) = mac(l7, rhs, p[4], d);
    let (l8, d) = mac(l8, rhs, p[5], d);
    let (l9, e) = adc(a[9], e, d);

    let rhs = l4.wrapping_mul(inv);
    let (_, d) = mac(l4, rhs, p[0], 0);
    let (l5, d) = mac(l5, rhs, p[1], d);
    let (l6, d) = mac(l6, rhs, p[2], d);
    let (l7, d) = mac(l7, rhs, p[3], d);
    let (l8, d) = mac(l8, rhs, p[4], d);
    let (l9, d) = mac(l9, rhs, p[5], d);
    let (l10, e) = adc(a[10], e, d);

    let rhs = l5.wrapping_mul(inv);
    let (_, d) = mac(l5, rhs, p[0], 0);
    let (l6, d) = mac(l6, rhs, p[1], d);
    let (l7, d) = mac(l7, rhs, p[2], d);
    let (l8, d) = mac(l8, rhs, p[3], d);
    let (l9, d) = mac(l9, rhs, p[4], d);
    let (l10, d) = mac(l10, rhs, p[5], d);
    let (l11, _) = adc(a[11], e, d);

    sub([l6, l7, l8, l9, l10, l11], p, p)
}

#[inline]
pub fn invert(
    a: [u64; 6],
    little_fermat: [u64; 6],
    identity: [u64; 6],
    p: [u64; 6],
    inv: u64,
) -> Option<[u64; 6]> {
    let zero: [u64; 6] = [0, 0, 0, 0, 0, 0];
    if a == zero {
        None
    } else {
        Some(pow(a, little_fermat, identity, p, inv))
    }
}

pub fn pow(a: [u64; 6], b: [u64; 6], mut identity: [u64; 6], p: [u64; 6], inv: u64) -> [u64; 6] {
    let bits = to_bits(b);
    for &bit in bits.iter() {
        identity = square(identity, p, inv);
        if bit == 1 {
            identity = mul(identity, a, p, inv);
        }
    }
    identity
}