tasm_lib/arithmetic/u160/
safe_mul.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic;
4use crate::arithmetic::u64::mul_two_u64s_to_u128::MulTwoU64sToU128;
5use crate::prelude::*;
6
7/// Multiply two `u160`s and crash on overflow.
8///
9/// ### Behavior
10///
11/// ```text
12/// BEFORE: _ [right: u160] [left: u160]
13/// AFTER:  _ [left · right: u160]
14/// ```
15///
16/// ### Preconditions
17///
18/// - all input arguments are properly [`BFieldCodec`] encoded
19/// - the product of `left` and `right` is less than or equal to
20///   0xff..ff = $2^{160} - 1$.
21///
22/// ### Postconditions
23///
24/// - the output is the product of the input
25/// - the output is properly [`BFieldCodec`] encoded
26#[derive(Debug, Clone)]
27pub struct SafeMul;
28
29impl SafeMul {
30    pub(crate) const OVERFLOW_0: i128 = 580;
31    pub(crate) const OVERFLOW_1: i128 = 581;
32    pub(crate) const OVERFLOW_2: i128 = 582;
33    pub(crate) const OVERFLOW_3: i128 = 583;
34    pub(crate) const OVERFLOW_4: i128 = 584;
35}
36
37impl BasicSnippet for SafeMul {
38    fn parameters(&self) -> Vec<(DataType, String)> {
39        ["right", "left"]
40            .map(|side| (DataType::U160, side.to_string()))
41            .to_vec()
42    }
43
44    fn return_values(&self) -> Vec<(DataType, String)> {
45        vec![(DataType::U160, "product".to_string())]
46    }
47
48    fn entrypoint(&self) -> String {
49        "tasmlib_arithmetic_u160_safe_mul".to_string()
50    }
51
52    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
53        let u64_to_u128_mul = library.import(Box::new(MulTwoU64sToU128));
54        let u64_safe_mul = library.import(Box::new(arithmetic::u64::safe_mul::SafeMul));
55        let u64_safe_add = library.import(Box::new(arithmetic::u64::add::Add));
56        let u128_safe_add = library.import(Box::new(arithmetic::u128::safe_add::SafeAdd));
57        let u160_safe_add = library.import(Box::new(arithmetic::u160::safe_add::SafeAdd));
58
59        triton_asm!(
60            // BEFORE: _ r_4 r_3 r_2 r_1 r_0 l_4 l_3 l_2 l_1 l_0
61            // AFTER:  _ p_4 p_3 p_2 p_1 p_0
62            {self.entrypoint()}:
63
64
65
66                /* Cast higest limbs, r_4 & l_4, to u64 */
67                push 0
68                place 10
69                push 0
70                place 5
71                // _ 0 r_4 r_3 r_2 r_1 r_0 0 l_4 l_3 l_2 l_1 l_0
72
73                /* Reinterpret as limbs of u64s */
74                // _ [r_c] [r_b] [r_a] [l_c] [l_b] [l_a]
75
76                /* Verify required high limbs are zero */
77                /*
78                   r_b * l_c == 0 ∧ r_c * l_b == 0 ∧ r_c * l_c == 0 =>
79                   (r_b == 0 || l_c == 0) &&
80                   (r_c == 0 || l_b == 0) &&
81                   (r_c == 0 || l_c == 0)
82                */
83
84                dup 9
85                push 0
86                eq
87                dup 9
88                push 0
89                eq
90                mul
91                // _ ... (r_b == 0)
92
93                dup 5
94                push 0
95                eq
96                // _ ... (r_b == 0) (l_c == 0)
97
98                dup 12
99                push 0
100                eq
101                // _ ... (r_b == 0) (l_c == 0) (r_c == 0)
102
103                dup 6
104                push 0
105                eq
106                dup 6
107                push 0
108                eq
109                mul
110                // _ ... (r_b == 0) (l_c == 0) (r_c == 0) (l_b == 0)
111
112                dup 2
113                dup 2
114                add
115                pop_count
116                // _ ... (r_b == 0) (l_c == 0) (r_c == 0) (l_b == 0) (r_c == 0 || l_c == 0)
117
118                assert error_id {Self::OVERFLOW_0}
119                // _ ... (r_b == 0) (l_c == 0) (r_c == 0) (l_b == 0)
120
121                add
122                pop_count
123                // _ ... (r_b == 0) (l_c == 0) (r_c == 0 || l_b == 0)
124
125                assert error_id {Self::OVERFLOW_1}
126                // _ ... (r_b == 0) (l_c == 0)
127
128                add
129                pop_count
130                // _ ... (r_b == 0 || l_c == 0)
131
132                assert error_id {Self::OVERFLOW_2}
133                // _ ...
134
135                // _ [r_c] [r_b] [r_a] [l_c] [l_b] [l_a]
136
137
138                /* Calculate 2^128 products */
139                /* r_c*l_a + r_b*l_b + r_a*l_c */
140                pick 11
141                pick 11
142                dup 3
143                dup 3
144                call {u64_safe_mul}
145                // _ [r_b] [r_a] [l_c] [l_b] [l_a] [r_c * l_a]
146
147                dup 11
148                dup 11
149                dup 7
150                dup 7
151                call {u64_safe_mul}
152                // _ [r_b] [r_a] [l_c] [l_b] [l_a] [r_c * l_a] [r_b * l_b]
153
154                dup 11
155                dup 11
156                pick 11
157                pick 11
158                call {u64_safe_mul}
159                // _ [r_b] [r_a] [l_b] [l_a] [r_c * l_a] [r_b * l_b] [r_a * l_c]
160
161
162                call {u64_safe_add}
163                call {u64_safe_add}
164                // _ [r_b] [r_a] [l_b] [l_a] [r_c * l_a + r_b * l_b + r_a * l_c]
165
166
167                /* Verify bound by u32::MAX as term is multiples of 2^{128} */
168                pick 1
169                push 0
170                eq
171                assert error_id {Self::OVERFLOW_3}
172                // _ [r_b] [r_a] [l_b] [l_a] (r_c * l_a + r_b * l_b + r_a * l_c: u32)
173                // _ [r_b] [r_a] [l_b] [l_a] (fact2: u32) <-rename
174
175                push 0
176                push 0
177                push 0
178                push 0
179                // _ [r_b] [r_a] [l_b] [l_a] [term2: u160]
180
181
182                /* Calculate 2^64 products */
183                pick 12
184                pick 12
185                dup 8
186                dup 8
187                call {u64_to_u128_mul}
188                // _ [r_a] [l_b] [l_a] [term2: u160] [r_b * l_a: u128]
189
190                dup 14
191                dup 14
192                pick 14
193                pick 14
194                call {u64_to_u128_mul}
195                // _ [r_a] [l_a] [term2: u160] [r_b * l_a: u128] [r_a * l_b: u128]
196
197                call {u128_safe_add}
198                // _ [r_a] [l_a] [term2: u160] [r_b * l_a + r_a * l_b: u128]
199                // _ [r_a] [l_a] [term2: u160] [fact1: u128] <- rename
200
201                /* Ensure fact1 bounded by 2^{96} */
202                pick 3
203                push 0
204                eq
205                assert error_id {Self::OVERFLOW_4}
206                // _ [r_a] [l_a] [term2: u160] [fact1: u96]
207
208                push 0
209                push 0
210                // _ [r_a] [l_a] [term2: u160] [term1: u160]
211
212                push 0
213                // _ [r_a] [l_a] [term2: u160] [term1: u160] 0
214
215                pick 14
216                pick 14
217                pick 14
218                pick 14
219                call {u64_to_u128_mul}
220                // _ [term2: u160] [term1: u160] 0 [r_a * l_a: u128]
221                // _ [term2: u160] [term1: u160] [r_a * l_a: u160]
222                // _ [term2: u160] [term1: u160] [term0: u160] <- rename
223
224                call {u160_safe_add}
225                call {u160_safe_add}
226                // _ [term2 + term1 + term0: u160]
227
228                // _ [prod: u160]
229
230                return
231        )
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use num::BigUint;
238    use num::One;
239    use rand::rngs::StdRng;
240
241    use super::*;
242    use crate::arithmetic::u160::u128_to_u160;
243    use crate::arithmetic::u160::u128_to_u160_shl_32;
244    use crate::arithmetic::u160::u128_to_u160_shl_32_lower_limb_filled;
245    use crate::test_prelude::*;
246
247    impl SafeMul {
248        fn test_assertion_failure(&self, left: [u32; 5], right: [u32; 5], error_ids: &[i128]) {
249            test_assertion_failure(
250                &ShadowedClosure::new(Self),
251                InitVmState::with_stack(self.set_up_test_stack((right, left))),
252                error_ids,
253            );
254        }
255    }
256
257    #[test]
258    fn rust_shadow() {
259        ShadowedClosure::new(SafeMul).test()
260    }
261
262    #[test]
263    fn overflow_unit_test() {
264        SafeMul.test_assertion_failure(
265            u128_to_u160_shl_32(u128::MAX),
266            u128_to_u160_shl_32(u128::MAX),
267            &[580],
268        );
269        SafeMul.test_assertion_failure(
270            u128_to_u160_shl_32(1u128 << 64),
271            u128_to_u160_shl_32(u128::MAX),
272            &[581],
273        );
274        SafeMul.test_assertion_failure(
275            u128_to_u160_shl_32(u128::MAX),
276            u128_to_u160_shl_32(1u128 << 64),
277            &[582],
278        );
279        SafeMul.test_assertion_failure(
280            u128_to_u160(1u128 << 64),
281            u128_to_u160(1u128 << 96),
282            &[583],
283        );
284        SafeMul.test_assertion_failure(
285            u128_to_u160(1u128 << 96),
286            u128_to_u160(1u128 << 64),
287            &[583],
288        );
289        SafeMul.test_assertion_failure(
290            u128_to_u160((1u128 << 64) - 1),
291            u128_to_u160(1u128 << 99),
292            &[584],
293        );
294        SafeMul.test_assertion_failure(
295            u128_to_u160(1u128 << 99),
296            u128_to_u160((1u128 << 64) - 1),
297            &[584],
298        );
299        SafeMul.test_assertion_failure(u128_to_u160(2), u128_to_u160_shl_32(1 << 127), &[583]);
300        SafeMul.test_assertion_failure(u128_to_u160_shl_32(1 << 127), u128_to_u160(2), &[583]);
301    }
302
303    #[proptest(cases = 100)]
304    fn arbitrary_overflow_crashes_vm_u128(
305        #[strategy(2_u128..)] left: u128,
306        #[strategy(u128::MAX / #left + 1..)] right: u128,
307    ) {
308        let left = u128_to_u160_shl_32(left);
309        let right = u128_to_u160(right);
310        SafeMul.test_assertion_failure(left, right, &[580, 581, 582, 583, 584, 570]);
311    }
312
313    #[proptest(cases = 50)]
314    fn marginal_overflow_crashes_vm(
315        #[strategy(2_u8..128)] _log_upper_bound: u8,
316        #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
317    ) {
318        let right = u128::MAX / left + 1;
319
320        let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
321        SafeMul.test_assertion_failure(
322            u128_to_u160_shl_32(left),
323            u128_to_u160(right),
324            &expected_error_codes,
325        );
326        SafeMul.test_assertion_failure(
327            u128_to_u160(left),
328            u128_to_u160_shl_32(right),
329            &expected_error_codes,
330        );
331    }
332
333    #[proptest(cases = 50)]
334    fn arbitrary_overflow_crashes_vm(
335        #[strategy(2_u8..128)] _log_upper_bound: u8,
336        #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
337        #[strategy(u128::MAX / #left + 1..)] right: u128,
338    ) {
339        let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
340        SafeMul.test_assertion_failure(
341            u128_to_u160_shl_32(left),
342            u128_to_u160(right),
343            &expected_error_codes,
344        );
345        SafeMul.test_assertion_failure(
346            u128_to_u160(left),
347            u128_to_u160_shl_32(right),
348            &expected_error_codes,
349        );
350        SafeMul.test_assertion_failure(
351            u128_to_u160_shl_32_lower_limb_filled(left),
352            u128_to_u160(right),
353            &expected_error_codes,
354        );
355        SafeMul.test_assertion_failure(
356            u128_to_u160(left),
357            u128_to_u160_shl_32_lower_limb_filled(right),
358            &expected_error_codes,
359        );
360
361        // Much overflow
362        SafeMul.test_assertion_failure(
363            u128_to_u160_shl_32(left),
364            u128_to_u160_shl_32(right),
365            &expected_error_codes,
366        );
367        SafeMul.test_assertion_failure(
368            u128_to_u160_shl_32(left),
369            u128_to_u160_shl_32(right),
370            &expected_error_codes,
371        );
372    }
373
374    impl Closure for SafeMul {
375        type Args = ([u32; 5], [u32; 5]);
376
377        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
378            let left: [u32; 5] = pop_encodable(stack);
379            let left: BigUint = BigUint::new(left.to_vec());
380            let right: [u32; 5] = pop_encodable(stack);
381            let right: BigUint = BigUint::new(right.to_vec());
382            let prod = left.clone() * right.clone();
383            let mut prod = prod.to_u32_digits();
384            assert!(prod.len() <= 5, "Overflow: left: {left}, right: {right}.");
385
386            prod.resize(5, 0);
387            let prod: [u32; 5] = prod.try_into().unwrap();
388
389            push_encodable(stack, &prod);
390        }
391
392        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
393            let mut rng = StdRng::from_seed(seed);
394            let lhs: [u32; 5] = rng.random();
395            let lhs_as_biguint = BigUint::new(lhs.to_vec());
396
397            let u160_max = BigUint::from_bytes_be(&[0xFF; 20]);
398            let max = &u160_max / &lhs_as_biguint;
399
400            let bits: u32 = max.bits().try_into().unwrap();
401            let bit_mask = BigUint::from(2u32).pow(bits) - BigUint::one();
402            let mut bit_mask = bit_mask.to_bytes_be();
403            bit_mask.reverse();
404            bit_mask.resize(20, 0);
405            bit_mask.reverse();
406            let mut rhs_bytes = [0u8; 20];
407            let rhs = loop {
408                rng.fill(&mut rhs_bytes);
409                for i in 0..20 {
410                    rhs_bytes[i] &= bit_mask[i];
411                }
412                let candidate = BigUint::from_bytes_be(&rhs_bytes);
413                if candidate < max {
414                    break candidate;
415                }
416            };
417
418            {
419                let prod = lhs_as_biguint * rhs.clone();
420                assert!(prod.to_u32_digits().len() <= 5);
421            }
422
423            let mut rhs = rhs.to_u32_digits();
424            rhs.resize(5, 0);
425
426            (lhs, rhs.try_into().unwrap())
427        }
428
429        fn corner_case_args(&self) -> Vec<Self::Args> {
430            fn u160_checked_mul(l: [u32; 5], r: [u32; 5]) -> Option<[u32; 5]> {
431                let l: BigUint = BigUint::new(l.to_vec());
432                let r: BigUint = BigUint::new(r.to_vec());
433
434                let prod = l * r;
435                let mut prod = prod.to_u32_digits();
436
437                if prod.len() > 5 {
438                    None
439                } else {
440                    prod.resize(5, 0);
441                    Some(prod.try_into().unwrap())
442                }
443            }
444
445            let edge_case_points = vec![
446                u128_to_u160(0),
447                u128_to_u160(1),
448                u128_to_u160(2),
449                u128_to_u160(u8::MAX as u128),
450                u128_to_u160(1 << 8),
451                u128_to_u160(u16::MAX as u128),
452                u128_to_u160(1 << 16),
453                u128_to_u160(u32::MAX as u128),
454                u128_to_u160(1 << 32),
455                u128_to_u160(u64::MAX as u128),
456                u128_to_u160(1 << 64),
457                [u32::MAX, u32::MAX, u32::MAX, 0, 0],
458                u128_to_u160(1 << 96),
459                u128_to_u160(u128::MAX),
460                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX >> 1],
461                [u32::MAX; 5],
462            ];
463
464            edge_case_points
465                .iter()
466                .cartesian_product(&edge_case_points)
467                .filter(|&(&l, &r)| u160_checked_mul(l, r).is_some())
468                .map(|(&l, &r)| (l, r))
469                .collect()
470        }
471    }
472}
473
474#[cfg(test)]
475mod benches {
476    use super::*;
477    use crate::test_prelude::*;
478
479    #[test]
480    fn benchmark() {
481        ShadowedClosure::new(SafeMul).bench()
482    }
483}