tasm_lib/arithmetic/u64/
safe_mul.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u64`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u64] [left: u64]
15/// AFTER:  _ [right · left: u64]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u64::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31    fn inputs(&self) -> Vec<(DataType, String)> {
32        ["rhs", "lhs"]
33            .map(|side| (DataType::U64, side.to_string()))
34            .to_vec()
35    }
36
37    fn outputs(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U64, "product".to_string())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u64_safe_mul".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        triton_asm!(
47            // BEFORE: _ right_hi right_lo left_hi left_lo
48            // AFTER:  _ prod_hi prod_lo
49            {self.entrypoint()}:
50                /* left_lo · right_lo */
51                dup 0
52                dup 3
53                mul
54                // _ right_hi right_lo left_hi left_lo (left_lo · right_lo)
55
56                /* left_lo · right_hi (consume left_lo) */
57                dup 4
58                pick 2
59                mul
60                // _ right_hi right_lo left_hi (left_lo · right_lo) (left_lo · right_hi)
61
62                /* left_hi · right_lo (consume right_lo) */
63                pick 3
64                dup 3
65                mul
66                // _ right_hi left_hi (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
67
68                /* left_hi · right_hi (consume left_hi and right_hi) */
69                pick 4
70                pick 4
71                mul
72                // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo) (left_hi · right_hi)
73
74                /* assert left_hi · right_hi == 0 */
75                push 0
76                eq
77                assert error_id 100
78                // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
79                // _ lolo                 lohi                 hilo
80
81                /* prod_hi = lolo_hi + lohi_lo + hilo_lo */
82                split
83                pick 1
84                push 0
85                eq
86                assert error_id 101
87                // _ lolo lohi hilo_lo
88
89                pick 1
90                split
91                pick 1
92                push 0
93                eq
94                assert error_id 102
95                // _ lolo hilo_lo lohi_lo
96
97
98                pick 2
99                split
100                // _ hilo_lo lohi_lo lolo_hi lolo_lo
101                // _ hilo_lo lohi_lo lolo_hi prod_lo
102
103                place 3
104                add
105                add
106                // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)
107
108                split
109                pick 1
110                push 0
111                eq
112                assert error_id 103
113                // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)_lo
114                // _ prod_lo prod_hi
115
116                place 1
117                return
118        )
119    }
120
121    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
122        let mut sign_offs = HashMap::new();
123        sign_offs.insert(Reviewer("ferdinand"), 0xaaa2259189834687.into());
124        sign_offs
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::test_prelude::*;
132
133    impl Closure for SafeMul {
134        type Args = (u64, u64);
135
136        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
137            let (right, left) = pop_encodable::<Self::Args>(stack);
138            let (product, is_overflow) = left.overflowing_mul(right);
139            assert!(!is_overflow);
140            push_encodable(stack, &product);
141        }
142
143        fn pseudorandom_args(
144            &self,
145            seed: [u8; 32],
146            bench_case: Option<BenchmarkCase>,
147        ) -> Self::Args {
148            let Some(bench_case) = bench_case else {
149                let mut rng = StdRng::from_seed(seed);
150                return (rng.next_u32().into(), rng.next_u32().into());
151            };
152
153            match bench_case {
154                BenchmarkCase::CommonCase => (1 << 31, (1 << 25) - 1),
155                BenchmarkCase::WorstCase => (1 << 31, (1 << 31) - 1),
156            }
157        }
158    }
159
160    #[test]
161    fn rust_shadow() {
162        ShadowedClosure::new(SafeMul).test();
163    }
164
165    #[test]
166    fn overflow_tests() {
167        let failure_conditions = [
168            (1 << 32, 1 << 32, 100),             // (left_hi · right_hi) != 0
169            (1 << 31, 1 << 33, 101),             // (left_lo · right_hi)_hi != 0
170            (1 << 33, 1 << 31, 102),             // (left_hi · right_lo)_hi != 0
171            ((1 << 31) - 1, (1 << 33) + 5, 103), // (hilo_lo + lohi_lo + lolo_hi)_hi != 0
172        ];
173
174        for (left, right, error_id) in failure_conditions {
175            let safe_mul = ShadowedClosure::new(SafeMul);
176            let stack = SafeMul.set_up_test_stack((left, right));
177            let vm_state = InitVmState::with_stack(stack);
178            test_assertion_failure(&safe_mul, vm_state, &[error_id]);
179        }
180    }
181}
182
183#[cfg(test)]
184mod benches {
185    use super::*;
186    use crate::test_prelude::*;
187
188    #[test]
189    fn benchmark() {
190        ShadowedClosure::new(SafeMul).bench();
191    }
192}