1use {
4 crate::{
5 big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6 WITHDRAWAL_BASELINE_FEE,
7 },
8 borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9 num_derive::FromPrimitive,
10 num_traits::FromPrimitive,
11 solana_program::{
12 account_info::AccountInfo,
13 borsh::get_instance_packed_len,
14 msg,
15 program_error::ProgramError,
16 program_memory::sol_memcmp,
17 program_pack::{Pack, Sealed},
18 pubkey::{Pubkey, PUBKEY_BYTES},
19 stake::state::Lockup,
20 },
21 spl_token_2022::{
22 extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
23 state::{Account, AccountState, Mint},
24 },
25 std::{borrow::Borrow, convert::TryFrom, fmt, matches},
26};
27
28#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
30pub enum AccountType {
31 Uninitialized,
33 StakePool,
35 ValidatorList,
37}
38
39impl Default for AccountType {
40 fn default() -> Self {
41 AccountType::Uninitialized
42 }
43}
44
45#[repr(C)]
47#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
48pub struct StakePool {
49 pub account_type: AccountType,
51
52 pub manager: Pubkey,
54
55 pub staker: Pubkey,
58
59 pub stake_deposit_authority: Pubkey,
69
70 pub stake_withdraw_bump_seed: u8,
73
74 pub validator_list: Pubkey,
76
77 pub reserve_stake: Pubkey,
79
80 pub pool_mint: Pubkey,
82
83 pub manager_fee_account: Pubkey,
85
86 pub token_program_id: Pubkey,
88
89 pub total_lamports: u64,
93
94 pub pool_token_supply: u64,
96
97 pub last_update_epoch: u64,
99
100 pub lockup: Lockup,
102
103 pub epoch_fee: Fee,
105
106 pub next_epoch_fee: FutureEpoch<Fee>,
108
109 pub preferred_deposit_validator_vote_address: Option<Pubkey>,
111
112 pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
114
115 pub stake_deposit_fee: Fee,
117
118 pub stake_withdrawal_fee: Fee,
120
121 pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
123
124 pub stake_referral_fee: u8,
129
130 pub sol_deposit_authority: Option<Pubkey>,
133
134 pub sol_deposit_fee: Fee,
136
137 pub sol_referral_fee: u8,
142
143 pub sol_withdraw_authority: Option<Pubkey>,
146
147 pub sol_withdrawal_fee: Fee,
149
150 pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
152
153 pub last_epoch_pool_token_supply: u64,
155
156 pub last_epoch_total_lamports: u64,
158}
159impl StakePool {
160 #[inline]
162 pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
163 if self.total_lamports == 0 || self.pool_token_supply == 0 {
164 return Some(stake_lamports);
165 }
166 u64::try_from(
167 (stake_lamports as u128)
168 .checked_mul(self.pool_token_supply as u128)?
169 .checked_div(self.total_lamports as u128)?,
170 )
171 .ok()
172 }
173
174 #[inline]
176 pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
177 let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
181 let denominator = self.pool_token_supply as u128;
182 if numerator < denominator || denominator == 0 {
183 Some(0)
184 } else {
185 u64::try_from(numerator.checked_div(denominator)?).ok()
186 }
187 }
188
189 #[inline]
191 pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
192 u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
193 }
194
195 #[inline]
197 pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
198 u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
199 }
200
201 #[inline]
203 pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
204 u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
205 }
206
207 #[inline]
209 pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
210 u64::try_from(
211 (stake_deposit_fee as u128)
212 .checked_mul(self.stake_referral_fee as u128)?
213 .checked_div(100u128)?,
214 )
215 .ok()
216 }
217
218 #[inline]
220 pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
221 u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
222 }
223
224 #[inline]
226 pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
227 u64::try_from(
228 (sol_deposit_fee as u128)
229 .checked_mul(self.sol_referral_fee as u128)?
230 .checked_div(100u128)?,
231 )
232 .ok()
233 }
234
235 #[inline]
240 pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
241 if reward_lamports == 0 {
242 return Some(0);
243 }
244 let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
245 let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
246 if total_lamports == fee_lamports || self.pool_token_supply == 0 {
247 Some(reward_lamports)
248 } else {
249 u64::try_from(
250 (self.pool_token_supply as u128)
251 .checked_mul(fee_lamports)?
252 .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
253 )
254 .ok()
255 }
256 }
257
258 #[inline]
260 pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
261 self.total_lamports
262 .checked_add(self.pool_token_supply)?
263 .checked_sub(1)?
264 .checked_div(self.pool_token_supply)
265 }
266
267 fn check_program_derived_authority(
269 authority_address: &Pubkey,
270 program_id: &Pubkey,
271 stake_pool_address: &Pubkey,
272 authority_seed: &[u8],
273 bump_seed: u8,
274 ) -> Result<(), ProgramError> {
275 let expected_address = Pubkey::create_program_address(
276 &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
277 program_id,
278 )?;
279
280 if *authority_address == expected_address {
281 Ok(())
282 } else {
283 msg!(
284 "Incorrect authority provided, expected {}, received {}",
285 expected_address,
286 authority_address
287 );
288 Err(StakePoolError::InvalidProgramAddress.into())
289 }
290 }
291
292 pub(crate) fn check_manager_fee_info(
295 &self,
296 manager_fee_info: &AccountInfo,
297 ) -> Result<(), ProgramError> {
298 let account_data = manager_fee_info.try_borrow_data()?;
299 let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
300 if manager_fee_info.owner != &self.token_program_id
301 || token_account.base.state != AccountState::Initialized
302 || token_account.base.mint != self.pool_mint
303 {
304 msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
305 return Err(StakePoolError::InvalidFeeAccount.into());
306 }
307 let extensions = token_account.get_extension_types()?;
308 if extensions
309 .iter()
310 .any(|x| !is_extension_supported_for_fee_account(x))
311 {
312 return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
313 }
314 Ok(())
315 }
316
317 #[inline]
319 pub(crate) fn check_authority_withdraw(
320 &self,
321 withdraw_authority: &Pubkey,
322 program_id: &Pubkey,
323 stake_pool_address: &Pubkey,
324 ) -> Result<(), ProgramError> {
325 Self::check_program_derived_authority(
326 withdraw_authority,
327 program_id,
328 stake_pool_address,
329 crate::AUTHORITY_WITHDRAW,
330 self.stake_withdraw_bump_seed,
331 )
332 }
333 #[inline]
335 pub(crate) fn check_stake_deposit_authority(
336 &self,
337 stake_deposit_authority: &Pubkey,
338 ) -> Result<(), ProgramError> {
339 if self.stake_deposit_authority == *stake_deposit_authority {
340 Ok(())
341 } else {
342 Err(StakePoolError::InvalidStakeDepositAuthority.into())
343 }
344 }
345
346 #[inline]
349 pub(crate) fn check_sol_deposit_authority(
350 &self,
351 maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
352 ) -> Result<(), ProgramError> {
353 if let Some(auth) = self.sol_deposit_authority {
354 let sol_deposit_authority = maybe_sol_deposit_authority?;
355 if auth != *sol_deposit_authority.key {
356 msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
357 return Err(StakePoolError::InvalidSolDepositAuthority.into());
358 }
359 if !sol_deposit_authority.is_signer {
360 msg!("SOL Deposit authority signature missing");
361 return Err(StakePoolError::SignatureMissing.into());
362 }
363 }
364 Ok(())
365 }
366
367 #[inline]
370 pub(crate) fn check_sol_withdraw_authority(
371 &self,
372 maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
373 ) -> Result<(), ProgramError> {
374 if let Some(auth) = self.sol_withdraw_authority {
375 let sol_withdraw_authority = maybe_sol_withdraw_authority?;
376 if auth != *sol_withdraw_authority.key {
377 return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
378 }
379 if !sol_withdraw_authority.is_signer {
380 msg!("SOL withdraw authority signature missing");
381 return Err(StakePoolError::SignatureMissing.into());
382 }
383 }
384 Ok(())
385 }
386
387 #[inline]
389 pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
390 if *mint_info.key != self.pool_mint {
391 Err(StakePoolError::WrongPoolMint.into())
392 } else {
393 let mint_data = mint_info.try_borrow_data()?;
394 let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
395 Ok(mint.base.decimals)
396 }
397 }
398
399 pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
401 if *manager_info.key != self.manager {
402 msg!(
403 "Incorrect manager provided, expected {}, received {}",
404 self.manager,
405 manager_info.key
406 );
407 return Err(StakePoolError::WrongManager.into());
408 }
409 if !manager_info.is_signer {
410 msg!("Manager signature missing");
411 return Err(StakePoolError::SignatureMissing.into());
412 }
413 Ok(())
414 }
415
416 pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
418 if *staker_info.key != self.staker {
419 msg!(
420 "Incorrect staker provided, expected {}, received {}",
421 self.staker,
422 staker_info.key
423 );
424 return Err(StakePoolError::WrongStaker.into());
425 }
426 if !staker_info.is_signer {
427 msg!("Staker signature missing");
428 return Err(StakePoolError::SignatureMissing.into());
429 }
430 Ok(())
431 }
432
433 pub fn check_validator_list(
435 &self,
436 validator_list_info: &AccountInfo,
437 ) -> Result<(), ProgramError> {
438 if *validator_list_info.key != self.validator_list {
439 msg!(
440 "Invalid validator list provided, expected {}, received {}",
441 self.validator_list,
442 validator_list_info.key
443 );
444 Err(StakePoolError::InvalidValidatorStakeList.into())
445 } else {
446 Ok(())
447 }
448 }
449
450 pub fn check_reserve_stake(
452 &self,
453 reserve_stake_info: &AccountInfo,
454 ) -> Result<(), ProgramError> {
455 if *reserve_stake_info.key != self.reserve_stake {
456 msg!(
457 "Invalid reserve stake provided, expected {}, received {}",
458 self.reserve_stake,
459 reserve_stake_info.key
460 );
461 Err(StakePoolError::InvalidProgramAddress.into())
462 } else {
463 Ok(())
464 }
465 }
466
467 pub fn is_valid(&self) -> bool {
469 self.account_type == AccountType::StakePool
470 }
471
472 pub fn is_uninitialized(&self) -> bool {
474 self.account_type == AccountType::Uninitialized
475 }
476
477 pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
479 match fee {
480 FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
481 FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
482 FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
483 FeeType::StakeWithdrawal(new_fee) => {
484 new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
485 self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
486 }
487 FeeType::SolWithdrawal(new_fee) => {
488 new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
489 self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
490 }
491 FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
492 FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
493 };
494 Ok(())
495 }
496}
497
498pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
500 const SUPPORTED_EXTENSIONS: [ExtensionType; 5] = [
501 ExtensionType::Uninitialized,
502 ExtensionType::TransferFeeConfig,
503 ExtensionType::ConfidentialTransferMint,
504 ExtensionType::DefaultAccountState, ExtensionType::InterestBearingConfig,
506 ];
507 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
508 msg!(
509 "Stake pool mint account cannot have the {:?} extension",
510 extension_type
511 );
512 false
513 } else {
514 true
515 }
516}
517
518pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
520 const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
524 ExtensionType::Uninitialized,
525 ExtensionType::TransferFeeAmount,
526 ExtensionType::ImmutableOwner,
527 ExtensionType::CpiGuard,
528 ];
529 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
530 msg!("Fee account cannot have the {:?} extension", extension_type);
531 false
532 } else {
533 true
534 }
535}
536
537#[repr(C)]
539#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
540pub struct ValidatorList {
541 pub header: ValidatorListHeader,
543
544 pub validators: Vec<ValidatorStakeInfo>,
546}
547
548#[repr(C)]
550#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
551pub struct ValidatorListHeader {
552 pub account_type: AccountType,
554
555 pub max_validators: u32,
557}
558
559#[derive(
561 FromPrimitive, Copy, Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema,
562)]
563pub enum StakeStatus {
564 Active,
566 DeactivatingTransient,
569 ReadyForRemoval,
572 DeactivatingValidator,
575 DeactivatingAll,
578}
579impl StakeStatus {
580 pub fn remove_validator_stake(&mut self) {
582 let new_self = match self {
583 Self::Active | Self::DeactivatingTransient | Self::ReadyForRemoval => *self,
584 Self::DeactivatingAll => Self::DeactivatingTransient,
585 Self::DeactivatingValidator => Self::ReadyForRemoval,
586 };
587 *self = new_self;
588 }
589 pub fn remove_transient_stake(&mut self) {
591 let new_self = match self {
592 Self::Active | Self::DeactivatingValidator | Self::ReadyForRemoval => *self,
593 Self::DeactivatingAll => Self::DeactivatingValidator,
594 Self::DeactivatingTransient => Self::ReadyForRemoval,
595 };
596 *self = new_self;
597 }
598}
599impl Default for StakeStatus {
600 fn default() -> Self {
601 Self::Active
602 }
603}
604
605#[derive(Debug, PartialEq)]
607pub(crate) enum StakeWithdrawSource {
608 Active,
610 Transient,
612 ValidatorRemoval,
614}
615
616#[repr(C)]
625#[derive(Clone, Copy, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
626pub struct ValidatorStakeInfo {
627 pub active_stake_lamports: u64,
632
633 pub transient_stake_lamports: u64,
638
639 pub last_update_epoch: u64,
641
642 pub transient_seed_suffix: u64,
644
645 pub unused: u32,
647
648 pub validator_seed_suffix: u32, pub status: StakeStatus,
653
654 pub vote_account_address: Pubkey,
656}
657
658impl ValidatorStakeInfo {
659 pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
661 self.active_stake_lamports
662 .checked_add(self.transient_stake_lamports)
663 .ok_or(StakePoolError::CalculationFailure)
664 }
665
666 pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
669 sol_memcmp(
670 &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
671 vote_address.as_ref(),
672 PUBKEY_BYTES,
673 ) == 0
674 }
675
676 pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
679 u64::try_from_slice(&data[0..8]).unwrap() > *lamports
681 }
682
683 pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
686 u64::try_from_slice(&data[8..16]).unwrap() > *lamports
688 }
689
690 pub fn is_not_removed(data: &[u8]) -> bool {
692 FromPrimitive::from_u8(data[40]) != Some(StakeStatus::ReadyForRemoval)
693 }
694}
695
696impl Sealed for ValidatorStakeInfo {}
697
698impl Pack for ValidatorStakeInfo {
699 const LEN: usize = 73;
700 fn pack_into_slice(&self, data: &mut [u8]) {
701 let mut data = data;
702 self.serialize(&mut data).unwrap();
705 }
706 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
707 let unpacked = Self::try_from_slice(src)?;
708 Ok(unpacked)
709 }
710}
711
712impl ValidatorList {
713 pub fn new(max_validators: u32) -> Self {
715 Self {
716 header: ValidatorListHeader {
717 account_type: AccountType::ValidatorList,
718 max_validators,
719 },
720 validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
721 }
722 }
723
724 pub fn calculate_max_validators(buffer_length: usize) -> usize {
726 let header_size = ValidatorListHeader::LEN.saturating_add(4);
727 buffer_length
728 .saturating_sub(header_size)
729 .saturating_div(ValidatorStakeInfo::LEN)
730 }
731
732 pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
734 self.validators
735 .iter()
736 .any(|x| x.vote_account_address == *vote_account_address)
737 }
738
739 pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
741 self.validators
742 .iter_mut()
743 .find(|x| x.vote_account_address == *vote_account_address)
744 }
745 pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
747 self.validators
748 .iter()
749 .find(|x| x.vote_account_address == *vote_account_address)
750 }
751
752 pub fn has_active_stake(&self) -> bool {
754 self.validators.iter().any(|x| x.active_stake_lamports > 0)
755 }
756}
757
758impl ValidatorListHeader {
759 const LEN: usize = 1 + 4;
760
761 pub fn is_valid(&self) -> bool {
763 self.account_type == AccountType::ValidatorList
764 }
765
766 pub fn is_uninitialized(&self) -> bool {
768 self.account_type == AccountType::Uninitialized
769 }
770
771 pub fn deserialize_mut_slice(
774 data: &mut [u8],
775 skip: usize,
776 len: usize,
777 ) -> Result<(Self, Vec<&mut ValidatorStakeInfo>), ProgramError> {
778 let (header, mut big_vec) = Self::deserialize_vec(data)?;
779 let validator_list = big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)?;
780 Ok((header, validator_list))
781 }
782
783 pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec), ProgramError> {
785 let mut data_mut = data.borrow();
786 let header = ValidatorListHeader::deserialize(&mut data_mut)?;
787 let length = get_instance_packed_len(&header)?;
788
789 let big_vec = BigVec {
790 data: &mut data[length..],
791 };
792 Ok((header, big_vec))
793 }
794}
795
796#[repr(C)]
799#[derive(Clone, Copy, Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
800pub enum FutureEpoch<T> {
801 None,
803 One(T),
805 Two(T),
807}
808impl<T> Default for FutureEpoch<T> {
809 fn default() -> Self {
810 Self::None
811 }
812}
813impl<T> FutureEpoch<T> {
814 pub fn new(value: T) -> Self {
816 Self::Two(value)
817 }
818}
819impl<T: Clone> FutureEpoch<T> {
820 pub fn update_epoch(&mut self) {
822 match self {
823 Self::None => {}
824 Self::One(_) => {
825 *self = Self::None;
827 }
828 Self::Two(v) => {
830 *self = Self::One(v.clone());
831 }
832 }
833 }
834
835 pub fn get(&self) -> Option<&T> {
837 match self {
838 Self::None | Self::Two(_) => None,
839 Self::One(v) => Some(v),
840 }
841 }
842}
843impl<T> From<FutureEpoch<T>> for Option<T> {
844 fn from(v: FutureEpoch<T>) -> Option<T> {
845 match v {
846 FutureEpoch::None => None,
847 FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
848 }
849 }
850}
851
852#[repr(C)]
856#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
857pub struct Fee {
858 pub denominator: u64,
860 pub numerator: u64,
862}
863
864impl Fee {
865 #[inline]
870 pub fn apply(&self, amt: u64) -> Option<u128> {
871 if self.denominator == 0 {
872 return Some(0);
873 }
874 (amt as u128)
875 .checked_mul(self.numerator as u128)?
876 .checked_div(self.denominator as u128)
877 }
878
879 pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
883 let (old_num, old_denom) =
886 if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
887 (
888 WITHDRAWAL_BASELINE_FEE.numerator,
889 WITHDRAWAL_BASELINE_FEE.denominator,
890 )
891 } else {
892 (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
893 };
894
895 if (old_num as u128)
898 .checked_mul(self.denominator as u128)
899 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
900 .ok_or(StakePoolError::CalculationFailure)?
901 < (self.numerator as u128)
902 .checked_mul(old_denom as u128)
903 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
904 .ok_or(StakePoolError::CalculationFailure)?
905 {
906 msg!(
907 "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
908 self.numerator.saturating_mul(old_denom),
909 old_num.saturating_mul(self.denominator),
910 );
911 return Err(StakePoolError::FeeIncreaseTooHigh);
912 }
913 Ok(())
914 }
915}
916
917impl fmt::Display for Fee {
918 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
919 if self.numerator > 0 && self.denominator > 0 {
920 write!(f, "{}/{}", self.numerator, self.denominator)
921 } else {
922 write!(f, "none")
923 }
924 }
925}
926
927#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
929pub enum FeeType {
930 SolReferral(u8),
932 StakeReferral(u8),
934 Epoch(Fee),
936 StakeWithdrawal(Fee),
938 SolDeposit(Fee),
940 StakeDeposit(Fee),
942 SolWithdrawal(Fee),
944}
945
946impl FeeType {
947 pub fn check_too_high(&self) -> Result<(), StakePoolError> {
949 let too_high = match self {
950 Self::SolReferral(pct) => *pct > 100u8,
951 Self::StakeReferral(pct) => *pct > 100u8,
952 Self::Epoch(fee) => fee.numerator > fee.denominator,
953 Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
954 Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
955 Self::SolDeposit(fee) => fee.numerator > fee.denominator,
956 Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
957 };
958 if too_high {
959 msg!("Fee greater than 100%: {:?}", self);
960 return Err(StakePoolError::FeeTooHigh);
961 }
962 Ok(())
963 }
964
965 #[inline]
967 pub fn can_only_change_next_epoch(&self) -> bool {
968 matches!(
969 self,
970 Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
971 )
972 }
973}
974
975#[cfg(test)]
976mod test {
977 #![allow(clippy::integer_arithmetic)]
978 use {
979 super::*,
980 proptest::prelude::*,
981 solana_program::{
982 borsh::{get_instance_packed_len, get_packed_len, try_from_slice_unchecked},
983 clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
984 native_token::LAMPORTS_PER_SOL,
985 },
986 };
987
988 fn uninitialized_validator_list() -> ValidatorList {
989 ValidatorList {
990 header: ValidatorListHeader {
991 account_type: AccountType::Uninitialized,
992 max_validators: 0,
993 },
994 validators: vec![],
995 }
996 }
997
998 fn test_validator_list(max_validators: u32) -> ValidatorList {
999 ValidatorList {
1000 header: ValidatorListHeader {
1001 account_type: AccountType::ValidatorList,
1002 max_validators,
1003 },
1004 validators: vec![
1005 ValidatorStakeInfo {
1006 status: StakeStatus::Active,
1007 vote_account_address: Pubkey::new_from_array([1; 32]),
1008 active_stake_lamports: u64::from_le_bytes([255; 8]),
1009 transient_stake_lamports: u64::from_le_bytes([128; 8]),
1010 last_update_epoch: u64::from_le_bytes([64; 8]),
1011 transient_seed_suffix: 0,
1012 unused: 0,
1013 validator_seed_suffix: 0,
1014 },
1015 ValidatorStakeInfo {
1016 status: StakeStatus::DeactivatingTransient,
1017 vote_account_address: Pubkey::new_from_array([2; 32]),
1018 active_stake_lamports: 998877665544,
1019 transient_stake_lamports: 222222222,
1020 last_update_epoch: 11223445566,
1021 transient_seed_suffix: 0,
1022 unused: 0,
1023 validator_seed_suffix: 0,
1024 },
1025 ValidatorStakeInfo {
1026 status: StakeStatus::ReadyForRemoval,
1027 vote_account_address: Pubkey::new_from_array([3; 32]),
1028 active_stake_lamports: 0,
1029 transient_stake_lamports: 0,
1030 last_update_epoch: 999999999999999,
1031 transient_seed_suffix: 0,
1032 unused: 0,
1033 validator_seed_suffix: 0,
1034 },
1035 ],
1036 }
1037 }
1038
1039 #[test]
1040 fn state_packing() {
1041 let max_validators = 10_000;
1042 let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1043 let stake_list = uninitialized_validator_list();
1044 let mut byte_vec = vec![0u8; size];
1045 let mut bytes = byte_vec.as_mut_slice();
1046 stake_list.serialize(&mut bytes).unwrap();
1047 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1048 assert_eq!(stake_list_unpacked, stake_list);
1049
1050 let stake_list = ValidatorList {
1052 header: ValidatorListHeader {
1053 account_type: AccountType::ValidatorList,
1054 max_validators: 0,
1055 },
1056 validators: vec![],
1057 };
1058 let mut byte_vec = vec![0u8; size];
1059 let mut bytes = byte_vec.as_mut_slice();
1060 stake_list.serialize(&mut bytes).unwrap();
1061 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1062 assert_eq!(stake_list_unpacked, stake_list);
1063
1064 let stake_list = test_validator_list(max_validators);
1066 let mut byte_vec = vec![0u8; size];
1067 let mut bytes = byte_vec.as_mut_slice();
1068 stake_list.serialize(&mut bytes).unwrap();
1069 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1070 assert_eq!(stake_list_unpacked, stake_list);
1071 }
1072
1073 #[test]
1074 fn validator_list_active_stake() {
1075 let max_validators = 10_000;
1076 let mut validator_list = test_validator_list(max_validators);
1077 assert!(validator_list.has_active_stake());
1078 for validator in validator_list.validators.iter_mut() {
1079 validator.active_stake_lamports = 0;
1080 }
1081 assert!(!validator_list.has_active_stake());
1082 }
1083
1084 #[test]
1085 fn validator_list_deserialize_mut_slice() {
1086 let max_validators = 10;
1087 let stake_list = test_validator_list(max_validators);
1088 let mut serialized = stake_list.try_to_vec().unwrap();
1089 let (header, list) = ValidatorListHeader::deserialize_mut_slice(
1090 &mut serialized,
1091 0,
1092 stake_list.validators.len(),
1093 )
1094 .unwrap();
1095 assert_eq!(header.account_type, AccountType::ValidatorList);
1096 assert_eq!(header.max_validators, max_validators);
1097 assert!(list
1098 .iter()
1099 .zip(stake_list.validators.iter())
1100 .all(|(a, b)| *a == b));
1101
1102 let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 1, 2).unwrap();
1103 assert!(list
1104 .iter()
1105 .zip(stake_list.validators[1..].iter())
1106 .all(|(a, b)| *a == b));
1107 let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 2, 1).unwrap();
1108 assert!(list
1109 .iter()
1110 .zip(stake_list.validators[2..].iter())
1111 .all(|(a, b)| *a == b));
1112 let (_, list) = ValidatorListHeader::deserialize_mut_slice(&mut serialized, 0, 2).unwrap();
1113 assert!(list
1114 .iter()
1115 .zip(stake_list.validators[..2].iter())
1116 .all(|(a, b)| *a == b));
1117
1118 assert_eq!(
1119 ValidatorListHeader::deserialize_mut_slice(&mut serialized, 0, 4).unwrap_err(),
1120 ProgramError::AccountDataTooSmall
1121 );
1122 assert_eq!(
1123 ValidatorListHeader::deserialize_mut_slice(&mut serialized, 1, 3).unwrap_err(),
1124 ProgramError::AccountDataTooSmall
1125 );
1126 }
1127
1128 #[test]
1129 fn validator_list_iter() {
1130 let max_validators = 10;
1131 let stake_list = test_validator_list(max_validators);
1132 let mut serialized = stake_list.try_to_vec().unwrap();
1133 let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1134 for (a, b) in big_vec
1135 .iter::<ValidatorStakeInfo>()
1136 .zip(stake_list.validators.iter())
1137 {
1138 assert_eq!(a, b);
1139 }
1140 }
1141
1142 proptest! {
1143 #[test]
1144 fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1145 let validators = ValidatorList::new(test_amount);
1146 let size = get_instance_packed_len(&validators).unwrap();
1147 assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1148 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1149 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1150 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1151 }
1152 }
1153
1154 prop_compose! {
1155 fn fee()(denominator in 1..=u16::MAX)(
1156 denominator in Just(denominator),
1157 numerator in 0..=denominator,
1158 ) -> (u64, u64) {
1159 (numerator as u64, denominator as u64)
1160 }
1161 }
1162
1163 prop_compose! {
1164 fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1165 total_lamports in Just(total_lamports),
1166 rewards in 0..=total_lamports,
1167 ) -> (u64, u64) {
1168 (total_lamports - rewards, rewards)
1169 }
1170 }
1171
1172 #[test]
1173 fn specific_fee_calculation() {
1174 let epoch_fee = Fee {
1176 numerator: 1,
1177 denominator: 10,
1178 };
1179 let mut stake_pool = StakePool {
1180 total_lamports: 100 * LAMPORTS_PER_SOL,
1181 pool_token_supply: 100 * LAMPORTS_PER_SOL,
1182 epoch_fee,
1183 ..StakePool::default()
1184 };
1185 let reward_lamports = 10 * LAMPORTS_PER_SOL;
1186 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1187
1188 stake_pool.total_lamports += reward_lamports;
1189 stake_pool.pool_token_supply += pool_token_fee;
1190
1191 let fee_lamports = stake_pool
1192 .calc_lamports_withdraw_amount(pool_token_fee)
1193 .unwrap();
1194 assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); }
1196
1197 #[test]
1198 fn zero_withdraw_calculation() {
1199 let epoch_fee = Fee {
1200 numerator: 0,
1201 denominator: 1,
1202 };
1203 let stake_pool = StakePool {
1204 epoch_fee,
1205 ..StakePool::default()
1206 };
1207 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1208 assert_eq!(fee_lamports, 0);
1209 }
1210
1211 #[test]
1212 fn divide_by_zero_fee() {
1213 let stake_pool = StakePool {
1214 total_lamports: 0,
1215 epoch_fee: Fee {
1216 numerator: 1,
1217 denominator: 10,
1218 },
1219 ..StakePool::default()
1220 };
1221 let rewards = 10;
1222 let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1223 assert_eq!(fee, rewards);
1224 }
1225
1226 #[test]
1227 fn approximate_apr_calculation() {
1228 let stake_pool = StakePool {
1230 last_epoch_total_lamports: 100_000,
1231 last_epoch_pool_token_supply: 100_000,
1232 total_lamports: 100_044,
1233 pool_token_supply: 100_000,
1234 ..StakePool::default()
1235 };
1236 let pool_token_value =
1237 stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1238 let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1239 / stake_pool.last_epoch_pool_token_supply as f64;
1240 let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1241 const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1242 const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1243 const EPSILON: f64 = 0.00001;
1244 let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1245 assert!((yearly_rate - 0.080355).abs() < EPSILON);
1246 }
1247
1248 proptest! {
1249 #[test]
1250 fn fee_calculation(
1251 (numerator, denominator) in fee(),
1252 (total_lamports, reward_lamports) in total_stake_and_rewards(),
1253 ) {
1254 let epoch_fee = Fee { denominator, numerator };
1255 let mut stake_pool = StakePool {
1256 total_lamports,
1257 pool_token_supply: total_lamports,
1258 epoch_fee,
1259 ..StakePool::default()
1260 };
1261 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1262
1263 stake_pool.total_lamports += reward_lamports;
1264 stake_pool.pool_token_supply += pool_token_fee;
1265
1266 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1267 let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1268 assert!(max_fee_lamports >= fee_lamports,
1269 "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1270 max_fee_lamports,
1271 fee_lamports);
1272
1273 let epsilon = 2 + reward_lamports / total_lamports;
1277 assert!(max_fee_lamports - fee_lamports <= epsilon,
1278 "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1279 max_fee_lamports, fee_lamports, epsilon);
1280 }
1281 }
1282
1283 prop_compose! {
1284 fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1285 total_lamports in Just(total_lamports),
1286 pool_token_supply in 1..=total_lamports,
1287 deposit_lamports in 1..total_lamports,
1288 ) -> (u64, u64, u64) {
1289 (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1290 }
1291 }
1292
1293 proptest! {
1294 #[test]
1295 fn deposit_and_withdraw(
1296 (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1297 ) {
1298 let mut stake_pool = StakePool {
1299 total_lamports,
1300 pool_token_supply,
1301 ..StakePool::default()
1302 };
1303 let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1304 prop_assume!(deposit_result > 0);
1305 stake_pool.total_lamports += deposit_stake;
1306 stake_pool.pool_token_supply += deposit_result;
1307 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1308 assert!(withdraw_result <= deposit_stake);
1309
1310 if deposit_result >= 2 {
1312 let first_half_deposit = deposit_result / 2;
1313 let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1314 stake_pool.total_lamports -= first_withdraw_result;
1315 stake_pool.pool_token_supply -= first_half_deposit;
1316 let second_half_deposit = deposit_result - first_half_deposit; let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1318 assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1319 }
1320 }
1321 }
1322
1323 #[test]
1324 fn specific_split_withdrawal() {
1325 let total_lamports = 1_100_000_000_000;
1326 let pool_token_supply = 1_000_000_000_000;
1327 let deposit_stake = 3;
1328 let mut stake_pool = StakePool {
1329 total_lamports,
1330 pool_token_supply,
1331 ..StakePool::default()
1332 };
1333 let deposit_result = stake_pool
1334 .calc_pool_tokens_for_deposit(deposit_stake)
1335 .unwrap();
1336 assert!(deposit_result > 0);
1337 stake_pool.total_lamports += deposit_stake;
1338 stake_pool.pool_token_supply += deposit_result;
1339 let withdraw_result = stake_pool
1340 .calc_lamports_withdraw_amount(deposit_result / 2)
1341 .unwrap();
1342 assert!(withdraw_result * 2 <= deposit_stake);
1343 }
1344
1345 #[test]
1346 fn withdraw_all() {
1347 let total_lamports = 1_100_000_000_000;
1348 let pool_token_supply = 1_000_000_000_000;
1349 let mut stake_pool = StakePool {
1350 total_lamports,
1351 pool_token_supply,
1352 ..StakePool::default()
1353 };
1354 let withdraw_result = stake_pool
1356 .calc_lamports_withdraw_amount(pool_token_supply)
1357 .unwrap();
1358 assert_eq!(stake_pool.total_lamports, withdraw_result);
1359
1360 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1362 stake_pool.total_lamports -= withdraw_result;
1363 stake_pool.pool_token_supply -= 1;
1364 let withdraw_result = stake_pool
1365 .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1366 .unwrap();
1367 assert_eq!(stake_pool.total_lamports, withdraw_result);
1368
1369 let mut stake_pool = StakePool {
1371 total_lamports,
1372 pool_token_supply,
1373 ..StakePool::default()
1374 };
1375 let withdraw_result = stake_pool
1376 .calc_lamports_withdraw_amount(pool_token_supply - 1)
1377 .unwrap();
1378 stake_pool.total_lamports -= withdraw_result;
1379 stake_pool.pool_token_supply = 1;
1380 assert_ne!(stake_pool.total_lamports, 0);
1381
1382 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1383 assert_eq!(stake_pool.total_lamports, withdraw_result);
1384 }
1385}