stable_swap_math/
math.rs

1//! Math helpers
2
3use num_traits::ToPrimitive;
4use stable_swap_client::fees::Fees;
5
6const MAX: u64 = 1 << 32;
7const MAX_BIG: u64 = 1 << 48;
8const MAX_SMALL: u64 = 1 << 16;
9
10/// Multiplies two u64s then divides by the third number.
11/// This function attempts to use 64 bit math if possible.
12#[inline(always)]
13pub fn mul_div(a: u64, b: u64, c: u64) -> Option<u64> {
14    if a > MAX || b > MAX {
15        (a as u128)
16            .checked_mul(b as u128)?
17            .checked_div(c as u128)?
18            .to_u64()
19    } else {
20        a.checked_mul(b)?.checked_div(c)
21    }
22}
23
24/// Multiplies two u64s then divides by the third number.
25/// This assumes that a > b.
26#[inline(always)]
27pub fn mul_div_imbalanced(a: u64, b: u64, c: u64) -> Option<u64> {
28    if a > MAX_BIG || b > MAX_SMALL {
29        (a as u128)
30            .checked_mul(b as u128)?
31            .checked_div(c as u128)?
32            .to_u64()
33    } else {
34        a.checked_mul(b)?.checked_div(c)
35    }
36}
37
38/// Calculates fees.
39pub trait FeeCalculator {
40    /// Applies the admin trade fee.
41    fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64>;
42    /// Applies the admin withdraw fee.
43    fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64>;
44    /// Applies the trade fee.
45    fn trade_fee(&self, trade_amount: u64) -> Option<u64>;
46    /// Applies the withdraw fee.
47    fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64>;
48    /// Applies the normalized trade fee.
49    fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64>;
50}
51
52impl FeeCalculator for Fees {
53    /// Apply admin trade fee
54    fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64> {
55        mul_div_imbalanced(
56            fee_amount,
57            self.admin_trade_fee_numerator,
58            self.admin_trade_fee_denominator,
59        )
60    }
61
62    /// Apply admin withdraw fee
63    fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64> {
64        mul_div_imbalanced(
65            fee_amount,
66            self.admin_withdraw_fee_numerator,
67            self.admin_withdraw_fee_denominator,
68        )
69    }
70
71    /// Compute trade fee from amount
72    fn trade_fee(&self, trade_amount: u64) -> Option<u64> {
73        mul_div_imbalanced(
74            trade_amount,
75            self.trade_fee_numerator,
76            self.trade_fee_denominator,
77        )
78    }
79
80    /// Compute withdraw fee from amount
81    fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64> {
82        mul_div_imbalanced(
83            withdraw_amount,
84            self.withdraw_fee_numerator,
85            self.withdraw_fee_denominator,
86        )
87    }
88
89    /// Compute normalized fee for symmetric/asymmetric deposits/withdraws
90    fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64> {
91        // adjusted_fee_numerator: uint256 = self.fee * N_COINS / (4 * (N_COINS - 1))
92        // The number 4 comes from Curve, originating from some sort of calculus
93        // https://github.com/curvefi/curve-contract/blob/e5fb8c0e0bcd2fe2e03634135806c0f36b245511/tests/simulation.py#L124
94        let adjusted_trade_fee_numerator = mul_div(
95            self.trade_fee_numerator,
96            n_coins.into(),
97            (n_coins.checked_sub(1)?).checked_mul(4)?.into(),
98        )?;
99
100        mul_div(
101            amount,
102            adjusted_trade_fee_numerator,
103            self.trade_fee_denominator,
104        )
105    }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn fee_results() {
115        let admin_trade_fee_numerator = 1;
116        let admin_trade_fee_denominator = 2;
117        let admin_withdraw_fee_numerator = 3;
118        let admin_withdraw_fee_denominator = 4;
119        let trade_fee_numerator = 5;
120        let trade_fee_denominator = 6;
121        let withdraw_fee_numerator = 7;
122        let withdraw_fee_denominator = 8;
123        let fees = Fees {
124            admin_trade_fee_numerator,
125            admin_trade_fee_denominator,
126            admin_withdraw_fee_numerator,
127            admin_withdraw_fee_denominator,
128            trade_fee_numerator,
129            trade_fee_denominator,
130            withdraw_fee_numerator,
131            withdraw_fee_denominator,
132        };
133
134        let trade_amount = 1_000_000_000;
135        let expected_trade_fee = trade_amount * trade_fee_numerator / trade_fee_denominator;
136        let trade_fee = fees.trade_fee(trade_amount).unwrap();
137        assert_eq!(trade_fee, expected_trade_fee);
138        let expected_admin_trade_fee =
139            expected_trade_fee * admin_trade_fee_numerator / admin_trade_fee_denominator;
140        assert_eq!(
141            fees.admin_trade_fee(trade_fee).unwrap(),
142            expected_admin_trade_fee
143        );
144
145        let withdraw_amount = 100_000_000_000;
146        let expected_withdraw_fee =
147            withdraw_amount * withdraw_fee_numerator / withdraw_fee_denominator;
148        let withdraw_fee = fees.withdraw_fee(withdraw_amount).unwrap();
149        assert_eq!(withdraw_fee, expected_withdraw_fee);
150        let expected_admin_withdraw_fee =
151            expected_withdraw_fee * admin_withdraw_fee_numerator / admin_withdraw_fee_denominator;
152        assert_eq!(
153            fees.admin_withdraw_fee(expected_withdraw_fee).unwrap(),
154            expected_admin_withdraw_fee
155        );
156
157        let n_coins: u8 = 2;
158        let adjusted_trade_fee_numerator: u64 =
159            trade_fee_numerator * (n_coins as u64) / (4 * ((n_coins as u64) - 1));
160        let expected_normalized_fee =
161            trade_amount * adjusted_trade_fee_numerator / trade_fee_denominator;
162        assert_eq!(
163            fees.normalized_trade_fee(n_coins, trade_amount).unwrap(),
164            expected_normalized_fee
165        );
166    }
167}