secret_cosmwasm_std/math/
isqrt.rs

1use std::{cmp, ops};
2
3use crate::{Uint128, Uint256, Uint512, Uint64};
4
5/// A trait for calculating the
6/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
7pub trait Isqrt {
8    /// The [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
9    fn isqrt(self) -> Self;
10}
11
12impl<I> Isqrt for I
13where
14    I: Unsigned
15        + ops::Add<I, Output = I>
16        + ops::Div<I, Output = I>
17        + ops::Shr<u32, Output = I>
18        + cmp::PartialOrd
19        + Copy
20        + From<u8>,
21{
22    /// Algorithm adapted from
23    /// [Wikipedia](https://en.wikipedia.org/wiki/Integer_square_root#Example_implementation_in_C).
24    fn isqrt(self) -> Self {
25        let mut x0 = self >> 1;
26
27        if x0 > 0.into() {
28            let mut x1 = (x0 + self / x0) >> 1;
29
30            while x1 < x0 {
31                x0 = x1;
32                x1 = (x0 + self / x0) >> 1;
33            }
34
35            return x0;
36        }
37        self
38    }
39}
40
41/// Marker trait for types that represent unsigned integers.
42pub trait Unsigned {}
43impl Unsigned for u8 {}
44impl Unsigned for u16 {}
45impl Unsigned for u32 {}
46impl Unsigned for u64 {}
47impl Unsigned for u128 {}
48impl Unsigned for Uint64 {}
49impl Unsigned for Uint128 {}
50impl Unsigned for Uint256 {}
51impl Unsigned for Uint512 {}
52impl Unsigned for usize {}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[test]
59    fn isqrt_primitives() {
60        // Let's check correctness.
61        assert_eq!(0u8.isqrt(), 0);
62        assert_eq!(1u8.isqrt(), 1);
63        assert_eq!(24u8.isqrt(), 4);
64        assert_eq!(25u8.isqrt(), 5);
65        assert_eq!(26u8.isqrt(), 5);
66        assert_eq!(36u8.isqrt(), 6);
67
68        // Let's also check different types.
69        assert_eq!(26u8.isqrt(), 5);
70        assert_eq!(26u16.isqrt(), 5);
71        assert_eq!(26u32.isqrt(), 5);
72        assert_eq!(26u64.isqrt(), 5);
73        assert_eq!(26u128.isqrt(), 5);
74    }
75
76    #[test]
77    fn isqrt_uint64() {
78        assert_eq!(Uint64::new(24).isqrt(), Uint64::new(4));
79    }
80
81    #[test]
82    fn isqrt_uint128() {
83        assert_eq!(Uint128::new(24).isqrt(), Uint128::new(4));
84    }
85
86    #[test]
87    fn isqrt_uint256() {
88        assert_eq!(Uint256::from(24u32).isqrt(), Uint256::from(4u32));
89        assert_eq!(
90            (Uint256::from(u128::MAX) * Uint256::from(u128::MAX)).isqrt(),
91            Uint256::try_from("340282366920938463463374607431768211455").unwrap()
92        );
93    }
94
95    #[test]
96    fn isqrt_uint512() {
97        assert_eq!(Uint512::from(24u32).isqrt(), Uint512::from(4u32));
98        assert_eq!(
99            (Uint512::from(Uint256::MAX) * Uint512::from(Uint256::MAX)).isqrt(),
100            Uint512::try_from(
101                "115792089237316195423570985008687907853269984665640564039457584007913129639935"
102            )
103            .unwrap()
104        );
105    }
106}