tasm_lib/arithmetic/u128/
safe_mul.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u128`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u128] [left: u128]
15/// AFTER:  _ [left · right: u128]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u128::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31    fn parameters(&self) -> Vec<(DataType, String)> {
32        ["right", "left"]
33            .map(|side| (DataType::U128, side.to_string()))
34            .to_vec()
35    }
36
37    fn return_values(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U128, "product".to_string())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u128_safe_mul".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        triton_asm!(
47            // BEFORE: _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0
48            // AFTER:  _ p_3 p_2 p_1 p_0
49            {self.entrypoint()}:
50                /*
51                 * p_0 is low limb, c_0 high limb of
52                 *        l_0·r_0
53                 *
54                 * p_1 is low limb, c_1 high limb of
55                 *        (l_1·r_0)_lo + (l_0·r_1)_lo
56                 *      + c_0
57                 *
58                 * p_2 is low limb, c_2 high limb of
59                 *        (l_1·r_0)_hi + (l_0·r_1)_hi
60                 *      + (l_2·r_0)_lo + (l_1·r_1)_lo + (l_0·r_2)_lo
61                 *      + c_1
62                 *
63                 * p_3 is low limb, c_3 high limb of
64                 *        (l_2·r_0)_hi + (l_1·r_1)_hi + (l_0·r_2)_hi
65                 *      + (l_3·r_0)_lo + (l_2·r_1)_lo + (l_1·r_2)_lo + (l_0·r_3)_lo
66                 *      + c_2
67                 *
68                 * All remaining limb combinations (l_3·r_1, l_3·r_2, l_3·r_3 l_2·r_2,
69                 * l_2·r_3, and l_1·r_3) as well as c_3 must be 0.
70                 */
71
72                /* p_0 */
73                dup 0 dup 5 mul split
74                // _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 p_0
75
76                place 9
77                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0
78
79                /* p_1 */
80                dup 2 dup 6 mul split
81                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo
82
83                dup 3 dup 9 mul split
84                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo (l_0·r_1)_hi (l_0·r_1)_lo
85                //                                       ^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
86
87                pick 2 pick 4
88                add add
89                split
90                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1 p_1
91
92                place 12
93                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1
94
95                /* p_2 */
96                add add
97                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip
98
99                dup 3 dup 6 mul split
100                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo
101
102                dup 4 dup 9 mul split
103                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo
104
105                dup 5 dup 12 mul split
106                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo (l_0·r_2)_hi (l_0·r_2)_lo
107                //                                           ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
108
109                pick 2 pick 4 pick 6
110                add add add
111                split
112                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2 p_2
113
114                place 14
115                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2
116
117                /* p_3 */
118                add add add
119                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_3_wip
120
121                dup 4 pick 6 mul split
122                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo
123
124                dup 5 dup 8 mul split
125                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo
126
127                dup 6 dup 11 mul split
128                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo
129
130                pick 7 dup 13 mul split
131                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo (l_0·l_3)_hi (l_0·l_3)_lo
132                //                                       ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
133
134                pick 2 pick 4 pick 6 pick 8
135                add add add add
136                split
137                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3 p_3
138
139                place 14
140                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3
141
142                /* overflow checks
143                 *
144                 * Carry c_3 and the high limbs still on stack are guaranteed to be smaller than
145                 * 2^32 since they resulted from instruction `split`. The sum of those 5 elements
146                 * cannot “wrap around” `BFieldElement::P`.
147                 */
148                add add add add
149                push 0 eq assert error_id 500
150                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1
151
152                /* l_3·r_1 */
153                dup 2 pick 4 mul
154                push 0 eq assert error_id 501
155                // _ [p; 4] r_3 r_2 l_3 l_2 l_1
156
157                /* l_2·r_2 */
158                dup 1 dup 4 mul
159                push 0 eq assert error_id 502
160                // _ [p; 4] r_3 r_2 l_3 l_2 l_1
161
162                /* l_1·r_3 */
163                dup 4 mul
164                push 0 eq assert error_id 503
165                // _ [p; 4] r_3 r_2 l_3 l_2
166
167                /* l_3·r_2 */
168                dup 1 pick 3 mul
169                push 0 eq assert error_id 504
170                // _ [p; 4] r_3 l_3 l_2
171
172                /* l_2·r_3 */
173                dup 2 mul
174                push 0 eq assert error_id 505
175                // _ [p; 4] r_3 l_3
176
177                /* l_3·r_3 */
178                mul
179                push 0 eq assert error_id 506
180                // _ [p; 4]
181
182                return
183        )
184    }
185
186    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
187        let mut sign_offs = HashMap::new();
188        sign_offs.insert(Reviewer("ferdinand"), 0xbba006a82c82b12f.into());
189        sign_offs
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use rand::rngs::StdRng;
196
197    use super::*;
198    use crate::test_prelude::*;
199
200    impl SafeMul {
201        fn test_assertion_failure(&self, left: u128, right: u128, error_ids: &[i128]) {
202            test_assertion_failure(
203                &ShadowedClosure::new(Self),
204                InitVmState::with_stack(self.set_up_test_stack((right, left))),
205                error_ids,
206            );
207        }
208    }
209
210    impl Closure for SafeMul {
211        type Args = (u128, u128);
212
213        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
214            let (right, left) = pop_encodable::<Self::Args>(stack);
215            let product = left.checked_mul(right).unwrap();
216            push_encodable(stack, &product);
217        }
218
219        fn pseudorandom_args(
220            &self,
221            seed: [u8; 32],
222            bench_case: Option<BenchmarkCase>,
223        ) -> Self::Args {
224            let Some(bench_case) = bench_case else {
225                let mut rng = StdRng::from_seed(seed);
226                let left = rng.random_range(1..=u128::MAX);
227                let right = rng.random_range(0..=u128::MAX / left);
228
229                return (right, left);
230            };
231
232            match bench_case {
233                BenchmarkCase::CommonCase => (1 << 63, (1 << 45) - 1),
234                BenchmarkCase::WorstCase => (1 << 63, (1 << 63) - 1),
235            }
236        }
237
238        fn corner_case_args(&self) -> Vec<Self::Args> {
239            const LEFT_NOISE: u128 = 0xfd4e_3f84_8677_df6b_da64_b83c_8267_c72d;
240            const RIGHT_NOISE: u128 = 0x538e_e051_c430_3e7a_0a29_a45a_5efb_67fa;
241
242            (0..u128::BITS)
243                .cartesian_product(0..u128::BITS)
244                .map(|(l, r)| {
245                    let left = (1 << l) | ((1 << l) - 1) & LEFT_NOISE;
246                    let right = (1 << r) | ((1 << r) - 1) & RIGHT_NOISE;
247                    (right, left)
248                })
249                .filter(|&(right, left)| left.checked_mul(right).is_some())
250                .step_by(5) // test performance is atrocious otherwise
251                .chain([(0, 0)])
252                .collect()
253        }
254    }
255
256    #[test]
257    fn rust_shadow() {
258        ShadowedClosure::new(SafeMul).test();
259    }
260
261    #[test]
262    fn overflow_crashes_vm() {
263        SafeMul.test_assertion_failure(1 << 127, 1 << 1, &[500]);
264        SafeMul.test_assertion_failure(1 << 96, 1 << 32, &[501]);
265        SafeMul.test_assertion_failure(1 << 64, 1 << 64, &[502]);
266        SafeMul.test_assertion_failure(1 << 32, 1 << 96, &[503]);
267        SafeMul.test_assertion_failure(1 << 96, 1 << 64, &[504]);
268        SafeMul.test_assertion_failure(1 << 64, 1 << 96, &[505]);
269        SafeMul.test_assertion_failure(1 << 96, 1 << 96, &[506]);
270
271        for i in 1..64 {
272            let left = u128::MAX >> i;
273            let right = (1 << i) + 1;
274            SafeMul.test_assertion_failure(left, right, &[500]);
275            SafeMul.test_assertion_failure(right, left, &[500]);
276        }
277
278        for i in 1..128 {
279            let left = 1 << i;
280            let right = 1 << (128 - i);
281            SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503]);
282        }
283    }
284
285    #[proptest(cases = 80)]
286    fn arbitrary_overflow_crashes_vm(
287        #[strategy(2_u8..128)] _log_upper_bound: u8,
288        #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
289        #[strategy(u128::MAX / #left + 1..)] right: u128,
290    ) {
291        SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
292    }
293
294    #[proptest(cases = 80)]
295    fn marginal_overflow_crashes_vm(
296        #[strategy(2_u8..128)] _log_upper_bound: u8,
297        #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
298    ) {
299        let right = u128::MAX / left + 1;
300        SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
301    }
302
303    #[proptest]
304    fn arbitrary_overflow_crashes_vm_u128(
305        #[strategy(2_u128..)] left: u128,
306        #[strategy(u128::MAX / #left + 1..)] right: u128,
307    ) {
308        SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
309    }
310}
311
312#[cfg(test)]
313mod benches {
314    use super::*;
315    use crate::test_prelude::*;
316
317    #[test]
318    fn benchmark() {
319        ShadowedClosure::new(SafeMul).bench();
320    }
321}