Skip to main content

tycho_simulation/evm/protocol/
safe_math.rs

1//! Safe Math
2//!
3//! This module contains basic functions to perform arithmetic operations on
4//! numerical types of the alloy crate and preventing them from overflowing.
5//! Should an operation cause an overflow a result containing TradeSimulationError
6//! will be returned.
7//! Functions for the types I256, U256, U512 are available.
8use alloy::primitives::{I256, U256, U512};
9use tycho_common::simulation::errors::SimulationError;
10
11pub fn safe_mul_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
12    let res = a.checked_mul(b);
13    _construc_result_u256(res)
14}
15
16pub fn safe_div_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
17    if b.is_zero() {
18        return Err(SimulationError::FatalError("Division by zero".to_string()));
19    }
20    let res = a.checked_div(b);
21    _construc_result_u256(res)
22}
23
24pub fn safe_add_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
25    let res = a.checked_add(b);
26    _construc_result_u256(res)
27}
28
29pub fn safe_sub_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
30    let res = a.checked_sub(b);
31    _construc_result_u256(res)
32}
33
34pub fn div_mod_u256(a: U256, b: U256) -> Result<(U256, U256), SimulationError> {
35    if b.is_zero() {
36        return Err(SimulationError::FatalError("Division by zero".to_string()));
37    }
38    let result = a / b;
39    let rest = a % b;
40    Ok((result, rest))
41}
42
43pub fn _construc_result_u256(res: Option<U256>) -> Result<U256, SimulationError> {
44    match res {
45        None => Err(SimulationError::FatalError("U256 arithmetic overflow".to_string())),
46        Some(value) => Ok(value),
47    }
48}
49
50pub fn safe_mul_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
51    let res = a.checked_mul(b);
52    _construc_result_u512(res)
53}
54
55pub fn safe_div_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
56    if b.is_zero() {
57        return Err(SimulationError::FatalError("Division by zero".to_string()));
58    }
59    let res = a.checked_div(b);
60    _construc_result_u512(res)
61}
62
63pub fn safe_add_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
64    let res = a.checked_add(b);
65    _construc_result_u512(res)
66}
67
68pub fn safe_sub_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
69    let res = a.checked_sub(b);
70    _construc_result_u512(res)
71}
72
73pub fn div_mod_u512(a: U512, b: U512) -> Result<(U512, U512), SimulationError> {
74    if b.is_zero() {
75        return Err(SimulationError::FatalError("Division by zero".to_string()));
76    }
77    let result = a / b;
78    let rest = a % b;
79    Ok((result, rest))
80}
81
82pub fn _construc_result_u512(res: Option<U512>) -> Result<U512, SimulationError> {
83    match res {
84        None => Err(SimulationError::FatalError("U512 arithmetic overflow".to_string())),
85        Some(value) => Ok(value),
86    }
87}
88
89pub fn safe_mul_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
90    let res = a.checked_mul(b);
91    _construc_result_i256(res)
92}
93
94pub fn safe_div_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
95    if b.is_zero() {
96        return Err(SimulationError::FatalError("Division by zero".to_string()));
97    }
98    let res = a.checked_div(b);
99    _construc_result_i256(res)
100}
101
102pub fn safe_add_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
103    let res = a.checked_add(b);
104    _construc_result_i256(res)
105}
106
107pub fn safe_sub_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
108    let res = a.checked_sub(b);
109    _construc_result_i256(res)
110}
111
112pub fn _construc_result_i256(res: Option<I256>) -> Result<I256, SimulationError> {
113    match res {
114        None => Err(SimulationError::FatalError("I256 arithmetic overflow".to_string())),
115        Some(value) => Ok(value),
116    }
117}
118
119/// Computes the integer square root of a U512 value using Newton's method.
120///
121/// Returns the floor of the square root.
122///
123/// # Algorithm
124///
125/// Uses Newton's method iteration:
126/// - Start with initial guess based on bit length
127/// - Iterate: x_new = (x + n/x) / 2
128/// - Stop when convergence is reached or value stops decreasing
129pub fn sqrt_u512(value: U512) -> U512 {
130    // Handle zero case
131    if value == U512::ZERO {
132        return U512::ZERO;
133    }
134
135    // Handle small values
136    if value == U512::from(1u32) {
137        return U512::from(1u32);
138    }
139
140    // Initial guess: use bit length to get approximate starting point
141    // For square root, start with 2^(bits/2)
142    let bits = 512 - value.leading_zeros();
143    let mut result = U512::from(1u32) << (bits / 2);
144
145    // Newton's method iteration for square root
146    // x_new = (x + n/x) / 2
147    let mut decreasing = false;
148    loop {
149        // Calculate: (value / result + result) / 2
150        let division = value / result;
151        let iter = (division + result) / U512::from(2u32);
152
153        // Check convergence
154        if iter == result {
155            // Hit fixed point, we're done
156            break;
157        }
158
159        if iter > result {
160            if decreasing {
161                // Was decreasing, now increasing - we've converged
162                break;
163            }
164            // Limit increase to prevent slow convergence
165            result =
166                if iter > result * U512::from(2u32) { result * U512::from(2u32) } else { iter };
167        } else {
168            // Converging downwards
169            decreasing = true;
170            result = iter;
171        }
172    }
173
174    result
175}
176
177/// Integer square root for U256, returning U256
178pub fn sqrt_u256(value: U256) -> Result<U256, SimulationError> {
179    if value == U256::ZERO {
180        return Ok(U256::ZERO);
181    }
182
183    let bits = 256 - value.leading_zeros();
184    let mut remainder = U256::ZERO;
185    let mut temp = U256::ZERO;
186    let result = compute_karatsuba_sqrt(value, &mut remainder, &mut temp, bits);
187
188    // Extract lower 256 bits
189    let limbs = result.as_limbs();
190    Ok(U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]]))
191}
192
193/// Recursive Karatsuba square root implementation
194/// Computes sqrt(x) and stores remainder in r
195/// Uses temp variable t for intermediate calculations
196/// Ref: https://hal.inria.fr/file/index/docid/72854/filename/RR-3805.pdf
197fn compute_karatsuba_sqrt(x: U256, r: &mut U256, t: &mut U256, bits: usize) -> U256 {
198    // Base case: exact integer floor sqrt once the value fits in one 64-bit limb.
199    // `u64::isqrt` returns floor(sqrt) directly; the root is at most u32::MAX, so
200    // `result * result` cannot overflow u64.
201    if bits <= 64 {
202        let x_small = x.as_limbs()[0];
203        let result = x_small.isqrt();
204        *r = x - U256::from(result * result);
205        return U256::from(result);
206    }
207
208    // Divide-and-conquer approach
209    // Split into quarters: process b bits at a time where b = bits/4
210    let b = bits / 4;
211
212    // q = x >> (2*b)  -- extract upper bits
213    let mut q = x >> (b * 2);
214
215    // Recursively compute sqrt of upper portion
216    let mut s = compute_karatsuba_sqrt(q, r, t, bits - b * 2);
217
218    // Build mask for extracting bits: (1 << (2*b)) - 1
219    *t = (U256::from(1u32) << (b * 2)) - U256::from(1u32);
220
221    // Extract middle bits and combine with remainder from recursive call
222    *r = (*r << b) | ((x & *t) >> b);
223
224    // Divide: t = r / (2*s), with quotient q and remainder r
225    s <<= 1;
226    q = *r / s;
227    *r -= q * s;
228
229    // Build s = (s << (b-1)) + q
230    s = (s << (b - 1)) + q;
231
232    // Extract lower b bits
233    *t = (U256::from(1u32) << b) - U256::from(1u32);
234    *r = (*r << b) | (x & *t);
235
236    // Compute q^2
237    let q_squared = q * q;
238
239    // Adjust if remainder is too small
240    if *r < q_squared {
241        *t = (s << 1) - U256::from(1u32);
242        *r += *t;
243        s -= U256::from(1u32);
244    }
245
246    *r -= q_squared;
247    s
248}
249
250#[cfg(test)]
251mod safe_math_tests {
252    use std::str::FromStr;
253
254    use rstest::rstest;
255
256    use super::*;
257
258    const U256_MAX: U256 = U256::from_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
259    const U512_MAX: U512 = U512::from_limbs([
260        u64::MAX,
261        u64::MAX,
262        u64::MAX,
263        u64::MAX,
264        u64::MAX,
265        u64::MAX,
266        u64::MAX,
267        u64::MAX,
268    ]);
269    /// I256 maximum value: 2^255 - 1
270    const I256_MAX: I256 = I256::from_raw(U256::from_limbs([
271        u64::MAX,
272        u64::MAX,
273        u64::MAX,
274        9223372036854775807u64, // 2^63 - 1 in the highest limb
275    ]));
276
277    /// I256 minimum value: -2^255
278    const I256_MIN: I256 = I256::from_raw(U256::from_limbs([
279        0,
280        0,
281        0,
282        9223372036854775808u64, // 2^63 in the highest limb
283    ]));
284
285    fn u256(s: &str) -> U256 {
286        U256::from_str(s).unwrap()
287    }
288
289    #[rstest]
290    #[case(U256_MAX, u256("2"), true, false, u256("0"))]
291    #[case(u256("3"), u256("2"), false, true, u256("6"))]
292    fn test_safe_mul_u256(
293        #[case] a: U256,
294        #[case] b: U256,
295        #[case] is_err: bool,
296        #[case] is_ok: bool,
297        #[case] expected: U256,
298    ) {
299        let res = safe_mul_u256(a, b);
300        assert_eq!(res.is_err(), is_err);
301        assert_eq!(res.is_ok(), is_ok);
302
303        if is_ok {
304            assert_eq!(res.unwrap(), expected);
305        }
306    }
307
308    #[rstest]
309    #[case(U256_MAX, u256("2"), true, false, u256("0"))]
310    #[case(u256("3"), u256("2"), false, true, u256("5"))]
311    fn test_safe_add_u256(
312        #[case] a: U256,
313        #[case] b: U256,
314        #[case] is_err: bool,
315        #[case] is_ok: bool,
316        #[case] expected: U256,
317    ) {
318        let res = safe_add_u256(a, b);
319        assert_eq!(res.is_err(), is_err);
320        assert_eq!(res.is_ok(), is_ok);
321
322        if is_ok {
323            assert_eq!(res.unwrap(), expected);
324        }
325    }
326
327    #[rstest]
328    #[case(u256("0"), u256("2"), true, false, u256("0"))]
329    #[case(u256("10"), u256("2"), false, true, u256("8"))]
330    fn test_safe_sub_u256(
331        #[case] a: U256,
332        #[case] b: U256,
333        #[case] is_err: bool,
334        #[case] is_ok: bool,
335        #[case] expected: U256,
336    ) {
337        let res = safe_sub_u256(a, b);
338        assert_eq!(res.is_err(), is_err);
339        assert_eq!(res.is_ok(), is_ok);
340
341        if is_ok {
342            assert_eq!(res.unwrap(), expected);
343        }
344    }
345
346    #[rstest]
347    #[case(u256("1"), u256("0"), true, false, u256("0"))]
348    #[case(u256("10"), u256("2"), false, true, u256("5"))]
349    fn test_safe_div_u256(
350        #[case] a: U256,
351        #[case] b: U256,
352        #[case] is_err: bool,
353        #[case] is_ok: bool,
354        #[case] expected: U256,
355    ) {
356        let res = safe_div_u256(a, b);
357        assert_eq!(res.is_err(), is_err);
358        assert_eq!(res.is_ok(), is_ok);
359
360        if is_ok {
361            assert_eq!(res.unwrap(), expected);
362        }
363    }
364
365    fn u512(s: &str) -> U512 {
366        U512::from_str(s).unwrap()
367    }
368
369    #[rstest]
370    #[case(U512_MAX, u512("2"), true, false, u512("0"))]
371    #[case(u512("3"), u512("2"), false, true, u512("6"))]
372    fn test_safe_mul_u512(
373        #[case] a: U512,
374        #[case] b: U512,
375        #[case] is_err: bool,
376        #[case] is_ok: bool,
377        #[case] expected: U512,
378    ) {
379        let res = safe_mul_u512(a, b);
380        assert_eq!(res.is_err(), is_err);
381        assert_eq!(res.is_ok(), is_ok);
382
383        if is_ok {
384            assert_eq!(res.unwrap(), expected);
385        }
386    }
387
388    #[rstest]
389    #[case(U512_MAX, u512("2"), true, false, u512("0"))]
390    #[case(u512("3"), u512("2"), false, true, u512("5"))]
391    fn test_safe_add_u512(
392        #[case] a: U512,
393        #[case] b: U512,
394        #[case] is_err: bool,
395        #[case] is_ok: bool,
396        #[case] expected: U512,
397    ) {
398        let res = safe_add_u512(a, b);
399        assert_eq!(res.is_err(), is_err);
400        assert_eq!(res.is_ok(), is_ok);
401
402        if is_ok {
403            assert_eq!(res.unwrap(), expected);
404        }
405    }
406
407    #[rstest]
408    #[case(u512("0"), u512("2"), true, false, u512("0"))]
409    #[case(u512("10"), u512("2"), false, true, u512("8"))]
410    fn test_safe_sub_u512(
411        #[case] a: U512,
412        #[case] b: U512,
413        #[case] is_err: bool,
414        #[case] is_ok: bool,
415        #[case] expected: U512,
416    ) {
417        let res = safe_sub_u512(a, b);
418        assert_eq!(res.is_err(), is_err);
419        assert_eq!(res.is_ok(), is_ok);
420
421        if is_ok {
422            assert_eq!(res.unwrap(), expected);
423        }
424    }
425
426    #[rstest]
427    #[case(u512("1"), u512("0"), true, false, u512("0"))]
428    #[case(u512("10"), u512("2"), false, true, u512("5"))]
429    fn test_safe_div_u512(
430        #[case] a: U512,
431        #[case] b: U512,
432        #[case] is_err: bool,
433        #[case] is_ok: bool,
434        #[case] expected: U512,
435    ) {
436        let res = safe_div_u512(a, b);
437        assert_eq!(res.is_err(), is_err);
438        assert_eq!(res.is_ok(), is_ok);
439
440        if is_ok {
441            assert_eq!(res.unwrap(), expected);
442        }
443    }
444
445    fn i256(s: &str) -> I256 {
446        I256::from_str(s).unwrap()
447    }
448
449    #[rstest]
450    #[case(I256_MAX, i256("2"), true, false, i256("0"))]
451    #[case(i256("3"), i256("2"), false, true, i256("6"))]
452    fn test_safe_mul_i256(
453        #[case] a: I256,
454        #[case] b: I256,
455        #[case] is_err: bool,
456        #[case] is_ok: bool,
457        #[case] expected: I256,
458    ) {
459        let res = safe_mul_i256(a, b);
460        assert_eq!(res.is_err(), is_err);
461        assert_eq!(res.is_ok(), is_ok);
462
463        if is_ok {
464            assert_eq!(res.unwrap(), expected);
465        }
466    }
467
468    #[rstest]
469    #[case(I256_MAX, i256("2"), true, false, i256("0"))]
470    #[case(i256("3"), i256("2"), false, true, i256("5"))]
471    fn test_safe_add_i256(
472        #[case] a: I256,
473        #[case] b: I256,
474        #[case] is_err: bool,
475        #[case] is_ok: bool,
476        #[case] expected: I256,
477    ) {
478        let res = safe_add_i256(a, b);
479        assert_eq!(res.is_err(), is_err);
480        assert_eq!(res.is_ok(), is_ok);
481
482        if is_ok {
483            assert_eq!(res.unwrap(), expected);
484        }
485    }
486
487    #[rstest]
488    #[case(I256_MIN, i256("2"), true, false, i256("0"))]
489    #[case(i256("10"), i256("2"), false, true, i256("8"))]
490    fn test_safe_sub_i256(
491        #[case] a: I256,
492        #[case] b: I256,
493        #[case] is_err: bool,
494        #[case] is_ok: bool,
495        #[case] expected: I256,
496    ) {
497        let res = safe_sub_i256(a, b);
498        assert_eq!(res.is_err(), is_err);
499        assert_eq!(res.is_ok(), is_ok);
500
501        if is_ok {
502            assert_eq!(res.unwrap(), expected);
503        }
504    }
505
506    #[rstest]
507    #[case(i256("1"), i256("0"), true, false, i256("0"))]
508    #[case(i256("10"), i256("2"), false, true, i256("5"))]
509    fn test_safe_div_i256(
510        #[case] a: I256,
511        #[case] b: I256,
512        #[case] is_err: bool,
513        #[case] is_ok: bool,
514        #[case] expected: I256,
515    ) {
516        let res = safe_div_i256(a, b);
517        assert_eq!(res.is_err(), is_err);
518        assert_eq!(res.is_ok(), is_ok);
519
520        if is_ok {
521            assert_eq!(res.unwrap(), expected);
522        }
523    }
524
525    #[test]
526    fn test_sqrt_u512() {
527        // Test edge cases
528        assert_eq!(sqrt_u512(U512::ZERO), U512::ZERO);
529        assert_eq!(sqrt_u512(U512::from(1u32)), U512::from(1u32));
530
531        // Test small perfect squares
532        assert_eq!(sqrt_u512(U512::from(4u32)), U512::from(2u32));
533        assert_eq!(sqrt_u512(U512::from(100u32)), U512::from(10u32));
534        assert_eq!(sqrt_u512(U512::from(10000u32)), U512::from(100u32));
535        assert_eq!(sqrt_u512(U512::from(1000000u32)), U512::from(1000u32));
536
537        // For non-perfect squares, should return floor of sqrt
538        assert_eq!(sqrt_u512(U512::from(2u32)), U512::from(1u32)); // sqrt(2) ≈ 1.41
539        assert_eq!(sqrt_u512(U512::from(3u32)), U512::from(1u32)); // sqrt(3) ≈ 1.73
540        assert_eq!(sqrt_u512(U512::from(5u32)), U512::from(2u32)); // sqrt(5) ≈ 2.23
541        assert_eq!(sqrt_u512(U512::from(8u32)), U512::from(2u32)); // sqrt(8) ≈ 2.82
542        assert_eq!(sqrt_u512(U512::from(10u32)), U512::from(3u32)); // sqrt(10) ≈ 3.16
543        assert_eq!(sqrt_u512(U512::from(15u32)), U512::from(3u32)); // sqrt(15) ≈ 3.87
544        assert_eq!(sqrt_u512(U512::from(99u32)), U512::from(9u32)); // sqrt(99) ≈ 9.94
545
546        // Test large values
547        let large = U512::from_str("1000000000000000000000000000000000000").unwrap();
548        let sqrt_large = sqrt_u512(large);
549        // Verify that sqrt_large^2 <= large < (sqrt_large + 1)^2
550        assert!(sqrt_large * sqrt_large <= large);
551        assert!((sqrt_large + U512::from(1u32)) * (sqrt_large + U512::from(1u32)) > large);
552    }
553
554    // u64::MAX as f64 rounds up to 2^64, so the unclamped f64 seed would be 2^32 and
555    // its square would overflow u64.
556    #[test]
557    fn test_sqrt_u256_u64_max() {
558        let result = sqrt_u256(U256::from(u64::MAX)).unwrap();
559        assert_eq!(result, U256::from(u32::MAX));
560    }
561
562    // 67108865^2 - 1: the f64 seed rounds up by 1; the correction step must restore the
563    // floor invariant result² ≤ x < (result+1)².
564    #[test]
565    fn test_sqrt_u256_floor_near_perfect_square() {
566        let x = U256::from(67108865u64 * 67108865u64 - 1);
567        let result = sqrt_u256(x).unwrap();
568        assert_eq!(result, U256::from(67108864u64));
569    }
570
571    // Whole-input-range sweep: boundary values and deterministic pseudo-random draws at
572    // every bit length 1..=256. The floor invariant result² ≤ x < (result+1)² fully
573    // characterizes integer sqrt, checked in U512 so x near U256::MAX cannot overflow.
574    #[test]
575    fn test_sqrt_u256_floor_invariant_full_range() {
576        let mut rng_state = 0x9E3779B97F4A7C15u64;
577        let mut next_rand = move || {
578            rng_state ^= rng_state >> 12;
579            rng_state ^= rng_state << 25;
580            rng_state ^= rng_state >> 27;
581            rng_state.wrapping_mul(0x2545F4914F6CDD1D)
582        };
583
584        let mut cases: Vec<U256> = vec![U256::ZERO, U256::from(1u64), U256::MAX];
585        for bits in 1..=256u32 {
586            let low = U256::from(1u64) << (bits - 1);
587            let high =
588                if bits == 256 { U256::MAX } else { (U256::from(1u64) << bits) - U256::from(1u64) };
589            cases.push(low);
590            cases.push(high);
591            for _ in 0..4 {
592                let mut draw = U256::ZERO;
593                for limb in 0..4 {
594                    draw |= U256::from(next_rand()) << (64 * limb);
595                }
596                cases.push(low + draw % (high - low + U256::from(1u64)));
597            }
598        }
599
600        for x in cases {
601            let result = sqrt_u256(x).unwrap();
602            let wide = U512::from(result);
603            let x_wide = U512::from(x);
604            assert!(wide * wide <= x_wide, "floor violated for x={x}");
605            let next = wide + U512::from(1u64);
606            assert!(next * next > x_wide, "not the greatest root for x={x}");
607        }
608    }
609
610    #[test]
611    fn test_sqrt_u256_floor_invariant_in_base_case_range() {
612        for x_small in [
613            0u64,
614            1,
615            2,
616            3,
617            4,
618            (1 << 26) - 1,
619            1 << 26,
620            (1 << 53) - 1,
621            1 << 53,
622            (1 << 53) + 1,
623            67108864 * 67108864,
624            u32::MAX as u64,
625            (u32::MAX as u64).pow(2),
626            (u32::MAX as u64).pow(2) - 1,
627            u64::MAX - 1,
628            u64::MAX,
629        ] {
630            let x = U256::from(x_small);
631            let result = sqrt_u256(x).unwrap();
632            assert!(result * result <= x, "floor violated for {x_small}");
633            let next = result + U256::from(1u64);
634            assert!(next * next > x, "not the greatest root for {x_small}");
635        }
636    }
637
638    // Adversarial perfect-square and k²±1 inputs at large bit-lengths, where the recursion
639    // (not the base case) runs and the off-by-one correction in the combine step must hold.
640    // Every root stays below 2^128, so k² fits U256 without overflow.
641    #[test]
642    fn test_sqrt_u256_recursive_perfect_square_boundaries() {
643        let one = U256::from(1u64);
644        let roots = [
645            one << 33, // first root whose square (2^66) leaves the base case
646            one << 64,
647            (one << 96) - one,
648            one << 100,
649            (one << 120) + U256::from(12345u64),
650            (one << 127) - one,
651            (one << 128) - one, // floor(sqrt(U256::MAX))
652        ];
653        for k in roots {
654            let square = k * k;
655            assert_eq!(sqrt_u256(square).unwrap(), k, "sqrt(k²) for k={k}");
656            assert_eq!(sqrt_u256(square - one).unwrap(), k - one, "sqrt(k²−1) for k={k}");
657            // k² + 1 < (k+1)² for k >= 1, so the floor stays at k.
658            assert_eq!(sqrt_u256(square + one).unwrap(), k, "sqrt(k²+1) for k={k}");
659        }
660    }
661}