tasm_lib/arithmetic/u128/safe_mul.rs
1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u128`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u128] [left: u128]
15/// AFTER: _ [left · right: u128]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u128::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31 fn parameters(&self) -> Vec<(DataType, String)> {
32 ["right", "left"]
33 .map(|side| (DataType::U128, side.to_string()))
34 .to_vec()
35 }
36
37 fn return_values(&self) -> Vec<(DataType, String)> {
38 vec![(DataType::U128, "product".to_string())]
39 }
40
41 fn entrypoint(&self) -> String {
42 "tasmlib_arithmetic_u128_safe_mul".to_string()
43 }
44
45 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46 triton_asm!(
47 // BEFORE: _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0
48 // AFTER: _ p_3 p_2 p_1 p_0
49 {self.entrypoint()}:
50 /*
51 * p_0 is low limb, c_0 high limb of
52 * l_0·r_0
53 *
54 * p_1 is low limb, c_1 high limb of
55 * (l_1·r_0)_lo + (l_0·r_1)_lo
56 * + c_0
57 *
58 * p_2 is low limb, c_2 high limb of
59 * (l_1·r_0)_hi + (l_0·r_1)_hi
60 * + (l_2·r_0)_lo + (l_1·r_1)_lo + (l_0·r_2)_lo
61 * + c_1
62 *
63 * p_3 is low limb, c_3 high limb of
64 * (l_2·r_0)_hi + (l_1·r_1)_hi + (l_0·r_2)_hi
65 * + (l_3·r_0)_lo + (l_2·r_1)_lo + (l_1·r_2)_lo + (l_0·r_3)_lo
66 * + c_2
67 *
68 * All remaining limb combinations (l_3·r_1, l_3·r_2, l_3·r_3 l_2·r_2,
69 * l_2·r_3, and l_1·r_3) as well as c_3 must be 0.
70 */
71
72 /* p_0 */
73 dup 0 dup 5 mul split
74 // _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 p_0
75
76 place 9
77 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0
78
79 /* p_1 */
80 dup 2 dup 6 mul split
81 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo
82
83 dup 3 dup 9 mul split
84 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo (l_0·r_1)_hi (l_0·r_1)_lo
85 // ^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
86
87 pick 2 pick 4
88 add add
89 split
90 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1 p_1
91
92 place 12
93 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1
94
95 /* p_2 */
96 add add
97 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip
98
99 dup 3 dup 6 mul split
100 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo
101
102 dup 4 dup 9 mul split
103 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo
104
105 dup 5 dup 12 mul split
106 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo (l_0·r_2)_hi (l_0·r_2)_lo
107 // ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
108
109 pick 2 pick 4 pick 6
110 add add add
111 split
112 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2 p_2
113
114 place 14
115 // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2
116
117 /* p_3 */
118 add add add
119 // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_3_wip
120
121 dup 4 pick 6 mul split
122 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo
123
124 dup 5 dup 8 mul split
125 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo
126
127 dup 6 dup 11 mul split
128 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo
129
130 pick 7 dup 13 mul split
131 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo (l_0·l_3)_hi (l_0·l_3)_lo
132 // ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
133
134 pick 2 pick 4 pick 6 pick 8
135 add add add add
136 split
137 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3 p_3
138
139 place 14
140 // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3
141
142 /* overflow checks
143 *
144 * Carry c_3 and the high limbs still on stack are guaranteed to be smaller than
145 * 2^32 since they resulted from instruction `split`. The sum of those 5 elements
146 * cannot “wrap around” `BFieldElement::P`.
147 */
148 add add add add
149 push 0 eq assert error_id 500
150 // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1
151
152 /* l_3·r_1 */
153 dup 2 pick 4 mul
154 push 0 eq assert error_id 501
155 // _ [p; 4] r_3 r_2 l_3 l_2 l_1
156
157 /* l_2·r_2 */
158 dup 1 dup 4 mul
159 push 0 eq assert error_id 502
160 // _ [p; 4] r_3 r_2 l_3 l_2 l_1
161
162 /* l_1·r_3 */
163 dup 4 mul
164 push 0 eq assert error_id 503
165 // _ [p; 4] r_3 r_2 l_3 l_2
166
167 /* l_3·r_2 */
168 dup 1 pick 3 mul
169 push 0 eq assert error_id 504
170 // _ [p; 4] r_3 l_3 l_2
171
172 /* l_2·r_3 */
173 dup 2 mul
174 push 0 eq assert error_id 505
175 // _ [p; 4] r_3 l_3
176
177 /* l_3·r_3 */
178 mul
179 push 0 eq assert error_id 506
180 // _ [p; 4]
181
182 return
183 )
184 }
185
186 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
187 let mut sign_offs = HashMap::new();
188 sign_offs.insert(Reviewer("ferdinand"), 0xbba006a82c82b12f.into());
189 sign_offs
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use rand::rngs::StdRng;
196
197 use super::*;
198 use crate::test_prelude::*;
199
200 impl SafeMul {
201 fn test_assertion_failure(&self, left: u128, right: u128, error_ids: &[i128]) {
202 test_assertion_failure(
203 &ShadowedClosure::new(Self),
204 InitVmState::with_stack(self.set_up_test_stack((right, left))),
205 error_ids,
206 );
207 }
208 }
209
210 impl Closure for SafeMul {
211 type Args = (u128, u128);
212
213 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
214 let (right, left) = pop_encodable::<Self::Args>(stack)?;
215 let product = left
216 .checked_mul(right)
217 .ok_or(RustShadowError::ArithmeticOverflow)?;
218 push_encodable(stack, &product);
219 Ok(())
220 }
221
222 fn pseudorandom_args(
223 &self,
224 seed: [u8; 32],
225 bench_case: Option<BenchmarkCase>,
226 ) -> Self::Args {
227 let Some(bench_case) = bench_case else {
228 let mut rng = StdRng::from_seed(seed);
229 let left = rng.random_range(1..=u128::MAX);
230 let right = rng.random_range(0..=u128::MAX / left);
231
232 return (right, left);
233 };
234
235 match bench_case {
236 BenchmarkCase::CommonCase => (1 << 63, (1 << 45) - 1),
237 BenchmarkCase::WorstCase => (1 << 63, (1 << 63) - 1),
238 }
239 }
240
241 fn corner_case_args(&self) -> Vec<Self::Args> {
242 const LEFT_NOISE: u128 = 0xfd4e_3f84_8677_df6b_da64_b83c_8267_c72d;
243 const RIGHT_NOISE: u128 = 0x538e_e051_c430_3e7a_0a29_a45a_5efb_67fa;
244
245 (0..u128::BITS)
246 .cartesian_product(0..u128::BITS)
247 .map(|(l, r)| {
248 let left = (1 << l) | ((1 << l) - 1) & LEFT_NOISE;
249 let right = (1 << r) | ((1 << r) - 1) & RIGHT_NOISE;
250 (right, left)
251 })
252 .filter(|&(right, left)| left.checked_mul(right).is_some())
253 .step_by(5) // test performance is atrocious otherwise
254 .chain([(0, 0)])
255 .collect()
256 }
257 }
258
259 #[macro_rules_attr::apply(test)]
260 fn rust_shadow() {
261 ShadowedClosure::new(SafeMul).test()
262 }
263
264 #[macro_rules_attr::apply(test)]
265 fn overflow_crashes_vm() {
266 SafeMul.test_assertion_failure(1 << 127, 1 << 1, &[500]);
267 SafeMul.test_assertion_failure(1 << 96, 1 << 32, &[501]);
268 SafeMul.test_assertion_failure(1 << 64, 1 << 64, &[502]);
269 SafeMul.test_assertion_failure(1 << 32, 1 << 96, &[503]);
270 SafeMul.test_assertion_failure(1 << 96, 1 << 64, &[504]);
271 SafeMul.test_assertion_failure(1 << 64, 1 << 96, &[505]);
272 SafeMul.test_assertion_failure(1 << 96, 1 << 96, &[506]);
273
274 for i in 1..64 {
275 let left = u128::MAX >> i;
276 let right = (1 << i) + 1;
277 SafeMul.test_assertion_failure(left, right, &[500]);
278 SafeMul.test_assertion_failure(right, left, &[500]);
279 }
280
281 for i in 1..128 {
282 let left = 1 << i;
283 let right = 1 << (128 - i);
284 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503]);
285 }
286 }
287
288 #[macro_rules_attr::apply(proptest(cases = 80))]
289 fn arbitrary_overflow_crashes_vm(
290 #[strategy(2_u8..128)] _log_upper_bound: u8,
291 #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
292 #[strategy(u128::MAX / #left + 1..)] right: u128,
293 ) {
294 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
295 }
296
297 #[macro_rules_attr::apply(proptest(cases = 80))]
298 fn marginal_overflow_crashes_vm(
299 #[strategy(2_u8..128)] _log_upper_bound: u8,
300 #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
301 ) {
302 let right = u128::MAX / left + 1;
303 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
304 }
305
306 #[macro_rules_attr::apply(proptest)]
307 fn arbitrary_overflow_crashes_vm_u128(
308 #[strategy(2_u128..)] left: u128,
309 #[strategy(u128::MAX / #left + 1..)] right: u128,
310 ) {
311 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
312 }
313}
314
315#[cfg(test)]
316mod benches {
317 use super::*;
318 use crate::test_prelude::*;
319
320 #[macro_rules_attr::apply(test)]
321 fn benchmark() {
322 ShadowedClosure::new(SafeMul).bench();
323 }
324}