tasm_lib/arithmetic/u128/safe_mul.rs
1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Multiply two `u128`s and crash on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [right: u128] [left: u128]
15/// AFTER: _ [left · right: u128]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21/// - the product of `left` and `right` is less than or equal to [`u128::MAX`]
22///
23/// ### Postconditions
24///
25/// - the output is the product of the input
26/// - the output is properly [`BFieldCodec`] encoded
27#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct SafeMul;
29
30impl BasicSnippet for SafeMul {
31 fn inputs(&self) -> Vec<(DataType, String)> {
32 ["right", "left"]
33 .map(|side| (DataType::U128, side.to_string()))
34 .to_vec()
35 }
36
37 fn outputs(&self) -> Vec<(DataType, String)> {
38 vec![(DataType::U128, "product".to_string())]
39 }
40
41 fn entrypoint(&self) -> String {
42 "tasmlib_arithmetic_u128_safe_mul".to_string()
43 }
44
45 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46 triton_asm!(
47 // BEFORE: _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0
48 // AFTER: _ p_3 p_2 p_1 p_0
49 {self.entrypoint()}:
50 /*
51 * p_0 is low limb, c_0 high limb of
52 * l_0·r_0
53 *
54 * p_1 is low limb, c_1 high limb of
55 * (l_1·r_0)_lo + (l_0·r_1)_lo
56 * + c_0
57 *
58 * p_2 is low limb, c_2 high limb of
59 * (l_1·r_0)_hi + (l_0·r_1)_hi
60 * + (l_2·r_0)_lo + (l_1·r_1)_lo + (l_0·r_2)_lo
61 * + c_1
62 *
63 * p_3 is low limb, c_3 high limb of
64 * (l_2·r_0)_hi + (l_1·r_1)_hi + (l_0·r_2)_hi
65 * + (l_3·r_0)_lo + (l_2·r_1)_lo + (l_1·r_2)_lo + (l_0·r_3)_lo
66 * + c_2
67 *
68 * All remaining limb combinations (l_3·r_1, l_3·r_2, l_3·r_3 l_2·r_2,
69 * l_2·r_3, and l_1·r_3) as well as c_3 must be 0.
70 */
71
72 /* p_0 */
73 dup 0 dup 5 mul split
74 // _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 p_0
75
76 place 9
77 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0
78
79 /* p_1 */
80 dup 2 dup 6 mul split
81 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo
82
83 dup 3 dup 9 mul split
84 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo (l_0·r_1)_hi (l_0·r_1)_lo
85 // ^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
86
87 pick 2 pick 4
88 add add
89 split
90 // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1 p_1
91
92 place 12
93 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1
94
95 /* p_2 */
96 add add
97 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip
98
99 dup 3 dup 6 mul split
100 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo
101
102 dup 4 dup 9 mul split
103 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo
104
105 dup 5 dup 12 mul split
106 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo (l_0·r_2)_hi (l_0·r_2)_lo
107 // ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
108
109 pick 2 pick 4 pick 6
110 add add add
111 split
112 // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2 p_2
113
114 place 14
115 // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2
116
117 /* p_3 */
118 add add add
119 // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_3_wip
120
121 dup 4 pick 6 mul split
122 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo
123
124 dup 5 dup 8 mul split
125 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo
126
127 dup 6 dup 11 mul split
128 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo
129
130 pick 7 dup 13 mul split
131 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo (l_0·l_3)_hi (l_0·l_3)_lo
132 // ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^
133
134 pick 2 pick 4 pick 6 pick 8
135 add add add add
136 split
137 // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3 p_3
138
139 place 14
140 // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3
141
142 /* overflow checks
143 *
144 * Carry c_3 and the high limbs still on stack are guaranteed to be smaller than
145 * 2^32 since they resulted from instruction `split`. The sum of those 5 elements
146 * cannot “wrap around” `BFieldElement::P`.
147 */
148 add add add add
149 push 0 eq assert error_id 500
150 // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1
151
152 /* l_3·r_1 */
153 dup 2 pick 4 mul
154 push 0 eq assert error_id 501
155 // _ [p; 4] r_3 r_2 l_3 l_2 l_1
156
157 /* l_2·r_2 */
158 dup 1 dup 4 mul
159 push 0 eq assert error_id 502
160 // _ [p; 4] r_3 r_2 l_3 l_2 l_1
161
162 /* l_1·r_3 */
163 dup 4 mul
164 push 0 eq assert error_id 503
165 // _ [p; 4] r_3 r_2 l_3 l_2
166
167 /* l_3·r_2 */
168 dup 1 pick 3 mul
169 push 0 eq assert error_id 504
170 // _ [p; 4] r_3 l_3 l_2
171
172 /* l_2·r_3 */
173 dup 2 mul
174 push 0 eq assert error_id 505
175 // _ [p; 4] r_3 l_3
176
177 /* l_3·r_3 */
178 mul
179 push 0 eq assert error_id 506
180 // _ [p; 4]
181
182 return
183 )
184 }
185
186 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
187 let mut sign_offs = HashMap::new();
188 sign_offs.insert(Reviewer("ferdinand"), 0x6a6ab0928dd2f0e4.into());
189 sign_offs
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::test_prelude::*;
197 use rand::rngs::StdRng;
198
199 impl SafeMul {
200 fn test_assertion_failure(&self, left: u128, right: u128, error_ids: &[i128]) {
201 test_assertion_failure(
202 &ShadowedClosure::new(Self),
203 InitVmState::with_stack(self.set_up_test_stack((right, left))),
204 error_ids,
205 );
206 }
207 }
208
209 impl Closure for SafeMul {
210 type Args = (u128, u128);
211
212 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
213 let (right, left) = pop_encodable::<Self::Args>(stack);
214 let product = left.checked_mul(right).unwrap();
215 push_encodable(stack, &product);
216 }
217
218 fn pseudorandom_args(
219 &self,
220 seed: [u8; 32],
221 bench_case: Option<BenchmarkCase>,
222 ) -> Self::Args {
223 let Some(bench_case) = bench_case else {
224 let mut rng = StdRng::from_seed(seed);
225 let left = rng.random_range(1..=u128::MAX);
226 let right = rng.random_range(0..=u128::MAX / left);
227
228 return (right, left);
229 };
230
231 match bench_case {
232 BenchmarkCase::CommonCase => (1 << 63, (1 << 45) - 1),
233 BenchmarkCase::WorstCase => (1 << 63, (1 << 63) - 1),
234 }
235 }
236
237 fn corner_case_args(&self) -> Vec<Self::Args> {
238 const LEFT_NOISE: u128 = 0xfd4e_3f84_8677_df6b_da64_b83c_8267_c72d;
239 const RIGHT_NOISE: u128 = 0x538e_e051_c430_3e7a_0a29_a45a_5efb_67fa;
240
241 (0..u128::BITS)
242 .cartesian_product(0..u128::BITS)
243 .map(|(l, r)| {
244 let left = (1 << l) | ((1 << l) - 1) & LEFT_NOISE;
245 let right = (1 << r) | ((1 << r) - 1) & RIGHT_NOISE;
246 (right, left)
247 })
248 .filter(|&(right, left)| left.checked_mul(right).is_some())
249 .step_by(5) // test performance is atrocious otherwise
250 .chain([(0, 0)])
251 .collect()
252 }
253 }
254
255 #[test]
256 fn rust_shadow() {
257 ShadowedClosure::new(SafeMul).test();
258 }
259
260 #[test]
261 fn overflow_crashes_vm() {
262 SafeMul.test_assertion_failure(1 << 127, 1 << 1, &[500]);
263 SafeMul.test_assertion_failure(1 << 96, 1 << 32, &[501]);
264 SafeMul.test_assertion_failure(1 << 64, 1 << 64, &[502]);
265 SafeMul.test_assertion_failure(1 << 32, 1 << 96, &[503]);
266 SafeMul.test_assertion_failure(1 << 96, 1 << 64, &[504]);
267 SafeMul.test_assertion_failure(1 << 64, 1 << 96, &[505]);
268 SafeMul.test_assertion_failure(1 << 96, 1 << 96, &[506]);
269
270 for i in 1..64 {
271 let left = u128::MAX >> i;
272 let right = (1 << i) + 1;
273 SafeMul.test_assertion_failure(left, right, &[500]);
274 SafeMul.test_assertion_failure(right, left, &[500]);
275 }
276
277 for i in 1..128 {
278 let left = 1 << i;
279 let right = 1 << (128 - i);
280 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503]);
281 }
282 }
283
284 #[proptest(cases = 1_000)]
285 fn arbitrary_overflow_crashes_vm(
286 #[strategy(2_u128..)] left: u128,
287 #[strategy(u128::MAX / #left + 1..)] right: u128,
288 ) {
289 SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
290 }
291}
292
293#[cfg(test)]
294mod benches {
295 use super::*;
296 use crate::test_prelude::*;
297
298 #[test]
299 fn benchmark() {
300 ShadowedClosure::new(SafeMul).bench();
301 }
302}