tensor_amm/hooked/
pool.rs

1use std::fmt::{self, Display, Formatter};
2
3use crate::accounts::Pool;
4use crate::errors::TensorAmmError;
5use crate::types::{
6    CurveType, Direction, EditPoolConfig, PoolConfig, PoolStats, PoolType, TakerSide,
7};
8use crate::HUNDRED_PCT_BPS;
9
10use spl_math::precise_number::PreciseNumber;
11
12#[allow(clippy::derivable_impls)]
13impl Default for PoolStats {
14    fn default() -> Self {
15        Self {
16            taker_sell_count: 0,
17            taker_buy_count: 0,
18            accumulated_mm_profit: 0,
19        }
20    }
21}
22
23impl Pool {
24    /// Shifts the price of a pool by a certain offset.
25    pub fn shift_price(&self, price_offset: i32, side: TakerSide) -> Result<u64, TensorAmmError> {
26        let direction = if price_offset > 0 {
27            Direction::Up
28        } else {
29            Direction::Down
30        };
31
32        let offset = price_offset.unsigned_abs();
33
34        let current_price = match self.config.curve_type {
35            CurveType::Linear => {
36                let base = self.config.starting_price;
37                let delta = self.config.delta;
38
39                match direction {
40                    Direction::Up => base
41                        .checked_add(
42                            delta
43                                .checked_mul(offset as u64)
44                                .ok_or(TensorAmmError::ArithmeticError)?,
45                        )
46                        .ok_or(TensorAmmError::ArithmeticError)?,
47                    Direction::Down => base
48                        .checked_sub(
49                            delta
50                                .checked_mul(offset as u64)
51                                .ok_or(TensorAmmError::ArithmeticError)?,
52                        )
53                        .ok_or(TensorAmmError::ArithmeticError)?,
54                }
55            }
56            CurveType::Exponential => {
57                let hundred_pct = PreciseNumber::new(HUNDRED_PCT_BPS.into())
58                    .ok_or(TensorAmmError::ArithmeticError)?;
59
60                let base = PreciseNumber::new(self.config.starting_price.into())
61                    .ok_or(TensorAmmError::ArithmeticError)?;
62
63                let factor = PreciseNumber::new(
64                    (HUNDRED_PCT_BPS)
65                        .checked_add(self.config.delta)
66                        .ok_or(TensorAmmError::ArithmeticError)?
67                        .into(),
68                )
69                .ok_or(TensorAmmError::ArithmeticError)?
70                .checked_div(&hundred_pct)
71                .ok_or(TensorAmmError::ArithmeticError)?
72                .checked_pow(offset.into())
73                .ok_or(TensorAmmError::ArithmeticError)?;
74
75                let result = match direction {
76                    // price * (1 + delta)^trade_count
77                    Direction::Up => base.checked_mul(&factor),
78                    //same but / instead of *
79                    Direction::Down => base.checked_div(&factor),
80                };
81
82                let rounded_result = match side {
83                    TakerSide::Buy => result.ok_or(TensorAmmError::ArithmeticError)?.ceiling(),
84                    TakerSide::Sell => result.ok_or(TensorAmmError::ArithmeticError)?.floor(),
85                };
86
87                let imprecise = rounded_result
88                    .ok_or(TensorAmmError::ArithmeticError)?
89                    .to_imprecise()
90                    .ok_or(TensorAmmError::ArithmeticError)?;
91
92                u64::try_from(imprecise)
93                    .ok()
94                    .ok_or(TensorAmmError::ArithmeticError)?
95            }
96        };
97
98        Ok(current_price)
99    }
100
101    /// Calculate the price of the pool after shifting it by a certain offset.
102    pub fn current_price(&self, side: TakerSide) -> Result<u64, TensorAmmError> {
103        match (self.config.pool_type, side) {
104            (PoolType::Trade, TakerSide::Buy)
105            | (PoolType::Token, TakerSide::Sell)
106            | (PoolType::NFT, TakerSide::Buy) => self.shift_price(self.price_offset, side),
107
108            // Trade pool sells require the price to be shifted down by 1 to prevent
109            // liquidity from being drained by repeated matched buys and sells.
110            (PoolType::Trade, TakerSide::Sell) => self.shift_price(self.price_offset - 1, side),
111
112            // Invalid combinations of pool type and side.
113            _ => Err(TensorAmmError::WrongPoolType),
114        }
115    }
116
117    /// Calculate the fee the MM receives when providing liquidity to a two-sided pool.
118    pub fn calc_mm_fee(&self, current_price: u64) -> Result<u64, TensorAmmError> {
119        let fee = match self.config.pool_type {
120            PoolType::Trade => (self.config.mm_fee_bps.into_base() as u64)
121                .checked_mul(current_price)
122                .ok_or(TensorAmmError::ArithmeticError)?
123                .checked_div(HUNDRED_PCT_BPS)
124                .ok_or(TensorAmmError::ArithmeticError)?,
125            PoolType::NFT | PoolType::Token => 0, // No mm fees for NFT or Token pools
126        };
127
128        Ok(fee)
129    }
130}
131
132impl Display for PoolType {
133    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
134        match self {
135            PoolType::Trade => write!(f, "Trade"),
136            PoolType::Token => write!(f, "Token"),
137            PoolType::NFT => write!(f, "NFT"),
138        }
139    }
140}
141
142impl Display for CurveType {
143    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
144        match self {
145            CurveType::Linear => write!(f, "Linear"),
146            CurveType::Exponential => write!(f, "Exponential"),
147        }
148    }
149}
150
151impl EditPoolConfig {
152    pub fn into_pool_config(self, pool_type: PoolType) -> PoolConfig {
153        PoolConfig {
154            pool_type,
155            curve_type: self.curve_type,
156            starting_price: self.starting_price,
157            delta: self.delta,
158            mm_compound_fees: self.mm_compound_fees,
159            mm_fee_bps: self.mm_fee_bps,
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    use solana_program::pubkey::Pubkey;
169
170    use crate::{
171        types::{PoolConfig, PoolStats},
172        Currency, NullableAddress, NullableU16, LAMPORTS_PER_SOL,
173    };
174
175    impl Pool {
176        pub fn new_test_pool(
177            pool_type: PoolType,
178            curve_type: CurveType,
179            starting_price: u64,
180            delta: u64,
181            price_offset: i32,
182            mm_fee_bps: NullableU16,
183        ) -> Self {
184            Self {
185                discriminator: [0; 8],
186                version: 1,
187                bump: [1],
188                created_at: 1234,
189                updated_at: 0,
190                expiry: 0,
191                owner: Pubkey::default(),
192                cosigner: NullableAddress::none(),
193                maker_broker: NullableAddress::none(),
194                rent_payer: Pubkey::default(),
195                whitelist: Pubkey::default(),
196                pool_id: [0; 32],
197                config: PoolConfig {
198                    pool_type,
199                    curve_type,
200                    starting_price,
201                    delta,
202                    mm_compound_fees: true,
203                    mm_fee_bps,
204                },
205                price_offset,
206                nfts_held: 0,
207                stats: PoolStats::default(),
208                currency: Currency::sol(),
209                amount: 0,
210                shared_escrow: NullableAddress::none(),
211                max_taker_sell_count: 10,
212                reserved: [0; 100],
213            }
214        }
215    }
216
217    // --------------------------------------- Linear
218
219    // Token Pool
220
221    #[test]
222    fn test_linear_token_pool() {
223        let delta = LAMPORTS_PER_SOL / 10;
224        let mut p = Pool::new_test_pool(
225            PoolType::Token,
226            CurveType::Linear,
227            LAMPORTS_PER_SOL,
228            delta,
229            0,
230            NullableU16::none(),
231        );
232        assert_eq!(p.current_price(TakerSide::Sell).unwrap(), LAMPORTS_PER_SOL);
233
234        // The pool has bought 1 NFT so has a trade "deficit".
235        // The price should be shifted down by 1 delta.
236        p.price_offset -= 1;
237        assert_eq!(
238            p.current_price(TakerSide::Sell).unwrap(),
239            LAMPORTS_PER_SOL - delta
240        );
241
242        // The pool has bought 2 additional NFTs so has a trade "deficit" of 3.
243        // The price should be shifted down by 3 deltas.
244        p.price_offset -= 2;
245        assert_eq!(
246            p.current_price(TakerSide::Sell).unwrap(),
247            LAMPORTS_PER_SOL - delta * 3
248        );
249
250        // The pool has bought 7 additional NFTs so has a trade "deficit" of 10.
251        // The price should be shifted down by 10 deltas.
252        p.price_offset -= 7;
253        assert_eq!(
254            p.current_price(TakerSide::Sell).unwrap(),
255            LAMPORTS_PER_SOL - delta * 10
256        );
257
258        // The current price should now be zero, because the pool has spent all its SOL.
259        assert_eq!(p.current_price(TakerSide::Sell).unwrap(), 0);
260    }
261
262    #[test]
263    #[should_panic(expected = "ArithmeticError")]
264    fn test_linear_token_pool_panic_overflow() {
265        let delta = LAMPORTS_PER_SOL / 10;
266        let p = Pool::new_test_pool(
267            PoolType::Token,
268            CurveType::Linear,
269            LAMPORTS_PER_SOL,
270            delta,
271            -11,
272            NullableU16::none(),
273        );
274        // Should overflow when we calculate the current price
275        // because the trade difference is more than the maximum
276        // and we cannot have a negative price.
277        p.current_price(TakerSide::Sell).unwrap();
278    }
279
280    #[test]
281    #[should_panic(expected = "WrongPoolType")]
282    fn test_linear_token_pool_panic_on_buy() {
283        let delta = LAMPORTS_PER_SOL / 10;
284        let p = Pool::new_test_pool(
285            PoolType::Token,
286            CurveType::Linear,
287            LAMPORTS_PER_SOL,
288            delta,
289            0,
290            NullableU16::none(),
291        );
292        // Token pools only buy NFTs (seller sells into them),
293        // so the taker side cannot be buy.
294        p.current_price(TakerSide::Buy).unwrap();
295    }
296
297    // NFT Pool
298
299    #[test]
300    fn test_linear_nft_pool() {
301        let delta = LAMPORTS_PER_SOL / 10;
302        let mut p = Pool::new_test_pool(
303            PoolType::NFT,
304            CurveType::Linear,
305            LAMPORTS_PER_SOL,
306            delta,
307            0,
308            NullableU16::none(),
309        );
310        assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
311
312        //  Trade surplus because Pool has sold NFT to taker.
313        // Current price should be shifted up by 1 delta.
314        p.price_offset += 1;
315        assert_eq!(
316            p.current_price(TakerSide::Buy).unwrap(),
317            LAMPORTS_PER_SOL + delta
318        );
319
320        // Sell an additional 2 NFTs to taker and the trade surplus is 3.
321        p.price_offset += 2;
322        assert_eq!(
323            p.current_price(TakerSide::Buy).unwrap(),
324            LAMPORTS_PER_SOL + delta * 3
325        );
326
327        // Price continues to go up.
328        // Real pools will run out of NTFs to sell at some point,
329        // but the price calculation in this test should still go up.
330        p.price_offset += 9999996;
331        assert_eq!(
332            p.current_price(TakerSide::Buy).unwrap(),
333            LAMPORTS_PER_SOL + delta * 9999999
334        );
335    }
336
337    #[test]
338    #[should_panic(expected = "ArithmeticError")]
339    fn test_linear_nft_pool_panic_overflow() {
340        let delta = LAMPORTS_PER_SOL / 10 * 100;
341        let p = Pool::new_test_pool(
342            PoolType::NFT,
343            CurveType::Linear,
344            LAMPORTS_PER_SOL * 100,
345            delta,
346            i32::MAX - 1, //get this to overflow
347            NullableU16::none(),
348        );
349        // Cannot go higher
350        p.current_price(TakerSide::Buy).unwrap();
351    }
352
353    #[test]
354    #[should_panic(expected = "WrongPoolType")]
355    fn test_linear_nft_pool_panic_on_sell() {
356        let delta = LAMPORTS_PER_SOL / 10 * 100;
357        let p = Pool::new_test_pool(
358            PoolType::NFT,
359            CurveType::Linear,
360            LAMPORTS_PER_SOL * 100,
361            delta,
362            0,
363            NullableU16::none(),
364        );
365        // NFT pools only sell NFTs (buyer buys from them).
366        p.current_price(TakerSide::Sell).unwrap();
367    }
368
369    // Trade Pool
370
371    #[test]
372    fn test_linear_trade_pool() {
373        let delta = LAMPORTS_PER_SOL / 10;
374        let mut p = Pool::new_test_pool(
375            PoolType::Trade,
376            CurveType::Linear,
377            LAMPORTS_PER_SOL,
378            delta,
379            0,
380            NullableU16::none(),
381        );
382        // NB: selling into the pool is always 1 delta lower than buying.
383
384        assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
385        assert_eq!(
386            p.current_price(TakerSide::Sell).unwrap(),
387            LAMPORTS_PER_SOL - delta
388        );
389
390        //pool's a buyer -> price goes down
391        p.price_offset -= 1;
392        assert_eq!(
393            p.current_price(TakerSide::Buy).unwrap(),
394            LAMPORTS_PER_SOL - delta
395        );
396        assert_eq!(
397            p.current_price(TakerSide::Sell).unwrap(),
398            LAMPORTS_PER_SOL - delta * 2
399        );
400
401        p.price_offset -= 2;
402        assert_eq!(
403            p.current_price(TakerSide::Buy).unwrap(),
404            LAMPORTS_PER_SOL - delta * 3
405        );
406        assert_eq!(
407            p.current_price(TakerSide::Sell).unwrap(),
408            LAMPORTS_PER_SOL - delta * 4
409        );
410        //pool can pay 0
411        p.price_offset -= 7;
412        assert_eq!(
413            p.current_price(TakerSide::Buy).unwrap(),
414            LAMPORTS_PER_SOL - delta * 10
415        );
416
417        // Sell price will overflow.
418
419        //pool's neutral
420        p.price_offset += 10;
421        assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
422        assert_eq!(
423            p.current_price(TakerSide::Sell).unwrap(),
424            LAMPORTS_PER_SOL - delta
425        );
426
427        //pool's a seller -> price goes up
428        p.price_offset += 1;
429        assert_eq!(
430            p.current_price(TakerSide::Buy).unwrap(),
431            LAMPORTS_PER_SOL + delta
432        );
433        assert_eq!(p.current_price(TakerSide::Sell).unwrap(), LAMPORTS_PER_SOL);
434
435        p.price_offset += 2;
436        assert_eq!(
437            p.current_price(TakerSide::Buy).unwrap(),
438            LAMPORTS_PER_SOL + delta * 3
439        );
440        assert_eq!(
441            p.current_price(TakerSide::Sell).unwrap(),
442            LAMPORTS_PER_SOL + delta * 2
443        );
444        //go much higher
445        p.price_offset += 9999996;
446        assert_eq!(
447            p.current_price(TakerSide::Buy).unwrap(),
448            LAMPORTS_PER_SOL + delta * 9999999
449        );
450        assert_eq!(
451            p.current_price(TakerSide::Sell).unwrap(),
452            LAMPORTS_PER_SOL + delta * 9999998
453        );
454    }
455
456    #[test]
457    #[should_panic(expected = "ArithmeticError")]
458    fn test_linear_trade_pool_panic_lower() {
459        let delta = LAMPORTS_PER_SOL / 10;
460        let p = Pool::new_test_pool(
461            PoolType::Trade,
462            CurveType::Linear,
463            LAMPORTS_PER_SOL,
464            delta,
465            -11,
466            NullableU16::none(),
467        );
468        p.current_price(TakerSide::Buy).unwrap();
469    }
470
471    #[test]
472    #[should_panic(expected = "ArithmeticError")]
473    fn test_linear_trade_pool_panic_sell_side_lower() {
474        let delta = LAMPORTS_PER_SOL / 10;
475        let p = Pool::new_test_pool(
476            PoolType::Trade,
477            CurveType::Linear,
478            LAMPORTS_PER_SOL,
479            delta,
480            -10, //10+1 tick for selling = overflow
481            NullableU16::none(),
482        );
483        p.current_price(TakerSide::Sell).unwrap();
484    }
485
486    #[test]
487    #[should_panic(expected = "ArithmeticError")]
488    fn test_linear_trade_pool_panic_upper() {
489        let delta = LAMPORTS_PER_SOL * 10_000_000_000;
490        let p = Pool::new_test_pool(
491            PoolType::Trade,
492            CurveType::Linear,
493            delta,
494            delta,
495            1, //just enough to overflow
496            NullableU16::none(),
497        );
498        p.current_price(TakerSide::Buy).unwrap();
499    }
500
501    #[test]
502    fn test_linear_trade_pool_sell_side_upper() {
503        let delta = LAMPORTS_PER_SOL * 10_000_000_000;
504        let p = Pool::new_test_pool(
505            PoolType::Trade,
506            CurveType::Linear,
507            delta,
508            delta,
509            1,
510            NullableU16::none(),
511        );
512        // This shouldn't oveflow for sell side (1 tick lower).
513        assert_eq!(p.current_price(TakerSide::Sell).unwrap(), delta);
514    }
515}