1use num_traits::ToPrimitive;
4use stable_swap_client::fees::Fees;
5
6const MAX: u64 = 1 << 32;
7const MAX_BIG: u64 = 1 << 48;
8const MAX_SMALL: u64 = 1 << 16;
9
10#[inline(always)]
13pub fn mul_div(a: u64, b: u64, c: u64) -> Option<u64> {
14 if a > MAX || b > MAX {
15 (a as u128)
16 .checked_mul(b as u128)?
17 .checked_div(c as u128)?
18 .to_u64()
19 } else {
20 a.checked_mul(b)?.checked_div(c)
21 }
22}
23
24#[inline(always)]
27pub fn mul_div_imbalanced(a: u64, b: u64, c: u64) -> Option<u64> {
28 if a > MAX_BIG || b > MAX_SMALL {
29 (a as u128)
30 .checked_mul(b as u128)?
31 .checked_div(c as u128)?
32 .to_u64()
33 } else {
34 a.checked_mul(b)?.checked_div(c)
35 }
36}
37
38pub trait FeeCalculator {
40 fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64>;
42 fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64>;
44 fn trade_fee(&self, trade_amount: u64) -> Option<u64>;
46 fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64>;
48 fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64>;
50}
51
52impl FeeCalculator for Fees {
53 fn admin_trade_fee(&self, fee_amount: u64) -> Option<u64> {
55 mul_div_imbalanced(
56 fee_amount,
57 self.admin_trade_fee_numerator,
58 self.admin_trade_fee_denominator,
59 )
60 }
61
62 fn admin_withdraw_fee(&self, fee_amount: u64) -> Option<u64> {
64 mul_div_imbalanced(
65 fee_amount,
66 self.admin_withdraw_fee_numerator,
67 self.admin_withdraw_fee_denominator,
68 )
69 }
70
71 fn trade_fee(&self, trade_amount: u64) -> Option<u64> {
73 mul_div_imbalanced(
74 trade_amount,
75 self.trade_fee_numerator,
76 self.trade_fee_denominator,
77 )
78 }
79
80 fn withdraw_fee(&self, withdraw_amount: u64) -> Option<u64> {
82 mul_div_imbalanced(
83 withdraw_amount,
84 self.withdraw_fee_numerator,
85 self.withdraw_fee_denominator,
86 )
87 }
88
89 fn normalized_trade_fee(&self, n_coins: u8, amount: u64) -> Option<u64> {
91 let adjusted_trade_fee_numerator = mul_div(
95 self.trade_fee_numerator,
96 n_coins.into(),
97 (n_coins.checked_sub(1)?).checked_mul(4)?.into(),
98 )?;
99
100 mul_div(
101 amount,
102 adjusted_trade_fee_numerator,
103 self.trade_fee_denominator,
104 )
105 }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn fee_results() {
115 let admin_trade_fee_numerator = 1;
116 let admin_trade_fee_denominator = 2;
117 let admin_withdraw_fee_numerator = 3;
118 let admin_withdraw_fee_denominator = 4;
119 let trade_fee_numerator = 5;
120 let trade_fee_denominator = 6;
121 let withdraw_fee_numerator = 7;
122 let withdraw_fee_denominator = 8;
123 let fees = Fees {
124 admin_trade_fee_numerator,
125 admin_trade_fee_denominator,
126 admin_withdraw_fee_numerator,
127 admin_withdraw_fee_denominator,
128 trade_fee_numerator,
129 trade_fee_denominator,
130 withdraw_fee_numerator,
131 withdraw_fee_denominator,
132 };
133
134 let trade_amount = 1_000_000_000;
135 let expected_trade_fee = trade_amount * trade_fee_numerator / trade_fee_denominator;
136 let trade_fee = fees.trade_fee(trade_amount).unwrap();
137 assert_eq!(trade_fee, expected_trade_fee);
138 let expected_admin_trade_fee =
139 expected_trade_fee * admin_trade_fee_numerator / admin_trade_fee_denominator;
140 assert_eq!(
141 fees.admin_trade_fee(trade_fee).unwrap(),
142 expected_admin_trade_fee
143 );
144
145 let withdraw_amount = 100_000_000_000;
146 let expected_withdraw_fee =
147 withdraw_amount * withdraw_fee_numerator / withdraw_fee_denominator;
148 let withdraw_fee = fees.withdraw_fee(withdraw_amount).unwrap();
149 assert_eq!(withdraw_fee, expected_withdraw_fee);
150 let expected_admin_withdraw_fee =
151 expected_withdraw_fee * admin_withdraw_fee_numerator / admin_withdraw_fee_denominator;
152 assert_eq!(
153 fees.admin_withdraw_fee(expected_withdraw_fee).unwrap(),
154 expected_admin_withdraw_fee
155 );
156
157 let n_coins: u8 = 2;
158 let adjusted_trade_fee_numerator: u64 =
159 trade_fee_numerator * (n_coins as u64) / (4 * ((n_coins as u64) - 1));
160 let expected_normalized_fee =
161 trade_amount * adjusted_trade_fee_numerator / trade_fee_denominator;
162 assert_eq!(
163 fees.normalized_trade_fee(n_coins, trade_amount).unwrap(),
164 expected_normalized_fee
165 );
166 }
167}