tycho_simulation/evm/protocol/
safe_math.rs

1//! Safe Math
2//!
3//! This module contains basic functions to perform arithmetic operations on
4//! numerical types of the alloy crate and preventing them from overflowing.
5//! Should an operation cause an overflow a result containing TradeSimulationError
6//! will be returned.
7//! Functions for the types I256, U256, U512 are available.
8use alloy::primitives::{I256, U256, U512};
9use tycho_common::simulation::errors::SimulationError;
10
11pub fn safe_mul_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
12    let res = a.checked_mul(b);
13    _construc_result_u256(res)
14}
15
16pub fn safe_div_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
17    if b.is_zero() {
18        return Err(SimulationError::FatalError("Division by zero".to_string()));
19    }
20    let res = a.checked_div(b);
21    _construc_result_u256(res)
22}
23
24pub fn safe_add_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
25    let res = a.checked_add(b);
26    _construc_result_u256(res)
27}
28
29pub fn safe_sub_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
30    let res = a.checked_sub(b);
31    _construc_result_u256(res)
32}
33
34pub fn div_mod_u256(a: U256, b: U256) -> Result<(U256, U256), SimulationError> {
35    if b.is_zero() {
36        return Err(SimulationError::FatalError("Division by zero".to_string()));
37    }
38    let result = a / b;
39    let rest = a % b;
40    Ok((result, rest))
41}
42
43pub fn _construc_result_u256(res: Option<U256>) -> Result<U256, SimulationError> {
44    match res {
45        None => Err(SimulationError::FatalError("U256 arithmetic overflow".to_string())),
46        Some(value) => Ok(value),
47    }
48}
49
50pub fn safe_mul_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
51    let res = a.checked_mul(b);
52    _construc_result_u512(res)
53}
54
55pub fn safe_div_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
56    if b.is_zero() {
57        return Err(SimulationError::FatalError("Division by zero".to_string()));
58    }
59    let res = a.checked_div(b);
60    _construc_result_u512(res)
61}
62
63pub fn safe_add_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
64    let res = a.checked_add(b);
65    _construc_result_u512(res)
66}
67
68pub fn safe_sub_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
69    let res = a.checked_sub(b);
70    _construc_result_u512(res)
71}
72
73pub fn div_mod_u512(a: U512, b: U512) -> Result<(U512, U512), SimulationError> {
74    if b.is_zero() {
75        return Err(SimulationError::FatalError("Division by zero".to_string()));
76    }
77    let result = a / b;
78    let rest = a % b;
79    Ok((result, rest))
80}
81
82pub fn _construc_result_u512(res: Option<U512>) -> Result<U512, SimulationError> {
83    match res {
84        None => Err(SimulationError::FatalError("U512 arithmetic overflow".to_string())),
85        Some(value) => Ok(value),
86    }
87}
88
89pub fn safe_mul_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
90    let res = a.checked_mul(b);
91    _construc_result_i256(res)
92}
93
94pub fn safe_div_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
95    if b.is_zero() {
96        return Err(SimulationError::FatalError("Division by zero".to_string()));
97    }
98    let res = a.checked_div(b);
99    _construc_result_i256(res)
100}
101
102pub fn safe_add_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
103    let res = a.checked_add(b);
104    _construc_result_i256(res)
105}
106
107pub fn safe_sub_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
108    let res = a.checked_sub(b);
109    _construc_result_i256(res)
110}
111
112pub fn _construc_result_i256(res: Option<I256>) -> Result<I256, SimulationError> {
113    match res {
114        None => Err(SimulationError::FatalError("I256 arithmetic overflow".to_string())),
115        Some(value) => Ok(value),
116    }
117}
118
119/// Computes the integer square root of a U512 value using Newton's method.
120///
121/// Returns the floor of the square root.
122///
123/// # Algorithm
124///
125/// Uses Newton's method iteration:
126/// - Start with initial guess based on bit length
127/// - Iterate: x_new = (x + n/x) / 2
128/// - Stop when convergence is reached or value stops decreasing
129pub fn sqrt_u512(value: U512) -> U512 {
130    // Handle zero case
131    if value == U512::ZERO {
132        return U512::ZERO;
133    }
134
135    // Handle small values
136    if value == U512::from(1u32) {
137        return U512::from(1u32);
138    }
139
140    // Initial guess: use bit length to get approximate starting point
141    // For square root, start with 2^(bits/2)
142    let bits = 512 - value.leading_zeros();
143    let mut result = U512::from(1u32) << (bits / 2);
144
145    // Newton's method iteration for square root
146    // x_new = (x + n/x) / 2
147    let mut decreasing = false;
148    loop {
149        // Calculate: (value / result + result) / 2
150        let division = value / result;
151        let iter = (division + result) / U512::from(2u32);
152
153        // Check convergence
154        if iter == result {
155            // Hit fixed point, we're done
156            break;
157        }
158
159        if iter > result {
160            if decreasing {
161                // Was decreasing, now increasing - we've converged
162                break;
163            }
164            // Limit increase to prevent slow convergence
165            result =
166                if iter > result * U512::from(2u32) { result * U512::from(2u32) } else { iter };
167        } else {
168            // Converging downwards
169            decreasing = true;
170            result = iter;
171        }
172    }
173
174    result
175}
176
177/// Integer square root for U256, returning U256
178pub fn sqrt_u256(value: U256) -> Result<U256, SimulationError> {
179    if value == U256::ZERO {
180        return Ok(U256::ZERO);
181    }
182
183    let bits = 256 - value.leading_zeros();
184    let mut remainder = U256::ZERO;
185    let mut temp = U256::ZERO;
186    let result = compute_karatsuba_sqrt(value, &mut remainder, &mut temp, bits);
187
188    // Extract lower 256 bits
189    let limbs = result.as_limbs();
190    Ok(U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]]))
191}
192
193/// Recursive Karatsuba square root implementation
194/// Computes sqrt(x) and stores remainder in r
195/// Uses temp variable t for intermediate calculations
196/// Ref: https://hal.inria.fr/file/index/docid/72854/filename/RR-3805.pdf
197fn compute_karatsuba_sqrt(x: U256, r: &mut U256, t: &mut U256, bits: usize) -> U256 {
198    // Base case: use simple method for small numbers
199    if bits <= 64 {
200        let x_small = x.as_limbs()[0];
201        let result = (x_small as f64).sqrt() as u64;
202        *r = x - U256::from(result * result);
203        return U256::from(result);
204    }
205
206    // Divide-and-conquer approach
207    // Split into quarters: process b bits at a time where b = bits/4
208    let b = bits / 4;
209
210    // q = x >> (2*b)  -- extract upper bits
211    let mut q = x >> (b * 2);
212
213    // Recursively compute sqrt of upper portion
214    let mut s = compute_karatsuba_sqrt(q, r, t, bits - b * 2);
215
216    // Build mask for extracting bits: (1 << (2*b)) - 1
217    *t = (U256::from(1u32) << (b * 2)) - U256::from(1u32);
218
219    // Extract middle bits and combine with remainder from recursive call
220    *r = (*r << b) | ((x & *t) >> b);
221
222    // Divide: t = r / (2*s), with quotient q and remainder r
223    s <<= 1;
224    q = *r / s;
225    *r -= q * s;
226
227    // Build s = (s << (b-1)) + q
228    s = (s << (b - 1)) + q;
229
230    // Extract lower b bits
231    *t = (U256::from(1u32) << b) - U256::from(1u32);
232    *r = (*r << b) | (x & *t);
233
234    // Compute q^2
235    let q_squared = q * q;
236
237    // Adjust if remainder is too small
238    if *r < q_squared {
239        *t = (s << 1) - U256::from(1u32);
240        *r += *t;
241        s -= U256::from(1u32);
242    }
243
244    *r -= q_squared;
245    s
246}
247
248#[cfg(test)]
249mod safe_math_tests {
250    use std::str::FromStr;
251
252    use rstest::rstest;
253
254    use super::*;
255
256    const U256_MAX: U256 = U256::from_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
257    const U512_MAX: U512 = U512::from_limbs([
258        u64::MAX,
259        u64::MAX,
260        u64::MAX,
261        u64::MAX,
262        u64::MAX,
263        u64::MAX,
264        u64::MAX,
265        u64::MAX,
266    ]);
267    /// I256 maximum value: 2^255 - 1
268    const I256_MAX: I256 = I256::from_raw(U256::from_limbs([
269        u64::MAX,
270        u64::MAX,
271        u64::MAX,
272        9223372036854775807u64, // 2^63 - 1 in the highest limb
273    ]));
274
275    /// I256 minimum value: -2^255
276    const I256_MIN: I256 = I256::from_raw(U256::from_limbs([
277        0,
278        0,
279        0,
280        9223372036854775808u64, // 2^63 in the highest limb
281    ]));
282
283    fn u256(s: &str) -> U256 {
284        U256::from_str(s).unwrap()
285    }
286
287    #[rstest]
288    #[case(U256_MAX, u256("2"), true, false, u256("0"))]
289    #[case(u256("3"), u256("2"), false, true, u256("6"))]
290    fn test_safe_mul_u256(
291        #[case] a: U256,
292        #[case] b: U256,
293        #[case] is_err: bool,
294        #[case] is_ok: bool,
295        #[case] expected: U256,
296    ) {
297        let res = safe_mul_u256(a, b);
298        assert_eq!(res.is_err(), is_err);
299        assert_eq!(res.is_ok(), is_ok);
300
301        if is_ok {
302            assert_eq!(res.unwrap(), expected);
303        }
304    }
305
306    #[rstest]
307    #[case(U256_MAX, u256("2"), true, false, u256("0"))]
308    #[case(u256("3"), u256("2"), false, true, u256("5"))]
309    fn test_safe_add_u256(
310        #[case] a: U256,
311        #[case] b: U256,
312        #[case] is_err: bool,
313        #[case] is_ok: bool,
314        #[case] expected: U256,
315    ) {
316        let res = safe_add_u256(a, b);
317        assert_eq!(res.is_err(), is_err);
318        assert_eq!(res.is_ok(), is_ok);
319
320        if is_ok {
321            assert_eq!(res.unwrap(), expected);
322        }
323    }
324
325    #[rstest]
326    #[case(u256("0"), u256("2"), true, false, u256("0"))]
327    #[case(u256("10"), u256("2"), false, true, u256("8"))]
328    fn test_safe_sub_u256(
329        #[case] a: U256,
330        #[case] b: U256,
331        #[case] is_err: bool,
332        #[case] is_ok: bool,
333        #[case] expected: U256,
334    ) {
335        let res = safe_sub_u256(a, b);
336        assert_eq!(res.is_err(), is_err);
337        assert_eq!(res.is_ok(), is_ok);
338
339        if is_ok {
340            assert_eq!(res.unwrap(), expected);
341        }
342    }
343
344    #[rstest]
345    #[case(u256("1"), u256("0"), true, false, u256("0"))]
346    #[case(u256("10"), u256("2"), false, true, u256("5"))]
347    fn test_safe_div_u256(
348        #[case] a: U256,
349        #[case] b: U256,
350        #[case] is_err: bool,
351        #[case] is_ok: bool,
352        #[case] expected: U256,
353    ) {
354        let res = safe_div_u256(a, b);
355        assert_eq!(res.is_err(), is_err);
356        assert_eq!(res.is_ok(), is_ok);
357
358        if is_ok {
359            assert_eq!(res.unwrap(), expected);
360        }
361    }
362
363    fn u512(s: &str) -> U512 {
364        U512::from_str(s).unwrap()
365    }
366
367    #[rstest]
368    #[case(U512_MAX, u512("2"), true, false, u512("0"))]
369    #[case(u512("3"), u512("2"), false, true, u512("6"))]
370    fn test_safe_mul_u512(
371        #[case] a: U512,
372        #[case] b: U512,
373        #[case] is_err: bool,
374        #[case] is_ok: bool,
375        #[case] expected: U512,
376    ) {
377        let res = safe_mul_u512(a, b);
378        assert_eq!(res.is_err(), is_err);
379        assert_eq!(res.is_ok(), is_ok);
380
381        if is_ok {
382            assert_eq!(res.unwrap(), expected);
383        }
384    }
385
386    #[rstest]
387    #[case(U512_MAX, u512("2"), true, false, u512("0"))]
388    #[case(u512("3"), u512("2"), false, true, u512("5"))]
389    fn test_safe_add_u512(
390        #[case] a: U512,
391        #[case] b: U512,
392        #[case] is_err: bool,
393        #[case] is_ok: bool,
394        #[case] expected: U512,
395    ) {
396        let res = safe_add_u512(a, b);
397        assert_eq!(res.is_err(), is_err);
398        assert_eq!(res.is_ok(), is_ok);
399
400        if is_ok {
401            assert_eq!(res.unwrap(), expected);
402        }
403    }
404
405    #[rstest]
406    #[case(u512("0"), u512("2"), true, false, u512("0"))]
407    #[case(u512("10"), u512("2"), false, true, u512("8"))]
408    fn test_safe_sub_u512(
409        #[case] a: U512,
410        #[case] b: U512,
411        #[case] is_err: bool,
412        #[case] is_ok: bool,
413        #[case] expected: U512,
414    ) {
415        let res = safe_sub_u512(a, b);
416        assert_eq!(res.is_err(), is_err);
417        assert_eq!(res.is_ok(), is_ok);
418
419        if is_ok {
420            assert_eq!(res.unwrap(), expected);
421        }
422    }
423
424    #[rstest]
425    #[case(u512("1"), u512("0"), true, false, u512("0"))]
426    #[case(u512("10"), u512("2"), false, true, u512("5"))]
427    fn test_safe_div_u512(
428        #[case] a: U512,
429        #[case] b: U512,
430        #[case] is_err: bool,
431        #[case] is_ok: bool,
432        #[case] expected: U512,
433    ) {
434        let res = safe_div_u512(a, b);
435        assert_eq!(res.is_err(), is_err);
436        assert_eq!(res.is_ok(), is_ok);
437
438        if is_ok {
439            assert_eq!(res.unwrap(), expected);
440        }
441    }
442
443    fn i256(s: &str) -> I256 {
444        I256::from_str(s).unwrap()
445    }
446
447    #[rstest]
448    #[case(I256_MAX, i256("2"), true, false, i256("0"))]
449    #[case(i256("3"), i256("2"), false, true, i256("6"))]
450    fn test_safe_mul_i256(
451        #[case] a: I256,
452        #[case] b: I256,
453        #[case] is_err: bool,
454        #[case] is_ok: bool,
455        #[case] expected: I256,
456    ) {
457        let res = safe_mul_i256(a, b);
458        assert_eq!(res.is_err(), is_err);
459        assert_eq!(res.is_ok(), is_ok);
460
461        if is_ok {
462            assert_eq!(res.unwrap(), expected);
463        }
464    }
465
466    #[rstest]
467    #[case(I256_MAX, i256("2"), true, false, i256("0"))]
468    #[case(i256("3"), i256("2"), false, true, i256("5"))]
469    fn test_safe_add_i256(
470        #[case] a: I256,
471        #[case] b: I256,
472        #[case] is_err: bool,
473        #[case] is_ok: bool,
474        #[case] expected: I256,
475    ) {
476        let res = safe_add_i256(a, b);
477        assert_eq!(res.is_err(), is_err);
478        assert_eq!(res.is_ok(), is_ok);
479
480        if is_ok {
481            assert_eq!(res.unwrap(), expected);
482        }
483    }
484
485    #[rstest]
486    #[case(I256_MIN, i256("2"), true, false, i256("0"))]
487    #[case(i256("10"), i256("2"), false, true, i256("8"))]
488    fn test_safe_sub_i256(
489        #[case] a: I256,
490        #[case] b: I256,
491        #[case] is_err: bool,
492        #[case] is_ok: bool,
493        #[case] expected: I256,
494    ) {
495        let res = safe_sub_i256(a, b);
496        assert_eq!(res.is_err(), is_err);
497        assert_eq!(res.is_ok(), is_ok);
498
499        if is_ok {
500            assert_eq!(res.unwrap(), expected);
501        }
502    }
503
504    #[rstest]
505    #[case(i256("1"), i256("0"), true, false, i256("0"))]
506    #[case(i256("10"), i256("2"), false, true, i256("5"))]
507    fn test_safe_div_i256(
508        #[case] a: I256,
509        #[case] b: I256,
510        #[case] is_err: bool,
511        #[case] is_ok: bool,
512        #[case] expected: I256,
513    ) {
514        let res = safe_div_i256(a, b);
515        assert_eq!(res.is_err(), is_err);
516        assert_eq!(res.is_ok(), is_ok);
517
518        if is_ok {
519            assert_eq!(res.unwrap(), expected);
520        }
521    }
522
523    #[test]
524    fn test_sqrt_u512() {
525        // Test edge cases
526        assert_eq!(sqrt_u512(U512::ZERO), U512::ZERO);
527        assert_eq!(sqrt_u512(U512::from(1u32)), U512::from(1u32));
528
529        // Test small perfect squares
530        assert_eq!(sqrt_u512(U512::from(4u32)), U512::from(2u32));
531        assert_eq!(sqrt_u512(U512::from(100u32)), U512::from(10u32));
532        assert_eq!(sqrt_u512(U512::from(10000u32)), U512::from(100u32));
533        assert_eq!(sqrt_u512(U512::from(1000000u32)), U512::from(1000u32));
534
535        // For non-perfect squares, should return floor of sqrt
536        assert_eq!(sqrt_u512(U512::from(2u32)), U512::from(1u32)); // sqrt(2) ≈ 1.41
537        assert_eq!(sqrt_u512(U512::from(3u32)), U512::from(1u32)); // sqrt(3) ≈ 1.73
538        assert_eq!(sqrt_u512(U512::from(5u32)), U512::from(2u32)); // sqrt(5) ≈ 2.23
539        assert_eq!(sqrt_u512(U512::from(8u32)), U512::from(2u32)); // sqrt(8) ≈ 2.82
540        assert_eq!(sqrt_u512(U512::from(10u32)), U512::from(3u32)); // sqrt(10) ≈ 3.16
541        assert_eq!(sqrt_u512(U512::from(15u32)), U512::from(3u32)); // sqrt(15) ≈ 3.87
542        assert_eq!(sqrt_u512(U512::from(99u32)), U512::from(9u32)); // sqrt(99) ≈ 9.94
543
544        // Test large values
545        let large = U512::from_str("1000000000000000000000000000000000000").unwrap();
546        let sqrt_large = sqrt_u512(large);
547        // Verify that sqrt_large^2 <= large < (sqrt_large + 1)^2
548        assert!(sqrt_large * sqrt_large <= large);
549        assert!((sqrt_large + U512::from(1u32)) * (sqrt_large + U512::from(1u32)) > large);
550    }
551}