1use alloy::primitives::{I256, U256, U512};
9use tycho_common::simulation::errors::SimulationError;
10
11pub fn safe_mul_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
12 let res = a.checked_mul(b);
13 _construc_result_u256(res)
14}
15
16pub fn safe_div_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
17 if b.is_zero() {
18 return Err(SimulationError::FatalError("Division by zero".to_string()));
19 }
20 let res = a.checked_div(b);
21 _construc_result_u256(res)
22}
23
24pub fn safe_add_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
25 let res = a.checked_add(b);
26 _construc_result_u256(res)
27}
28
29pub fn safe_sub_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
30 let res = a.checked_sub(b);
31 _construc_result_u256(res)
32}
33
34pub fn div_mod_u256(a: U256, b: U256) -> Result<(U256, U256), SimulationError> {
35 if b.is_zero() {
36 return Err(SimulationError::FatalError("Division by zero".to_string()));
37 }
38 let result = a / b;
39 let rest = a % b;
40 Ok((result, rest))
41}
42
43pub fn _construc_result_u256(res: Option<U256>) -> Result<U256, SimulationError> {
44 match res {
45 None => Err(SimulationError::FatalError("U256 arithmetic overflow".to_string())),
46 Some(value) => Ok(value),
47 }
48}
49
50pub fn safe_mul_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
51 let res = a.checked_mul(b);
52 _construc_result_u512(res)
53}
54
55pub fn safe_div_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
56 if b.is_zero() {
57 return Err(SimulationError::FatalError("Division by zero".to_string()));
58 }
59 let res = a.checked_div(b);
60 _construc_result_u512(res)
61}
62
63pub fn safe_add_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
64 let res = a.checked_add(b);
65 _construc_result_u512(res)
66}
67
68pub fn safe_sub_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
69 let res = a.checked_sub(b);
70 _construc_result_u512(res)
71}
72
73pub fn div_mod_u512(a: U512, b: U512) -> Result<(U512, U512), SimulationError> {
74 if b.is_zero() {
75 return Err(SimulationError::FatalError("Division by zero".to_string()));
76 }
77 let result = a / b;
78 let rest = a % b;
79 Ok((result, rest))
80}
81
82pub fn _construc_result_u512(res: Option<U512>) -> Result<U512, SimulationError> {
83 match res {
84 None => Err(SimulationError::FatalError("U512 arithmetic overflow".to_string())),
85 Some(value) => Ok(value),
86 }
87}
88
89pub fn safe_mul_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
90 let res = a.checked_mul(b);
91 _construc_result_i256(res)
92}
93
94pub fn safe_div_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
95 if b.is_zero() {
96 return Err(SimulationError::FatalError("Division by zero".to_string()));
97 }
98 let res = a.checked_div(b);
99 _construc_result_i256(res)
100}
101
102pub fn safe_add_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
103 let res = a.checked_add(b);
104 _construc_result_i256(res)
105}
106
107pub fn safe_sub_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
108 let res = a.checked_sub(b);
109 _construc_result_i256(res)
110}
111
112pub fn _construc_result_i256(res: Option<I256>) -> Result<I256, SimulationError> {
113 match res {
114 None => Err(SimulationError::FatalError("I256 arithmetic overflow".to_string())),
115 Some(value) => Ok(value),
116 }
117}
118
119pub fn sqrt_u512(value: U512) -> U512 {
130 if value == U512::ZERO {
132 return U512::ZERO;
133 }
134
135 if value == U512::from(1u32) {
137 return U512::from(1u32);
138 }
139
140 let bits = 512 - value.leading_zeros();
143 let mut result = U512::from(1u32) << (bits / 2);
144
145 let mut decreasing = false;
148 loop {
149 let division = value / result;
151 let iter = (division + result) / U512::from(2u32);
152
153 if iter == result {
155 break;
157 }
158
159 if iter > result {
160 if decreasing {
161 break;
163 }
164 result =
166 if iter > result * U512::from(2u32) { result * U512::from(2u32) } else { iter };
167 } else {
168 decreasing = true;
170 result = iter;
171 }
172 }
173
174 result
175}
176
177pub fn sqrt_u256(value: U256) -> Result<U256, SimulationError> {
179 if value == U256::ZERO {
180 return Ok(U256::ZERO);
181 }
182
183 let bits = 256 - value.leading_zeros();
184 let mut remainder = U256::ZERO;
185 let mut temp = U256::ZERO;
186 let result = compute_karatsuba_sqrt(value, &mut remainder, &mut temp, bits);
187
188 let limbs = result.as_limbs();
190 Ok(U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]]))
191}
192
193fn compute_karatsuba_sqrt(x: U256, r: &mut U256, t: &mut U256, bits: usize) -> U256 {
198 if bits <= 64 {
202 let x_small = x.as_limbs()[0];
203 let result = x_small.isqrt();
204 *r = x - U256::from(result * result);
205 return U256::from(result);
206 }
207
208 let b = bits / 4;
211
212 let mut q = x >> (b * 2);
214
215 let mut s = compute_karatsuba_sqrt(q, r, t, bits - b * 2);
217
218 *t = (U256::from(1u32) << (b * 2)) - U256::from(1u32);
220
221 *r = (*r << b) | ((x & *t) >> b);
223
224 s <<= 1;
226 q = *r / s;
227 *r -= q * s;
228
229 s = (s << (b - 1)) + q;
231
232 *t = (U256::from(1u32) << b) - U256::from(1u32);
234 *r = (*r << b) | (x & *t);
235
236 let q_squared = q * q;
238
239 if *r < q_squared {
241 *t = (s << 1) - U256::from(1u32);
242 *r += *t;
243 s -= U256::from(1u32);
244 }
245
246 *r -= q_squared;
247 s
248}
249
250#[cfg(test)]
251mod safe_math_tests {
252 use std::str::FromStr;
253
254 use rstest::rstest;
255
256 use super::*;
257
258 const U256_MAX: U256 = U256::from_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
259 const U512_MAX: U512 = U512::from_limbs([
260 u64::MAX,
261 u64::MAX,
262 u64::MAX,
263 u64::MAX,
264 u64::MAX,
265 u64::MAX,
266 u64::MAX,
267 u64::MAX,
268 ]);
269 const I256_MAX: I256 = I256::from_raw(U256::from_limbs([
271 u64::MAX,
272 u64::MAX,
273 u64::MAX,
274 9223372036854775807u64, ]));
276
277 const I256_MIN: I256 = I256::from_raw(U256::from_limbs([
279 0,
280 0,
281 0,
282 9223372036854775808u64, ]));
284
285 fn u256(s: &str) -> U256 {
286 U256::from_str(s).unwrap()
287 }
288
289 #[rstest]
290 #[case(U256_MAX, u256("2"), true, false, u256("0"))]
291 #[case(u256("3"), u256("2"), false, true, u256("6"))]
292 fn test_safe_mul_u256(
293 #[case] a: U256,
294 #[case] b: U256,
295 #[case] is_err: bool,
296 #[case] is_ok: bool,
297 #[case] expected: U256,
298 ) {
299 let res = safe_mul_u256(a, b);
300 assert_eq!(res.is_err(), is_err);
301 assert_eq!(res.is_ok(), is_ok);
302
303 if is_ok {
304 assert_eq!(res.unwrap(), expected);
305 }
306 }
307
308 #[rstest]
309 #[case(U256_MAX, u256("2"), true, false, u256("0"))]
310 #[case(u256("3"), u256("2"), false, true, u256("5"))]
311 fn test_safe_add_u256(
312 #[case] a: U256,
313 #[case] b: U256,
314 #[case] is_err: bool,
315 #[case] is_ok: bool,
316 #[case] expected: U256,
317 ) {
318 let res = safe_add_u256(a, b);
319 assert_eq!(res.is_err(), is_err);
320 assert_eq!(res.is_ok(), is_ok);
321
322 if is_ok {
323 assert_eq!(res.unwrap(), expected);
324 }
325 }
326
327 #[rstest]
328 #[case(u256("0"), u256("2"), true, false, u256("0"))]
329 #[case(u256("10"), u256("2"), false, true, u256("8"))]
330 fn test_safe_sub_u256(
331 #[case] a: U256,
332 #[case] b: U256,
333 #[case] is_err: bool,
334 #[case] is_ok: bool,
335 #[case] expected: U256,
336 ) {
337 let res = safe_sub_u256(a, b);
338 assert_eq!(res.is_err(), is_err);
339 assert_eq!(res.is_ok(), is_ok);
340
341 if is_ok {
342 assert_eq!(res.unwrap(), expected);
343 }
344 }
345
346 #[rstest]
347 #[case(u256("1"), u256("0"), true, false, u256("0"))]
348 #[case(u256("10"), u256("2"), false, true, u256("5"))]
349 fn test_safe_div_u256(
350 #[case] a: U256,
351 #[case] b: U256,
352 #[case] is_err: bool,
353 #[case] is_ok: bool,
354 #[case] expected: U256,
355 ) {
356 let res = safe_div_u256(a, b);
357 assert_eq!(res.is_err(), is_err);
358 assert_eq!(res.is_ok(), is_ok);
359
360 if is_ok {
361 assert_eq!(res.unwrap(), expected);
362 }
363 }
364
365 fn u512(s: &str) -> U512 {
366 U512::from_str(s).unwrap()
367 }
368
369 #[rstest]
370 #[case(U512_MAX, u512("2"), true, false, u512("0"))]
371 #[case(u512("3"), u512("2"), false, true, u512("6"))]
372 fn test_safe_mul_u512(
373 #[case] a: U512,
374 #[case] b: U512,
375 #[case] is_err: bool,
376 #[case] is_ok: bool,
377 #[case] expected: U512,
378 ) {
379 let res = safe_mul_u512(a, b);
380 assert_eq!(res.is_err(), is_err);
381 assert_eq!(res.is_ok(), is_ok);
382
383 if is_ok {
384 assert_eq!(res.unwrap(), expected);
385 }
386 }
387
388 #[rstest]
389 #[case(U512_MAX, u512("2"), true, false, u512("0"))]
390 #[case(u512("3"), u512("2"), false, true, u512("5"))]
391 fn test_safe_add_u512(
392 #[case] a: U512,
393 #[case] b: U512,
394 #[case] is_err: bool,
395 #[case] is_ok: bool,
396 #[case] expected: U512,
397 ) {
398 let res = safe_add_u512(a, b);
399 assert_eq!(res.is_err(), is_err);
400 assert_eq!(res.is_ok(), is_ok);
401
402 if is_ok {
403 assert_eq!(res.unwrap(), expected);
404 }
405 }
406
407 #[rstest]
408 #[case(u512("0"), u512("2"), true, false, u512("0"))]
409 #[case(u512("10"), u512("2"), false, true, u512("8"))]
410 fn test_safe_sub_u512(
411 #[case] a: U512,
412 #[case] b: U512,
413 #[case] is_err: bool,
414 #[case] is_ok: bool,
415 #[case] expected: U512,
416 ) {
417 let res = safe_sub_u512(a, b);
418 assert_eq!(res.is_err(), is_err);
419 assert_eq!(res.is_ok(), is_ok);
420
421 if is_ok {
422 assert_eq!(res.unwrap(), expected);
423 }
424 }
425
426 #[rstest]
427 #[case(u512("1"), u512("0"), true, false, u512("0"))]
428 #[case(u512("10"), u512("2"), false, true, u512("5"))]
429 fn test_safe_div_u512(
430 #[case] a: U512,
431 #[case] b: U512,
432 #[case] is_err: bool,
433 #[case] is_ok: bool,
434 #[case] expected: U512,
435 ) {
436 let res = safe_div_u512(a, b);
437 assert_eq!(res.is_err(), is_err);
438 assert_eq!(res.is_ok(), is_ok);
439
440 if is_ok {
441 assert_eq!(res.unwrap(), expected);
442 }
443 }
444
445 fn i256(s: &str) -> I256 {
446 I256::from_str(s).unwrap()
447 }
448
449 #[rstest]
450 #[case(I256_MAX, i256("2"), true, false, i256("0"))]
451 #[case(i256("3"), i256("2"), false, true, i256("6"))]
452 fn test_safe_mul_i256(
453 #[case] a: I256,
454 #[case] b: I256,
455 #[case] is_err: bool,
456 #[case] is_ok: bool,
457 #[case] expected: I256,
458 ) {
459 let res = safe_mul_i256(a, b);
460 assert_eq!(res.is_err(), is_err);
461 assert_eq!(res.is_ok(), is_ok);
462
463 if is_ok {
464 assert_eq!(res.unwrap(), expected);
465 }
466 }
467
468 #[rstest]
469 #[case(I256_MAX, i256("2"), true, false, i256("0"))]
470 #[case(i256("3"), i256("2"), false, true, i256("5"))]
471 fn test_safe_add_i256(
472 #[case] a: I256,
473 #[case] b: I256,
474 #[case] is_err: bool,
475 #[case] is_ok: bool,
476 #[case] expected: I256,
477 ) {
478 let res = safe_add_i256(a, b);
479 assert_eq!(res.is_err(), is_err);
480 assert_eq!(res.is_ok(), is_ok);
481
482 if is_ok {
483 assert_eq!(res.unwrap(), expected);
484 }
485 }
486
487 #[rstest]
488 #[case(I256_MIN, i256("2"), true, false, i256("0"))]
489 #[case(i256("10"), i256("2"), false, true, i256("8"))]
490 fn test_safe_sub_i256(
491 #[case] a: I256,
492 #[case] b: I256,
493 #[case] is_err: bool,
494 #[case] is_ok: bool,
495 #[case] expected: I256,
496 ) {
497 let res = safe_sub_i256(a, b);
498 assert_eq!(res.is_err(), is_err);
499 assert_eq!(res.is_ok(), is_ok);
500
501 if is_ok {
502 assert_eq!(res.unwrap(), expected);
503 }
504 }
505
506 #[rstest]
507 #[case(i256("1"), i256("0"), true, false, i256("0"))]
508 #[case(i256("10"), i256("2"), false, true, i256("5"))]
509 fn test_safe_div_i256(
510 #[case] a: I256,
511 #[case] b: I256,
512 #[case] is_err: bool,
513 #[case] is_ok: bool,
514 #[case] expected: I256,
515 ) {
516 let res = safe_div_i256(a, b);
517 assert_eq!(res.is_err(), is_err);
518 assert_eq!(res.is_ok(), is_ok);
519
520 if is_ok {
521 assert_eq!(res.unwrap(), expected);
522 }
523 }
524
525 #[test]
526 fn test_sqrt_u512() {
527 assert_eq!(sqrt_u512(U512::ZERO), U512::ZERO);
529 assert_eq!(sqrt_u512(U512::from(1u32)), U512::from(1u32));
530
531 assert_eq!(sqrt_u512(U512::from(4u32)), U512::from(2u32));
533 assert_eq!(sqrt_u512(U512::from(100u32)), U512::from(10u32));
534 assert_eq!(sqrt_u512(U512::from(10000u32)), U512::from(100u32));
535 assert_eq!(sqrt_u512(U512::from(1000000u32)), U512::from(1000u32));
536
537 assert_eq!(sqrt_u512(U512::from(2u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(3u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(5u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(8u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(10u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(15u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(99u32)), U512::from(9u32)); let large = U512::from_str("1000000000000000000000000000000000000").unwrap();
548 let sqrt_large = sqrt_u512(large);
549 assert!(sqrt_large * sqrt_large <= large);
551 assert!((sqrt_large + U512::from(1u32)) * (sqrt_large + U512::from(1u32)) > large);
552 }
553
554 #[test]
557 fn test_sqrt_u256_u64_max() {
558 let result = sqrt_u256(U256::from(u64::MAX)).unwrap();
559 assert_eq!(result, U256::from(u32::MAX));
560 }
561
562 #[test]
565 fn test_sqrt_u256_floor_near_perfect_square() {
566 let x = U256::from(67108865u64 * 67108865u64 - 1);
567 let result = sqrt_u256(x).unwrap();
568 assert_eq!(result, U256::from(67108864u64));
569 }
570
571 #[test]
575 fn test_sqrt_u256_floor_invariant_full_range() {
576 let mut rng_state = 0x9E3779B97F4A7C15u64;
577 let mut next_rand = move || {
578 rng_state ^= rng_state >> 12;
579 rng_state ^= rng_state << 25;
580 rng_state ^= rng_state >> 27;
581 rng_state.wrapping_mul(0x2545F4914F6CDD1D)
582 };
583
584 let mut cases: Vec<U256> = vec![U256::ZERO, U256::from(1u64), U256::MAX];
585 for bits in 1..=256u32 {
586 let low = U256::from(1u64) << (bits - 1);
587 let high =
588 if bits == 256 { U256::MAX } else { (U256::from(1u64) << bits) - U256::from(1u64) };
589 cases.push(low);
590 cases.push(high);
591 for _ in 0..4 {
592 let mut draw = U256::ZERO;
593 for limb in 0..4 {
594 draw |= U256::from(next_rand()) << (64 * limb);
595 }
596 cases.push(low + draw % (high - low + U256::from(1u64)));
597 }
598 }
599
600 for x in cases {
601 let result = sqrt_u256(x).unwrap();
602 let wide = U512::from(result);
603 let x_wide = U512::from(x);
604 assert!(wide * wide <= x_wide, "floor violated for x={x}");
605 let next = wide + U512::from(1u64);
606 assert!(next * next > x_wide, "not the greatest root for x={x}");
607 }
608 }
609
610 #[test]
611 fn test_sqrt_u256_floor_invariant_in_base_case_range() {
612 for x_small in [
613 0u64,
614 1,
615 2,
616 3,
617 4,
618 (1 << 26) - 1,
619 1 << 26,
620 (1 << 53) - 1,
621 1 << 53,
622 (1 << 53) + 1,
623 67108864 * 67108864,
624 u32::MAX as u64,
625 (u32::MAX as u64).pow(2),
626 (u32::MAX as u64).pow(2) - 1,
627 u64::MAX - 1,
628 u64::MAX,
629 ] {
630 let x = U256::from(x_small);
631 let result = sqrt_u256(x).unwrap();
632 assert!(result * result <= x, "floor violated for {x_small}");
633 let next = result + U256::from(1u64);
634 assert!(next * next > x, "not the greatest root for {x_small}");
635 }
636 }
637
638 #[test]
642 fn test_sqrt_u256_recursive_perfect_square_boundaries() {
643 let one = U256::from(1u64);
644 let roots = [
645 one << 33, one << 64,
647 (one << 96) - one,
648 one << 100,
649 (one << 120) + U256::from(12345u64),
650 (one << 127) - one,
651 (one << 128) - one, ];
653 for k in roots {
654 let square = k * k;
655 assert_eq!(sqrt_u256(square).unwrap(), k, "sqrt(k²) for k={k}");
656 assert_eq!(sqrt_u256(square - one).unwrap(), k - one, "sqrt(k²−1) for k={k}");
657 assert_eq!(sqrt_u256(square + one).unwrap(), k, "sqrt(k²+1) for k={k}");
659 }
660 }
661}