1use alloy::primitives::{I256, U256, U512};
9use tycho_common::simulation::errors::SimulationError;
10
11pub fn safe_mul_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
12 let res = a.checked_mul(b);
13 _construc_result_u256(res)
14}
15
16pub fn safe_div_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
17 if b.is_zero() {
18 return Err(SimulationError::FatalError("Division by zero".to_string()));
19 }
20 let res = a.checked_div(b);
21 _construc_result_u256(res)
22}
23
24pub fn safe_add_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
25 let res = a.checked_add(b);
26 _construc_result_u256(res)
27}
28
29pub fn safe_sub_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
30 let res = a.checked_sub(b);
31 _construc_result_u256(res)
32}
33
34pub fn div_mod_u256(a: U256, b: U256) -> Result<(U256, U256), SimulationError> {
35 if b.is_zero() {
36 return Err(SimulationError::FatalError("Division by zero".to_string()));
37 }
38 let result = a / b;
39 let rest = a % b;
40 Ok((result, rest))
41}
42
43pub fn _construc_result_u256(res: Option<U256>) -> Result<U256, SimulationError> {
44 match res {
45 None => Err(SimulationError::FatalError("U256 arithmetic overflow".to_string())),
46 Some(value) => Ok(value),
47 }
48}
49
50pub fn safe_mul_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
51 let res = a.checked_mul(b);
52 _construc_result_u512(res)
53}
54
55pub fn safe_div_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
56 if b.is_zero() {
57 return Err(SimulationError::FatalError("Division by zero".to_string()));
58 }
59 let res = a.checked_div(b);
60 _construc_result_u512(res)
61}
62
63pub fn safe_add_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
64 let res = a.checked_add(b);
65 _construc_result_u512(res)
66}
67
68pub fn safe_sub_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
69 let res = a.checked_sub(b);
70 _construc_result_u512(res)
71}
72
73pub fn div_mod_u512(a: U512, b: U512) -> Result<(U512, U512), SimulationError> {
74 if b.is_zero() {
75 return Err(SimulationError::FatalError("Division by zero".to_string()));
76 }
77 let result = a / b;
78 let rest = a % b;
79 Ok((result, rest))
80}
81
82pub fn _construc_result_u512(res: Option<U512>) -> Result<U512, SimulationError> {
83 match res {
84 None => Err(SimulationError::FatalError("U512 arithmetic overflow".to_string())),
85 Some(value) => Ok(value),
86 }
87}
88
89pub fn safe_mul_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
90 let res = a.checked_mul(b);
91 _construc_result_i256(res)
92}
93
94pub fn safe_div_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
95 if b.is_zero() {
96 return Err(SimulationError::FatalError("Division by zero".to_string()));
97 }
98 let res = a.checked_div(b);
99 _construc_result_i256(res)
100}
101
102pub fn safe_add_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
103 let res = a.checked_add(b);
104 _construc_result_i256(res)
105}
106
107pub fn safe_sub_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
108 let res = a.checked_sub(b);
109 _construc_result_i256(res)
110}
111
112pub fn _construc_result_i256(res: Option<I256>) -> Result<I256, SimulationError> {
113 match res {
114 None => Err(SimulationError::FatalError("I256 arithmetic overflow".to_string())),
115 Some(value) => Ok(value),
116 }
117}
118
119pub fn sqrt_u512(value: U512) -> U512 {
130 if value == U512::ZERO {
132 return U512::ZERO;
133 }
134
135 if value == U512::from(1u32) {
137 return U512::from(1u32);
138 }
139
140 let bits = 512 - value.leading_zeros();
143 let mut result = U512::from(1u32) << (bits / 2);
144
145 let mut decreasing = false;
148 loop {
149 let division = value / result;
151 let iter = (division + result) / U512::from(2u32);
152
153 if iter == result {
155 break;
157 }
158
159 if iter > result {
160 if decreasing {
161 break;
163 }
164 result =
166 if iter > result * U512::from(2u32) { result * U512::from(2u32) } else { iter };
167 } else {
168 decreasing = true;
170 result = iter;
171 }
172 }
173
174 result
175}
176
177pub fn sqrt_u256(value: U256) -> Result<U256, SimulationError> {
179 if value == U256::ZERO {
180 return Ok(U256::ZERO);
181 }
182
183 let bits = 256 - value.leading_zeros();
184 let mut remainder = U256::ZERO;
185 let mut temp = U256::ZERO;
186 let result = compute_karatsuba_sqrt(value, &mut remainder, &mut temp, bits);
187
188 let limbs = result.as_limbs();
190 Ok(U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]]))
191}
192
193fn compute_karatsuba_sqrt(x: U256, r: &mut U256, t: &mut U256, bits: usize) -> U256 {
198 if bits <= 64 {
200 let x_small = x.as_limbs()[0];
201 let result = (x_small as f64).sqrt() as u64;
202 *r = x - U256::from(result * result);
203 return U256::from(result);
204 }
205
206 let b = bits / 4;
209
210 let mut q = x >> (b * 2);
212
213 let mut s = compute_karatsuba_sqrt(q, r, t, bits - b * 2);
215
216 *t = (U256::from(1u32) << (b * 2)) - U256::from(1u32);
218
219 *r = (*r << b) | ((x & *t) >> b);
221
222 s <<= 1;
224 q = *r / s;
225 *r -= q * s;
226
227 s = (s << (b - 1)) + q;
229
230 *t = (U256::from(1u32) << b) - U256::from(1u32);
232 *r = (*r << b) | (x & *t);
233
234 let q_squared = q * q;
236
237 if *r < q_squared {
239 *t = (s << 1) - U256::from(1u32);
240 *r += *t;
241 s -= U256::from(1u32);
242 }
243
244 *r -= q_squared;
245 s
246}
247
248#[cfg(test)]
249mod safe_math_tests {
250 use std::str::FromStr;
251
252 use rstest::rstest;
253
254 use super::*;
255
256 const U256_MAX: U256 = U256::from_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
257 const U512_MAX: U512 = U512::from_limbs([
258 u64::MAX,
259 u64::MAX,
260 u64::MAX,
261 u64::MAX,
262 u64::MAX,
263 u64::MAX,
264 u64::MAX,
265 u64::MAX,
266 ]);
267 const I256_MAX: I256 = I256::from_raw(U256::from_limbs([
269 u64::MAX,
270 u64::MAX,
271 u64::MAX,
272 9223372036854775807u64, ]));
274
275 const I256_MIN: I256 = I256::from_raw(U256::from_limbs([
277 0,
278 0,
279 0,
280 9223372036854775808u64, ]));
282
283 fn u256(s: &str) -> U256 {
284 U256::from_str(s).unwrap()
285 }
286
287 #[rstest]
288 #[case(U256_MAX, u256("2"), true, false, u256("0"))]
289 #[case(u256("3"), u256("2"), false, true, u256("6"))]
290 fn test_safe_mul_u256(
291 #[case] a: U256,
292 #[case] b: U256,
293 #[case] is_err: bool,
294 #[case] is_ok: bool,
295 #[case] expected: U256,
296 ) {
297 let res = safe_mul_u256(a, b);
298 assert_eq!(res.is_err(), is_err);
299 assert_eq!(res.is_ok(), is_ok);
300
301 if is_ok {
302 assert_eq!(res.unwrap(), expected);
303 }
304 }
305
306 #[rstest]
307 #[case(U256_MAX, u256("2"), true, false, u256("0"))]
308 #[case(u256("3"), u256("2"), false, true, u256("5"))]
309 fn test_safe_add_u256(
310 #[case] a: U256,
311 #[case] b: U256,
312 #[case] is_err: bool,
313 #[case] is_ok: bool,
314 #[case] expected: U256,
315 ) {
316 let res = safe_add_u256(a, b);
317 assert_eq!(res.is_err(), is_err);
318 assert_eq!(res.is_ok(), is_ok);
319
320 if is_ok {
321 assert_eq!(res.unwrap(), expected);
322 }
323 }
324
325 #[rstest]
326 #[case(u256("0"), u256("2"), true, false, u256("0"))]
327 #[case(u256("10"), u256("2"), false, true, u256("8"))]
328 fn test_safe_sub_u256(
329 #[case] a: U256,
330 #[case] b: U256,
331 #[case] is_err: bool,
332 #[case] is_ok: bool,
333 #[case] expected: U256,
334 ) {
335 let res = safe_sub_u256(a, b);
336 assert_eq!(res.is_err(), is_err);
337 assert_eq!(res.is_ok(), is_ok);
338
339 if is_ok {
340 assert_eq!(res.unwrap(), expected);
341 }
342 }
343
344 #[rstest]
345 #[case(u256("1"), u256("0"), true, false, u256("0"))]
346 #[case(u256("10"), u256("2"), false, true, u256("5"))]
347 fn test_safe_div_u256(
348 #[case] a: U256,
349 #[case] b: U256,
350 #[case] is_err: bool,
351 #[case] is_ok: bool,
352 #[case] expected: U256,
353 ) {
354 let res = safe_div_u256(a, b);
355 assert_eq!(res.is_err(), is_err);
356 assert_eq!(res.is_ok(), is_ok);
357
358 if is_ok {
359 assert_eq!(res.unwrap(), expected);
360 }
361 }
362
363 fn u512(s: &str) -> U512 {
364 U512::from_str(s).unwrap()
365 }
366
367 #[rstest]
368 #[case(U512_MAX, u512("2"), true, false, u512("0"))]
369 #[case(u512("3"), u512("2"), false, true, u512("6"))]
370 fn test_safe_mul_u512(
371 #[case] a: U512,
372 #[case] b: U512,
373 #[case] is_err: bool,
374 #[case] is_ok: bool,
375 #[case] expected: U512,
376 ) {
377 let res = safe_mul_u512(a, b);
378 assert_eq!(res.is_err(), is_err);
379 assert_eq!(res.is_ok(), is_ok);
380
381 if is_ok {
382 assert_eq!(res.unwrap(), expected);
383 }
384 }
385
386 #[rstest]
387 #[case(U512_MAX, u512("2"), true, false, u512("0"))]
388 #[case(u512("3"), u512("2"), false, true, u512("5"))]
389 fn test_safe_add_u512(
390 #[case] a: U512,
391 #[case] b: U512,
392 #[case] is_err: bool,
393 #[case] is_ok: bool,
394 #[case] expected: U512,
395 ) {
396 let res = safe_add_u512(a, b);
397 assert_eq!(res.is_err(), is_err);
398 assert_eq!(res.is_ok(), is_ok);
399
400 if is_ok {
401 assert_eq!(res.unwrap(), expected);
402 }
403 }
404
405 #[rstest]
406 #[case(u512("0"), u512("2"), true, false, u512("0"))]
407 #[case(u512("10"), u512("2"), false, true, u512("8"))]
408 fn test_safe_sub_u512(
409 #[case] a: U512,
410 #[case] b: U512,
411 #[case] is_err: bool,
412 #[case] is_ok: bool,
413 #[case] expected: U512,
414 ) {
415 let res = safe_sub_u512(a, b);
416 assert_eq!(res.is_err(), is_err);
417 assert_eq!(res.is_ok(), is_ok);
418
419 if is_ok {
420 assert_eq!(res.unwrap(), expected);
421 }
422 }
423
424 #[rstest]
425 #[case(u512("1"), u512("0"), true, false, u512("0"))]
426 #[case(u512("10"), u512("2"), false, true, u512("5"))]
427 fn test_safe_div_u512(
428 #[case] a: U512,
429 #[case] b: U512,
430 #[case] is_err: bool,
431 #[case] is_ok: bool,
432 #[case] expected: U512,
433 ) {
434 let res = safe_div_u512(a, b);
435 assert_eq!(res.is_err(), is_err);
436 assert_eq!(res.is_ok(), is_ok);
437
438 if is_ok {
439 assert_eq!(res.unwrap(), expected);
440 }
441 }
442
443 fn i256(s: &str) -> I256 {
444 I256::from_str(s).unwrap()
445 }
446
447 #[rstest]
448 #[case(I256_MAX, i256("2"), true, false, i256("0"))]
449 #[case(i256("3"), i256("2"), false, true, i256("6"))]
450 fn test_safe_mul_i256(
451 #[case] a: I256,
452 #[case] b: I256,
453 #[case] is_err: bool,
454 #[case] is_ok: bool,
455 #[case] expected: I256,
456 ) {
457 let res = safe_mul_i256(a, b);
458 assert_eq!(res.is_err(), is_err);
459 assert_eq!(res.is_ok(), is_ok);
460
461 if is_ok {
462 assert_eq!(res.unwrap(), expected);
463 }
464 }
465
466 #[rstest]
467 #[case(I256_MAX, i256("2"), true, false, i256("0"))]
468 #[case(i256("3"), i256("2"), false, true, i256("5"))]
469 fn test_safe_add_i256(
470 #[case] a: I256,
471 #[case] b: I256,
472 #[case] is_err: bool,
473 #[case] is_ok: bool,
474 #[case] expected: I256,
475 ) {
476 let res = safe_add_i256(a, b);
477 assert_eq!(res.is_err(), is_err);
478 assert_eq!(res.is_ok(), is_ok);
479
480 if is_ok {
481 assert_eq!(res.unwrap(), expected);
482 }
483 }
484
485 #[rstest]
486 #[case(I256_MIN, i256("2"), true, false, i256("0"))]
487 #[case(i256("10"), i256("2"), false, true, i256("8"))]
488 fn test_safe_sub_i256(
489 #[case] a: I256,
490 #[case] b: I256,
491 #[case] is_err: bool,
492 #[case] is_ok: bool,
493 #[case] expected: I256,
494 ) {
495 let res = safe_sub_i256(a, b);
496 assert_eq!(res.is_err(), is_err);
497 assert_eq!(res.is_ok(), is_ok);
498
499 if is_ok {
500 assert_eq!(res.unwrap(), expected);
501 }
502 }
503
504 #[rstest]
505 #[case(i256("1"), i256("0"), true, false, i256("0"))]
506 #[case(i256("10"), i256("2"), false, true, i256("5"))]
507 fn test_safe_div_i256(
508 #[case] a: I256,
509 #[case] b: I256,
510 #[case] is_err: bool,
511 #[case] is_ok: bool,
512 #[case] expected: I256,
513 ) {
514 let res = safe_div_i256(a, b);
515 assert_eq!(res.is_err(), is_err);
516 assert_eq!(res.is_ok(), is_ok);
517
518 if is_ok {
519 assert_eq!(res.unwrap(), expected);
520 }
521 }
522
523 #[test]
524 fn test_sqrt_u512() {
525 assert_eq!(sqrt_u512(U512::ZERO), U512::ZERO);
527 assert_eq!(sqrt_u512(U512::from(1u32)), U512::from(1u32));
528
529 assert_eq!(sqrt_u512(U512::from(4u32)), U512::from(2u32));
531 assert_eq!(sqrt_u512(U512::from(100u32)), U512::from(10u32));
532 assert_eq!(sqrt_u512(U512::from(10000u32)), U512::from(100u32));
533 assert_eq!(sqrt_u512(U512::from(1000000u32)), U512::from(1000u32));
534
535 assert_eq!(sqrt_u512(U512::from(2u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(3u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(5u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(8u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(10u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(15u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(99u32)), U512::from(9u32)); let large = U512::from_str("1000000000000000000000000000000000000").unwrap();
546 let sqrt_large = sqrt_u512(large);
547 assert!(sqrt_large * sqrt_large <= large);
549 assert!((sqrt_large + U512::from(1u32)) * (sqrt_large + U512::from(1u32)) > large);
550 }
551}