Skip to main content

tasm_lib/arithmetic/u64/
safe_mul.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u64`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u64] [left: u64]
15/// AFTER:  _ [right · left: u64]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u64::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31    fn parameters(&self) -> Vec<(DataType, String)> {
32        ["rhs", "lhs"]
33            .map(|side| (DataType::U64, side.to_string()))
34            .to_vec()
35    }
36
37    fn return_values(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U64, "product".to_string())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u64_safe_mul".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        triton_asm!(
47            // BEFORE: _ right_hi right_lo left_hi left_lo
48            // AFTER:  _ prod_hi prod_lo
49            {self.entrypoint()}:
50                /* left_lo · right_lo */
51                dup 0
52                dup 3
53                mul
54                // _ right_hi right_lo left_hi left_lo (left_lo · right_lo)
55
56                /* left_lo · right_hi (consume left_lo) */
57                dup 4
58                pick 2
59                mul
60                // _ right_hi right_lo left_hi (left_lo · right_lo) (left_lo · right_hi)
61
62                /* left_hi · right_lo (consume right_lo) */
63                pick 3
64                dup 3
65                mul
66                // _ right_hi left_hi (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
67
68                /* left_hi · right_hi (consume left_hi and right_hi) */
69                pick 4
70                pick 4
71                mul
72                // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo) (left_hi · right_hi)
73
74                /* assert left_hi · right_hi == 0 */
75                push 0
76                eq
77                assert error_id 100
78                // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
79                // _ lolo                 lohi                 hilo
80
81                /* prod_hi = lolo_hi + lohi_lo + hilo_lo */
82                split
83                pick 1
84                push 0
85                eq
86                assert error_id 101
87                // _ lolo lohi hilo_lo
88
89                pick 1
90                split
91                pick 1
92                push 0
93                eq
94                assert error_id 102
95                // _ lolo hilo_lo lohi_lo
96
97
98                pick 2
99                split
100                // _ hilo_lo lohi_lo lolo_hi lolo_lo
101                // _ hilo_lo lohi_lo lolo_hi prod_lo
102
103                place 3
104                add
105                add
106                // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)
107
108                split
109                pick 1
110                push 0
111                eq
112                assert error_id 103
113                // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)_lo
114                // _ prod_lo prod_hi
115
116                place 1
117                return
118        )
119    }
120
121    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
122        let mut sign_offs = HashMap::new();
123        sign_offs.insert(Reviewer("ferdinand"), 0xe472ad0d428ed30c.into());
124        sign_offs
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::test_prelude::*;
132
133    impl Closure for SafeMul {
134        type Args = (u64, u64);
135
136        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
137            let (right, left) = pop_encodable::<Self::Args>(stack)?;
138            let (product, is_overflow) = left.overflowing_mul(right);
139            if is_overflow {
140                return Err(RustShadowError::ArithmeticOverflow);
141            }
142            push_encodable(stack, &product);
143            Ok(())
144        }
145
146        fn pseudorandom_args(
147            &self,
148            seed: [u8; 32],
149            bench_case: Option<BenchmarkCase>,
150        ) -> Self::Args {
151            let Some(bench_case) = bench_case else {
152                let mut rng = StdRng::from_seed(seed);
153                return (rng.next_u32().into(), rng.next_u32().into());
154            };
155
156            match bench_case {
157                BenchmarkCase::CommonCase => (1 << 31, (1 << 25) - 1),
158                BenchmarkCase::WorstCase => (1 << 31, (1 << 31) - 1),
159            }
160        }
161    }
162
163    #[macro_rules_attr::apply(test)]
164    fn rust_shadow() {
165        ShadowedClosure::new(SafeMul).test();
166    }
167
168    #[macro_rules_attr::apply(test)]
169    fn overflow_tests() {
170        let failure_conditions = [
171            (1 << 32, 1 << 32, 100),             // (left_hi · right_hi) != 0
172            (1 << 31, 1 << 33, 101),             // (left_lo · right_hi)_hi != 0
173            (1 << 33, 1 << 31, 102),             // (left_hi · right_lo)_hi != 0
174            ((1 << 31) - 1, (1 << 33) + 5, 103), // (hilo_lo + lohi_lo + lolo_hi)_hi != 0
175        ];
176
177        for (left, right, error_id) in failure_conditions {
178            let safe_mul = ShadowedClosure::new(SafeMul);
179            let stack = SafeMul.set_up_test_stack((left, right));
180            let vm_state = InitVmState::with_stack(stack);
181            test_assertion_failure(&safe_mul, vm_state, &[error_id]);
182        }
183    }
184}
185
186#[cfg(test)]
187mod benches {
188    use super::*;
189    use crate::test_prelude::*;
190
191    #[macro_rules_attr::apply(test)]
192    fn benchmark() {
193        ShadowedClosure::new(SafeMul).bench();
194    }
195}