tasm_lib/arithmetic/u160/
safe_mul.rs1use triton_vm::prelude::*;
2
3use crate::arithmetic;
4use crate::arithmetic::u64::mul_two_u64s_to_u128::MulTwoU64sToU128;
5use crate::prelude::*;
6
7#[derive(Debug, Clone)]
27pub struct SafeMul;
28
29impl SafeMul {
30 pub(crate) const OVERFLOW_0: i128 = 580;
31 pub(crate) const OVERFLOW_1: i128 = 581;
32 pub(crate) const OVERFLOW_2: i128 = 582;
33 pub(crate) const OVERFLOW_3: i128 = 583;
34 pub(crate) const OVERFLOW_4: i128 = 584;
35}
36
37impl BasicSnippet for SafeMul {
38 fn parameters(&self) -> Vec<(DataType, String)> {
39 ["right", "left"]
40 .map(|side| (DataType::U160, side.to_string()))
41 .to_vec()
42 }
43
44 fn return_values(&self) -> Vec<(DataType, String)> {
45 vec![(DataType::U160, "product".to_string())]
46 }
47
48 fn entrypoint(&self) -> String {
49 "tasmlib_arithmetic_u160_safe_mul".to_string()
50 }
51
52 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
53 let u64_to_u128_mul = library.import(Box::new(MulTwoU64sToU128));
54 let u64_safe_mul = library.import(Box::new(arithmetic::u64::safe_mul::SafeMul));
55 let u64_safe_add = library.import(Box::new(arithmetic::u64::add::Add));
56 let u128_safe_add = library.import(Box::new(arithmetic::u128::safe_add::SafeAdd));
57 let u160_safe_add = library.import(Box::new(arithmetic::u160::safe_add::SafeAdd));
58
59 triton_asm!(
60 {self.entrypoint()}:
63
64
65
66 push 0
68 place 10
69 push 0
70 place 5
71 dup 9
85 push 0
86 eq
87 dup 9
88 push 0
89 eq
90 mul
91 dup 5
94 push 0
95 eq
96 dup 12
99 push 0
100 eq
101 dup 6
104 push 0
105 eq
106 dup 6
107 push 0
108 eq
109 mul
110 dup 2
113 dup 2
114 add
115 pop_count
116 assert error_id {Self::OVERFLOW_0}
119 add
122 pop_count
123 assert error_id {Self::OVERFLOW_1}
126 add
129 pop_count
130 assert error_id {Self::OVERFLOW_2}
133 pick 11
141 pick 11
142 dup 3
143 dup 3
144 call {u64_safe_mul}
145 dup 11
148 dup 11
149 dup 7
150 dup 7
151 call {u64_safe_mul}
152 dup 11
155 dup 11
156 pick 11
157 pick 11
158 call {u64_safe_mul}
159 call {u64_safe_add}
163 call {u64_safe_add}
164 pick 1
169 push 0
170 eq
171 assert error_id {Self::OVERFLOW_3}
172 push 0
176 push 0
177 push 0
178 push 0
179 pick 12
184 pick 12
185 dup 8
186 dup 8
187 call {u64_to_u128_mul}
188 dup 14
191 dup 14
192 pick 14
193 pick 14
194 call {u64_to_u128_mul}
195 call {u128_safe_add}
198 pick 3
203 push 0
204 eq
205 assert error_id {Self::OVERFLOW_4}
206 push 0
209 push 0
210 push 0
213 pick 14
216 pick 14
217 pick 14
218 pick 14
219 call {u64_to_u128_mul}
220 call {u160_safe_add}
225 call {u160_safe_add}
226 return
231 )
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use num::BigUint;
238 use num::One;
239 use rand::rngs::StdRng;
240
241 use super::*;
242 use crate::arithmetic::u160::u128_to_u160;
243 use crate::arithmetic::u160::u128_to_u160_shl_32;
244 use crate::arithmetic::u160::u128_to_u160_shl_32_lower_limb_filled;
245 use crate::test_prelude::*;
246
247 impl SafeMul {
248 fn test_assertion_failure(&self, left: [u32; 5], right: [u32; 5], error_ids: &[i128]) {
249 test_assertion_failure(
250 &ShadowedClosure::new(Self),
251 InitVmState::with_stack(self.set_up_test_stack((right, left))),
252 error_ids,
253 );
254 }
255 }
256
257 #[test]
258 fn rust_shadow() {
259 ShadowedClosure::new(SafeMul).test()
260 }
261
262 #[test]
263 fn overflow_unit_test() {
264 SafeMul.test_assertion_failure(
265 u128_to_u160_shl_32(u128::MAX),
266 u128_to_u160_shl_32(u128::MAX),
267 &[580],
268 );
269 SafeMul.test_assertion_failure(
270 u128_to_u160_shl_32(1u128 << 64),
271 u128_to_u160_shl_32(u128::MAX),
272 &[581],
273 );
274 SafeMul.test_assertion_failure(
275 u128_to_u160_shl_32(u128::MAX),
276 u128_to_u160_shl_32(1u128 << 64),
277 &[582],
278 );
279 SafeMul.test_assertion_failure(
280 u128_to_u160(1u128 << 64),
281 u128_to_u160(1u128 << 96),
282 &[583],
283 );
284 SafeMul.test_assertion_failure(
285 u128_to_u160(1u128 << 96),
286 u128_to_u160(1u128 << 64),
287 &[583],
288 );
289 SafeMul.test_assertion_failure(
290 u128_to_u160((1u128 << 64) - 1),
291 u128_to_u160(1u128 << 99),
292 &[584],
293 );
294 SafeMul.test_assertion_failure(
295 u128_to_u160(1u128 << 99),
296 u128_to_u160((1u128 << 64) - 1),
297 &[584],
298 );
299 SafeMul.test_assertion_failure(u128_to_u160(2), u128_to_u160_shl_32(1 << 127), &[583]);
300 SafeMul.test_assertion_failure(u128_to_u160_shl_32(1 << 127), u128_to_u160(2), &[583]);
301 }
302
303 #[proptest(cases = 100)]
304 fn arbitrary_overflow_crashes_vm_u128(
305 #[strategy(2_u128..)] left: u128,
306 #[strategy(u128::MAX / #left + 1..)] right: u128,
307 ) {
308 let left = u128_to_u160_shl_32(left);
309 let right = u128_to_u160(right);
310 SafeMul.test_assertion_failure(left, right, &[580, 581, 582, 583, 584, 570]);
311 }
312
313 #[proptest(cases = 50)]
314 fn marginal_overflow_crashes_vm(
315 #[strategy(2_u8..128)] _log_upper_bound: u8,
316 #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
317 ) {
318 let right = u128::MAX / left + 1;
319
320 let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
321 SafeMul.test_assertion_failure(
322 u128_to_u160_shl_32(left),
323 u128_to_u160(right),
324 &expected_error_codes,
325 );
326 SafeMul.test_assertion_failure(
327 u128_to_u160(left),
328 u128_to_u160_shl_32(right),
329 &expected_error_codes,
330 );
331 }
332
333 #[proptest(cases = 50)]
334 fn arbitrary_overflow_crashes_vm(
335 #[strategy(2_u8..128)] _log_upper_bound: u8,
336 #[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
337 #[strategy(u128::MAX / #left + 1..)] right: u128,
338 ) {
339 let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
340 SafeMul.test_assertion_failure(
341 u128_to_u160_shl_32(left),
342 u128_to_u160(right),
343 &expected_error_codes,
344 );
345 SafeMul.test_assertion_failure(
346 u128_to_u160(left),
347 u128_to_u160_shl_32(right),
348 &expected_error_codes,
349 );
350 SafeMul.test_assertion_failure(
351 u128_to_u160_shl_32_lower_limb_filled(left),
352 u128_to_u160(right),
353 &expected_error_codes,
354 );
355 SafeMul.test_assertion_failure(
356 u128_to_u160(left),
357 u128_to_u160_shl_32_lower_limb_filled(right),
358 &expected_error_codes,
359 );
360
361 SafeMul.test_assertion_failure(
363 u128_to_u160_shl_32(left),
364 u128_to_u160_shl_32(right),
365 &expected_error_codes,
366 );
367 SafeMul.test_assertion_failure(
368 u128_to_u160_shl_32(left),
369 u128_to_u160_shl_32(right),
370 &expected_error_codes,
371 );
372 }
373
374 impl Closure for SafeMul {
375 type Args = ([u32; 5], [u32; 5]);
376
377 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
378 let left: [u32; 5] = pop_encodable(stack);
379 let left: BigUint = BigUint::new(left.to_vec());
380 let right: [u32; 5] = pop_encodable(stack);
381 let right: BigUint = BigUint::new(right.to_vec());
382 let prod = left.clone() * right.clone();
383 let mut prod = prod.to_u32_digits();
384 assert!(prod.len() <= 5, "Overflow: left: {left}, right: {right}.");
385
386 prod.resize(5, 0);
387 let prod: [u32; 5] = prod.try_into().unwrap();
388
389 push_encodable(stack, &prod);
390 }
391
392 fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
393 let mut rng = StdRng::from_seed(seed);
394 let lhs: [u32; 5] = rng.random();
395 let lhs_as_biguint = BigUint::new(lhs.to_vec());
396
397 let u160_max = BigUint::from_bytes_be(&[0xFF; 20]);
398 let max = &u160_max / &lhs_as_biguint;
399
400 let bits: u32 = max.bits().try_into().unwrap();
401 let bit_mask = BigUint::from(2u32).pow(bits) - BigUint::one();
402 let mut bit_mask = bit_mask.to_bytes_be();
403 bit_mask.reverse();
404 bit_mask.resize(20, 0);
405 bit_mask.reverse();
406 let mut rhs_bytes = [0u8; 20];
407 let rhs = loop {
408 rng.fill(&mut rhs_bytes);
409 for i in 0..20 {
410 rhs_bytes[i] &= bit_mask[i];
411 }
412 let candidate = BigUint::from_bytes_be(&rhs_bytes);
413 if candidate < max {
414 break candidate;
415 }
416 };
417
418 {
419 let prod = lhs_as_biguint * rhs.clone();
420 assert!(prod.to_u32_digits().len() <= 5);
421 }
422
423 let mut rhs = rhs.to_u32_digits();
424 rhs.resize(5, 0);
425
426 (lhs, rhs.try_into().unwrap())
427 }
428
429 fn corner_case_args(&self) -> Vec<Self::Args> {
430 fn u160_checked_mul(l: [u32; 5], r: [u32; 5]) -> Option<[u32; 5]> {
431 let l: BigUint = BigUint::new(l.to_vec());
432 let r: BigUint = BigUint::new(r.to_vec());
433
434 let prod = l * r;
435 let mut prod = prod.to_u32_digits();
436
437 if prod.len() > 5 {
438 None
439 } else {
440 prod.resize(5, 0);
441 Some(prod.try_into().unwrap())
442 }
443 }
444
445 let edge_case_points = vec![
446 u128_to_u160(0),
447 u128_to_u160(1),
448 u128_to_u160(2),
449 u128_to_u160(u8::MAX as u128),
450 u128_to_u160(1 << 8),
451 u128_to_u160(u16::MAX as u128),
452 u128_to_u160(1 << 16),
453 u128_to_u160(u32::MAX as u128),
454 u128_to_u160(1 << 32),
455 u128_to_u160(u64::MAX as u128),
456 u128_to_u160(1 << 64),
457 [u32::MAX, u32::MAX, u32::MAX, 0, 0],
458 u128_to_u160(1 << 96),
459 u128_to_u160(u128::MAX),
460 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX >> 1],
461 [u32::MAX; 5],
462 ];
463
464 edge_case_points
465 .iter()
466 .cartesian_product(&edge_case_points)
467 .filter(|&(&l, &r)| u160_checked_mul(l, r).is_some())
468 .map(|(&l, &r)| (l, r))
469 .collect()
470 }
471 }
472}
473
474#[cfg(test)]
475mod benches {
476 use super::*;
477 use crate::test_prelude::*;
478
479 #[test]
480 fn benchmark() {
481 ShadowedClosure::new(SafeMul).bench()
482 }
483}