tasm_lib/arithmetic/u64/safe_mul.rs
1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u64`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u64] [left: u64]
15/// AFTER: _ [right · left: u64]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u64::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31 fn inputs(&self) -> Vec<(DataType, String)> {
32 ["rhs", "lhs"]
33 .map(|side| (DataType::U64, side.to_string()))
34 .to_vec()
35 }
36
37 fn outputs(&self) -> Vec<(DataType, String)> {
38 vec![(DataType::U64, "product".to_string())]
39 }
40
41 fn entrypoint(&self) -> String {
42 "tasmlib_arithmetic_u64_safe_mul".to_string()
43 }
44
45 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46 triton_asm!(
47 // BEFORE: _ right_hi right_lo left_hi left_lo
48 // AFTER: _ prod_hi prod_lo
49 {self.entrypoint()}:
50 /* left_lo · right_lo */
51 dup 0
52 dup 3
53 mul
54 // _ right_hi right_lo left_hi left_lo (left_lo · right_lo)
55
56 /* left_lo · right_hi (consume left_lo) */
57 dup 4
58 pick 2
59 mul
60 // _ right_hi right_lo left_hi (left_lo · right_lo) (left_lo · right_hi)
61
62 /* left_hi · right_lo (consume right_lo) */
63 pick 3
64 dup 3
65 mul
66 // _ right_hi left_hi (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
67
68 /* left_hi · right_hi (consume left_hi and right_hi) */
69 pick 4
70 pick 4
71 mul
72 // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo) (left_hi · right_hi)
73
74 /* assert left_hi · right_hi == 0 */
75 push 0
76 eq
77 assert error_id 100
78 // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo)
79 // _ lolo lohi hilo
80
81 /* prod_hi = lolo_hi + lohi_lo + hilo_lo */
82 split
83 pick 1
84 push 0
85 eq
86 assert error_id 101
87 // _ lolo lohi hilo_lo
88
89 pick 1
90 split
91 pick 1
92 push 0
93 eq
94 assert error_id 102
95 // _ lolo hilo_lo lohi_lo
96
97
98 pick 2
99 split
100 // _ hilo_lo lohi_lo lolo_hi lolo_lo
101 // _ hilo_lo lohi_lo lolo_hi prod_lo
102
103 place 3
104 add
105 add
106 // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)
107
108 split
109 pick 1
110 push 0
111 eq
112 assert error_id 103
113 // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)_lo
114 // _ prod_lo prod_hi
115
116 place 1
117 return
118 )
119 }
120
121 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
122 let mut sign_offs = HashMap::new();
123 sign_offs.insert(Reviewer("ferdinand"), 0xaaa2259189834687.into());
124 sign_offs
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::test_prelude::*;
132
133 impl Closure for SafeMul {
134 type Args = (u64, u64);
135
136 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
137 let (right, left) = pop_encodable::<Self::Args>(stack);
138 let (product, is_overflow) = left.overflowing_mul(right);
139 assert!(!is_overflow);
140 push_encodable(stack, &product);
141 }
142
143 fn pseudorandom_args(
144 &self,
145 seed: [u8; 32],
146 bench_case: Option<BenchmarkCase>,
147 ) -> Self::Args {
148 let Some(bench_case) = bench_case else {
149 let mut rng = StdRng::from_seed(seed);
150 return (rng.next_u32().into(), rng.next_u32().into());
151 };
152
153 match bench_case {
154 BenchmarkCase::CommonCase => (1 << 31, (1 << 25) - 1),
155 BenchmarkCase::WorstCase => (1 << 31, (1 << 31) - 1),
156 }
157 }
158 }
159
160 #[test]
161 fn rust_shadow() {
162 ShadowedClosure::new(SafeMul).test();
163 }
164
165 #[test]
166 fn overflow_tests() {
167 let failure_conditions = [
168 (1 << 32, 1 << 32, 100), // (left_hi · right_hi) != 0
169 (1 << 31, 1 << 33, 101), // (left_lo · right_hi)_hi != 0
170 (1 << 33, 1 << 31, 102), // (left_hi · right_lo)_hi != 0
171 ((1 << 31) - 1, (1 << 33) + 5, 103), // (hilo_lo + lohi_lo + lolo_hi)_hi != 0
172 ];
173
174 for (left, right, error_id) in failure_conditions {
175 let safe_mul = ShadowedClosure::new(SafeMul);
176 let stack = SafeMul.set_up_test_stack((left, right));
177 let vm_state = InitVmState::with_stack(stack);
178 test_assertion_failure(&safe_mul, vm_state, &[error_id]);
179 }
180 }
181}
182
183#[cfg(test)]
184mod benches {
185 use super::*;
186 use crate::test_prelude::*;
187
188 #[test]
189 fn benchmark() {
190 ShadowedClosure::new(SafeMul).bench();
191 }
192}