tasm_lib/arithmetic/u128/
safe_mul.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u128`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u128] [left: u128]
15/// AFTER:  _ [left · right: u128]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u128::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31    fn inputs(&self) -> Vec<(DataType, String)> {
32        ["right", "left"]
33            .map(|side| (DataType::U128, side.to_string()))
34            .to_vec()
35    }
36
37    fn outputs(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U128, "product".to_string())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u128_safe_mul".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        triton_asm!(
47            // BEFORE: _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0
48            // AFTER:  _ p_3 p_2 p_1 p_0
49            {self.entrypoint()}:
50                /*
51                 * p_0 is low limb, c_0 high limb of
52                 *        l_0·r_0
53                 *
54                 * p_1 is low limb, c_1 high limb of
55                 *        (l_1·r_0)_lo + (l_0·r_1)_lo
56                 *      + c_0
57                 *
58                 * p_2 is low limb, c_2 high limb of
59                 *        (l_1·r_0)_hi + (l_0·r_1)_hi
60                 *      + (l_2·r_0)_lo + (l_1·r_1)_lo + (l_0·r_2)_lo
61                 *      + c_1
62                 *
63                 * p_3 is low limb, c_3 high limb of
64                 *        (l_2·r_0)_hi + (l_1·r_1)_hi + (l_0·r_2)_hi
65                 *      + (l_3·r_0)_lo + (l_2·r_1)_lo + (l_1·r_2)_lo + (l_0·r_3)_lo
66                 *      + c_2
67                 *
68                 * All remaining limb combinations (l_3·r_1, l_3·r_2, l_3·r_3 l_2·r_2,
69                 * l_2·r_3, and l_1·r_3) as well as c_3 must be 0.
70                 */
71
72                /* p_0 */
73                dup 0 dup 5 mul split
74                // _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 p_0
75
76                place 9
77                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0
78
79                /* p_1 */
80                dup 2 dup 6 mul split
81                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo
82
83                dup 3 dup 9 mul split
84                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo (l_0·r_1)_hi (l_0·r_1)_lo
85                //                                       ^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
86
87                pick 2 pick 4
88                add add
89                split
90                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1 p_1
91
92                place 12
93                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1
94
95                /* p_2 */
96                add add
97                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip
98
99                dup 3 dup 6 mul split
100                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo
101
102                dup 4 dup 9 mul split
103                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo
104
105                dup 5 dup 12 mul split
106                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo (l_0·r_2)_hi (l_0·r_2)_lo
107                //                                           ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
108
109                pick 2 pick 4 pick 6
110                add add add
111                split
112                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2 p_2
113
114                place 14
115                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2
116
117                /* p_3 */
118                add add add
119                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_3_wip
120
121                dup 4 pick 6 mul split
122                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo
123
124                dup 5 dup 8 mul split
125                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo
126
127                dup 6 dup 11 mul split
128                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo
129
130                pick 7 dup 13 mul split
131                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo (l_0·l_3)_hi (l_0·l_3)_lo
132                //                                       ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^
133
134                pick 2 pick 4 pick 6 pick 8
135                add add add add
136                split
137                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3 p_3
138
139                place 14
140                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3
141
142                /* overflow checks
143                 *
144                 * Carry c_3 and the high limbs still on stack are guaranteed to be smaller than
145                 * 2^32 since they resulted from instruction `split`. The sum of those 5 elements
146                 * cannot “wrap around” `BFieldElement::P`.
147                 */
148                add add add add
149                push 0 eq assert error_id 500
150                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1
151
152                /* l_3·r_1 */
153                dup 2 pick 4 mul
154                push 0 eq assert error_id 501
155                // _ [p; 4] r_3 r_2 l_3 l_2 l_1
156
157                /* l_2·r_2 */
158                dup 1 dup 4 mul
159                push 0 eq assert error_id 502
160                // _ [p; 4] r_3 r_2 l_3 l_2 l_1
161
162                /* l_1·r_3 */
163                dup 4 mul
164                push 0 eq assert error_id 503
165                // _ [p; 4] r_3 r_2 l_3 l_2
166
167                /* l_3·r_2 */
168                dup 1 pick 3 mul
169                push 0 eq assert error_id 504
170                // _ [p; 4] r_3 l_3 l_2
171
172                /* l_2·r_3 */
173                dup 2 mul
174                push 0 eq assert error_id 505
175                // _ [p; 4] r_3 l_3
176
177                /* l_3·r_3 */
178                mul
179                push 0 eq assert error_id 506
180                // _ [p; 4]
181
182                return
183        )
184    }
185
186    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
187        let mut sign_offs = HashMap::new();
188        sign_offs.insert(Reviewer("ferdinand"), 0x6a6ab0928dd2f0e4.into());
189        sign_offs
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::test_prelude::*;
197    use rand::rngs::StdRng;
198
199    impl SafeMul {
200        fn test_assertion_failure(&self, left: u128, right: u128, error_ids: &[i128]) {
201            test_assertion_failure(
202                &ShadowedClosure::new(Self),
203                InitVmState::with_stack(self.set_up_test_stack((right, left))),
204                error_ids,
205            );
206        }
207    }
208
209    impl Closure for SafeMul {
210        type Args = (u128, u128);
211
212        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
213            let (right, left) = pop_encodable::<Self::Args>(stack);
214            let product = left.checked_mul(right).unwrap();
215            push_encodable(stack, &product);
216        }
217
218        fn pseudorandom_args(
219            &self,
220            seed: [u8; 32],
221            bench_case: Option<BenchmarkCase>,
222        ) -> Self::Args {
223            let Some(bench_case) = bench_case else {
224                let mut rng = StdRng::from_seed(seed);
225                let left = rng.random_range(1..=u128::MAX);
226                let right = rng.random_range(0..=u128::MAX / left);
227
228                return (right, left);
229            };
230
231            match bench_case {
232                BenchmarkCase::CommonCase => (1 << 63, (1 << 45) - 1),
233                BenchmarkCase::WorstCase => (1 << 63, (1 << 63) - 1),
234            }
235        }
236
237        fn corner_case_args(&self) -> Vec<Self::Args> {
238            const LEFT_NOISE: u128 = 0xfd4e_3f84_8677_df6b_da64_b83c_8267_c72d;
239            const RIGHT_NOISE: u128 = 0x538e_e051_c430_3e7a_0a29_a45a_5efb_67fa;
240
241            (0..u128::BITS)
242                .cartesian_product(0..u128::BITS)
243                .map(|(l, r)| {
244                    let left = (1 << l) | ((1 << l) - 1) & LEFT_NOISE;
245                    let right = (1 << r) | ((1 << r) - 1) & RIGHT_NOISE;
246                    (right, left)
247                })
248                .filter(|&(right, left)| left.checked_mul(right).is_some())
249                .step_by(5) // test performance is atrocious otherwise
250                .chain([(0, 0)])
251                .collect()
252        }
253    }
254
255    #[test]
256    fn rust_shadow() {
257        ShadowedClosure::new(SafeMul).test();
258    }
259
260    #[test]
261    fn overflow_crashes_vm() {
262        SafeMul.test_assertion_failure(1 << 127, 1 << 1, &[500]);
263        SafeMul.test_assertion_failure(1 << 96, 1 << 32, &[501]);
264        SafeMul.test_assertion_failure(1 << 64, 1 << 64, &[502]);
265        SafeMul.test_assertion_failure(1 << 32, 1 << 96, &[503]);
266        SafeMul.test_assertion_failure(1 << 96, 1 << 64, &[504]);
267        SafeMul.test_assertion_failure(1 << 64, 1 << 96, &[505]);
268        SafeMul.test_assertion_failure(1 << 96, 1 << 96, &[506]);
269
270        for i in 1..64 {
271            let left = u128::MAX >> i;
272            let right = (1 << i) + 1;
273            SafeMul.test_assertion_failure(left, right, &[500]);
274            SafeMul.test_assertion_failure(right, left, &[500]);
275        }
276
277        for i in 1..128 {
278            let left = 1 << i;
279            let right = 1 << (128 - i);
280            SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503]);
281        }
282    }
283
284    #[proptest(cases = 1_000)]
285    fn arbitrary_overflow_crashes_vm(
286        #[strategy(2_u128..)] left: u128,
287        #[strategy(u128::MAX / #left + 1..)] right: u128,
288    ) {
289        SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
290    }
291}
292
293#[cfg(test)]
294mod benches {
295    use super::*;
296    use crate::test_prelude::*;
297
298    #[test]
299    fn benchmark() {
300        ShadowedClosure::new(SafeMul).bench();
301    }
302}