1use std::fmt::{self, Display, Formatter};
2
3use crate::accounts::Pool;
4use crate::errors::TensorAmmError;
5use crate::types::{
6 CurveType, Direction, EditPoolConfig, PoolConfig, PoolStats, PoolType, TakerSide,
7};
8use crate::HUNDRED_PCT_BPS;
9
10use spl_math::precise_number::PreciseNumber;
11
12#[allow(clippy::derivable_impls)]
13impl Default for PoolStats {
14 fn default() -> Self {
15 Self {
16 taker_sell_count: 0,
17 taker_buy_count: 0,
18 accumulated_mm_profit: 0,
19 }
20 }
21}
22
23impl Pool {
24 pub fn shift_price(&self, price_offset: i32, side: TakerSide) -> Result<u64, TensorAmmError> {
26 let direction = if price_offset > 0 {
27 Direction::Up
28 } else {
29 Direction::Down
30 };
31
32 let offset = price_offset.unsigned_abs();
33
34 let current_price = match self.config.curve_type {
35 CurveType::Linear => {
36 let base = self.config.starting_price;
37 let delta = self.config.delta;
38
39 match direction {
40 Direction::Up => base
41 .checked_add(
42 delta
43 .checked_mul(offset as u64)
44 .ok_or(TensorAmmError::ArithmeticError)?,
45 )
46 .ok_or(TensorAmmError::ArithmeticError)?,
47 Direction::Down => base
48 .checked_sub(
49 delta
50 .checked_mul(offset as u64)
51 .ok_or(TensorAmmError::ArithmeticError)?,
52 )
53 .ok_or(TensorAmmError::ArithmeticError)?,
54 }
55 }
56 CurveType::Exponential => {
57 let hundred_pct = PreciseNumber::new(HUNDRED_PCT_BPS.into())
58 .ok_or(TensorAmmError::ArithmeticError)?;
59
60 let base = PreciseNumber::new(self.config.starting_price.into())
61 .ok_or(TensorAmmError::ArithmeticError)?;
62
63 let factor = PreciseNumber::new(
64 (HUNDRED_PCT_BPS)
65 .checked_add(self.config.delta)
66 .ok_or(TensorAmmError::ArithmeticError)?
67 .into(),
68 )
69 .ok_or(TensorAmmError::ArithmeticError)?
70 .checked_div(&hundred_pct)
71 .ok_or(TensorAmmError::ArithmeticError)?
72 .checked_pow(offset.into())
73 .ok_or(TensorAmmError::ArithmeticError)?;
74
75 let result = match direction {
76 Direction::Up => base.checked_mul(&factor),
78 Direction::Down => base.checked_div(&factor),
80 };
81
82 let rounded_result = match side {
83 TakerSide::Buy => result.ok_or(TensorAmmError::ArithmeticError)?.ceiling(),
84 TakerSide::Sell => result.ok_or(TensorAmmError::ArithmeticError)?.floor(),
85 };
86
87 let imprecise = rounded_result
88 .ok_or(TensorAmmError::ArithmeticError)?
89 .to_imprecise()
90 .ok_or(TensorAmmError::ArithmeticError)?;
91
92 u64::try_from(imprecise)
93 .ok()
94 .ok_or(TensorAmmError::ArithmeticError)?
95 }
96 };
97
98 Ok(current_price)
99 }
100
101 pub fn current_price(&self, side: TakerSide) -> Result<u64, TensorAmmError> {
103 match (self.config.pool_type, side) {
104 (PoolType::Trade, TakerSide::Buy)
105 | (PoolType::Token, TakerSide::Sell)
106 | (PoolType::NFT, TakerSide::Buy) => self.shift_price(self.price_offset, side),
107
108 (PoolType::Trade, TakerSide::Sell) => self.shift_price(self.price_offset - 1, side),
111
112 _ => Err(TensorAmmError::WrongPoolType),
114 }
115 }
116
117 pub fn calc_mm_fee(&self, current_price: u64) -> Result<u64, TensorAmmError> {
119 let fee = match self.config.pool_type {
120 PoolType::Trade => (self.config.mm_fee_bps.into_base() as u64)
121 .checked_mul(current_price)
122 .ok_or(TensorAmmError::ArithmeticError)?
123 .checked_div(HUNDRED_PCT_BPS)
124 .ok_or(TensorAmmError::ArithmeticError)?,
125 PoolType::NFT | PoolType::Token => 0, };
127
128 Ok(fee)
129 }
130}
131
132impl Display for PoolType {
133 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
134 match self {
135 PoolType::Trade => write!(f, "Trade"),
136 PoolType::Token => write!(f, "Token"),
137 PoolType::NFT => write!(f, "NFT"),
138 }
139 }
140}
141
142impl Display for CurveType {
143 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
144 match self {
145 CurveType::Linear => write!(f, "Linear"),
146 CurveType::Exponential => write!(f, "Exponential"),
147 }
148 }
149}
150
151impl EditPoolConfig {
152 pub fn into_pool_config(self, pool_type: PoolType) -> PoolConfig {
153 PoolConfig {
154 pool_type,
155 curve_type: self.curve_type,
156 starting_price: self.starting_price,
157 delta: self.delta,
158 mm_compound_fees: self.mm_compound_fees,
159 mm_fee_bps: self.mm_fee_bps,
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 use solana_program::pubkey::Pubkey;
169
170 use crate::{
171 types::{PoolConfig, PoolStats},
172 Currency, NullableAddress, NullableU16, LAMPORTS_PER_SOL,
173 };
174
175 impl Pool {
176 pub fn new_test_pool(
177 pool_type: PoolType,
178 curve_type: CurveType,
179 starting_price: u64,
180 delta: u64,
181 price_offset: i32,
182 mm_fee_bps: NullableU16,
183 ) -> Self {
184 Self {
185 discriminator: [0; 8],
186 version: 1,
187 bump: [1],
188 created_at: 1234,
189 updated_at: 0,
190 expiry: 0,
191 owner: Pubkey::default(),
192 cosigner: NullableAddress::none(),
193 maker_broker: NullableAddress::none(),
194 rent_payer: Pubkey::default(),
195 whitelist: Pubkey::default(),
196 pool_id: [0; 32],
197 config: PoolConfig {
198 pool_type,
199 curve_type,
200 starting_price,
201 delta,
202 mm_compound_fees: true,
203 mm_fee_bps,
204 },
205 price_offset,
206 nfts_held: 0,
207 stats: PoolStats::default(),
208 currency: Currency::sol(),
209 amount: 0,
210 shared_escrow: NullableAddress::none(),
211 max_taker_sell_count: 10,
212 reserved: [0; 100],
213 }
214 }
215 }
216
217 #[test]
222 fn test_linear_token_pool() {
223 let delta = LAMPORTS_PER_SOL / 10;
224 let mut p = Pool::new_test_pool(
225 PoolType::Token,
226 CurveType::Linear,
227 LAMPORTS_PER_SOL,
228 delta,
229 0,
230 NullableU16::none(),
231 );
232 assert_eq!(p.current_price(TakerSide::Sell).unwrap(), LAMPORTS_PER_SOL);
233
234 p.price_offset -= 1;
237 assert_eq!(
238 p.current_price(TakerSide::Sell).unwrap(),
239 LAMPORTS_PER_SOL - delta
240 );
241
242 p.price_offset -= 2;
245 assert_eq!(
246 p.current_price(TakerSide::Sell).unwrap(),
247 LAMPORTS_PER_SOL - delta * 3
248 );
249
250 p.price_offset -= 7;
253 assert_eq!(
254 p.current_price(TakerSide::Sell).unwrap(),
255 LAMPORTS_PER_SOL - delta * 10
256 );
257
258 assert_eq!(p.current_price(TakerSide::Sell).unwrap(), 0);
260 }
261
262 #[test]
263 #[should_panic(expected = "ArithmeticError")]
264 fn test_linear_token_pool_panic_overflow() {
265 let delta = LAMPORTS_PER_SOL / 10;
266 let p = Pool::new_test_pool(
267 PoolType::Token,
268 CurveType::Linear,
269 LAMPORTS_PER_SOL,
270 delta,
271 -11,
272 NullableU16::none(),
273 );
274 p.current_price(TakerSide::Sell).unwrap();
278 }
279
280 #[test]
281 #[should_panic(expected = "WrongPoolType")]
282 fn test_linear_token_pool_panic_on_buy() {
283 let delta = LAMPORTS_PER_SOL / 10;
284 let p = Pool::new_test_pool(
285 PoolType::Token,
286 CurveType::Linear,
287 LAMPORTS_PER_SOL,
288 delta,
289 0,
290 NullableU16::none(),
291 );
292 p.current_price(TakerSide::Buy).unwrap();
295 }
296
297 #[test]
300 fn test_linear_nft_pool() {
301 let delta = LAMPORTS_PER_SOL / 10;
302 let mut p = Pool::new_test_pool(
303 PoolType::NFT,
304 CurveType::Linear,
305 LAMPORTS_PER_SOL,
306 delta,
307 0,
308 NullableU16::none(),
309 );
310 assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
311
312 p.price_offset += 1;
315 assert_eq!(
316 p.current_price(TakerSide::Buy).unwrap(),
317 LAMPORTS_PER_SOL + delta
318 );
319
320 p.price_offset += 2;
322 assert_eq!(
323 p.current_price(TakerSide::Buy).unwrap(),
324 LAMPORTS_PER_SOL + delta * 3
325 );
326
327 p.price_offset += 9999996;
331 assert_eq!(
332 p.current_price(TakerSide::Buy).unwrap(),
333 LAMPORTS_PER_SOL + delta * 9999999
334 );
335 }
336
337 #[test]
338 #[should_panic(expected = "ArithmeticError")]
339 fn test_linear_nft_pool_panic_overflow() {
340 let delta = LAMPORTS_PER_SOL / 10 * 100;
341 let p = Pool::new_test_pool(
342 PoolType::NFT,
343 CurveType::Linear,
344 LAMPORTS_PER_SOL * 100,
345 delta,
346 i32::MAX - 1, NullableU16::none(),
348 );
349 p.current_price(TakerSide::Buy).unwrap();
351 }
352
353 #[test]
354 #[should_panic(expected = "WrongPoolType")]
355 fn test_linear_nft_pool_panic_on_sell() {
356 let delta = LAMPORTS_PER_SOL / 10 * 100;
357 let p = Pool::new_test_pool(
358 PoolType::NFT,
359 CurveType::Linear,
360 LAMPORTS_PER_SOL * 100,
361 delta,
362 0,
363 NullableU16::none(),
364 );
365 p.current_price(TakerSide::Sell).unwrap();
367 }
368
369 #[test]
372 fn test_linear_trade_pool() {
373 let delta = LAMPORTS_PER_SOL / 10;
374 let mut p = Pool::new_test_pool(
375 PoolType::Trade,
376 CurveType::Linear,
377 LAMPORTS_PER_SOL,
378 delta,
379 0,
380 NullableU16::none(),
381 );
382 assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
385 assert_eq!(
386 p.current_price(TakerSide::Sell).unwrap(),
387 LAMPORTS_PER_SOL - delta
388 );
389
390 p.price_offset -= 1;
392 assert_eq!(
393 p.current_price(TakerSide::Buy).unwrap(),
394 LAMPORTS_PER_SOL - delta
395 );
396 assert_eq!(
397 p.current_price(TakerSide::Sell).unwrap(),
398 LAMPORTS_PER_SOL - delta * 2
399 );
400
401 p.price_offset -= 2;
402 assert_eq!(
403 p.current_price(TakerSide::Buy).unwrap(),
404 LAMPORTS_PER_SOL - delta * 3
405 );
406 assert_eq!(
407 p.current_price(TakerSide::Sell).unwrap(),
408 LAMPORTS_PER_SOL - delta * 4
409 );
410 p.price_offset -= 7;
412 assert_eq!(
413 p.current_price(TakerSide::Buy).unwrap(),
414 LAMPORTS_PER_SOL - delta * 10
415 );
416
417 p.price_offset += 10;
421 assert_eq!(p.current_price(TakerSide::Buy).unwrap(), LAMPORTS_PER_SOL);
422 assert_eq!(
423 p.current_price(TakerSide::Sell).unwrap(),
424 LAMPORTS_PER_SOL - delta
425 );
426
427 p.price_offset += 1;
429 assert_eq!(
430 p.current_price(TakerSide::Buy).unwrap(),
431 LAMPORTS_PER_SOL + delta
432 );
433 assert_eq!(p.current_price(TakerSide::Sell).unwrap(), LAMPORTS_PER_SOL);
434
435 p.price_offset += 2;
436 assert_eq!(
437 p.current_price(TakerSide::Buy).unwrap(),
438 LAMPORTS_PER_SOL + delta * 3
439 );
440 assert_eq!(
441 p.current_price(TakerSide::Sell).unwrap(),
442 LAMPORTS_PER_SOL + delta * 2
443 );
444 p.price_offset += 9999996;
446 assert_eq!(
447 p.current_price(TakerSide::Buy).unwrap(),
448 LAMPORTS_PER_SOL + delta * 9999999
449 );
450 assert_eq!(
451 p.current_price(TakerSide::Sell).unwrap(),
452 LAMPORTS_PER_SOL + delta * 9999998
453 );
454 }
455
456 #[test]
457 #[should_panic(expected = "ArithmeticError")]
458 fn test_linear_trade_pool_panic_lower() {
459 let delta = LAMPORTS_PER_SOL / 10;
460 let p = Pool::new_test_pool(
461 PoolType::Trade,
462 CurveType::Linear,
463 LAMPORTS_PER_SOL,
464 delta,
465 -11,
466 NullableU16::none(),
467 );
468 p.current_price(TakerSide::Buy).unwrap();
469 }
470
471 #[test]
472 #[should_panic(expected = "ArithmeticError")]
473 fn test_linear_trade_pool_panic_sell_side_lower() {
474 let delta = LAMPORTS_PER_SOL / 10;
475 let p = Pool::new_test_pool(
476 PoolType::Trade,
477 CurveType::Linear,
478 LAMPORTS_PER_SOL,
479 delta,
480 -10, NullableU16::none(),
482 );
483 p.current_price(TakerSide::Sell).unwrap();
484 }
485
486 #[test]
487 #[should_panic(expected = "ArithmeticError")]
488 fn test_linear_trade_pool_panic_upper() {
489 let delta = LAMPORTS_PER_SOL * 10_000_000_000;
490 let p = Pool::new_test_pool(
491 PoolType::Trade,
492 CurveType::Linear,
493 delta,
494 delta,
495 1, NullableU16::none(),
497 );
498 p.current_price(TakerSide::Buy).unwrap();
499 }
500
501 #[test]
502 fn test_linear_trade_pool_sell_side_upper() {
503 let delta = LAMPORTS_PER_SOL * 10_000_000_000;
504 let p = Pool::new_test_pool(
505 PoolType::Trade,
506 CurveType::Linear,
507 delta,
508 delta,
509 1,
510 NullableU16::none(),
511 );
512 assert_eq!(p.current_price(TakerSide::Sell).unwrap(), delta);
514 }
515}