1use {
4 crate::{
5 big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6 WITHDRAWAL_BASELINE_FEE,
7 },
8 borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9 bytemuck::{Pod, Zeroable},
10 num_derive::{FromPrimitive, ToPrimitive},
11 num_traits::{FromPrimitive, ToPrimitive},
12 solana_program::{
13 account_info::AccountInfo,
14 borsh1::get_instance_packed_len,
15 msg,
16 program_error::ProgramError,
17 program_memory::sol_memcmp,
18 program_pack::{Pack, Sealed},
19 pubkey::{Pubkey, PUBKEY_BYTES},
20 stake::state::Lockup,
21 },
22 spl_pod::primitives::{PodU32, PodU64},
23 spl_token_2022::{
24 extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
25 state::{Account, AccountState, Mint},
26 },
27 std::{borrow::Borrow, convert::TryFrom, fmt, matches},
28};
29
30#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
32pub enum AccountType {
33 #[default]
35 Uninitialized,
36 StakePool,
38 ValidatorList,
40}
41
42#[repr(C)]
44#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
45pub struct StakePool {
46 pub account_type: AccountType,
48
49 pub manager: Pubkey,
52
53 pub staker: Pubkey,
56
57 pub stake_deposit_authority: Pubkey,
67
68 pub stake_withdraw_bump_seed: u8,
71
72 pub validator_list: Pubkey,
74
75 pub reserve_stake: Pubkey,
77
78 pub pool_mint: Pubkey,
80
81 pub manager_fee_account: Pubkey,
83
84 pub token_program_id: Pubkey,
86
87 pub total_lamports: u64,
91
92 pub pool_token_supply: u64,
95
96 pub last_update_epoch: u64,
98
99 pub lockup: Lockup,
101
102 pub epoch_fee: Fee,
104
105 pub next_epoch_fee: FutureEpoch<Fee>,
107
108 pub preferred_deposit_validator_vote_address: Option<Pubkey>,
110
111 pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
113
114 pub stake_deposit_fee: Fee,
116
117 pub stake_withdrawal_fee: Fee,
119
120 pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
122
123 pub stake_referral_fee: u8,
129
130 pub sol_deposit_authority: Option<Pubkey>,
133
134 pub sol_deposit_fee: Fee,
136
137 pub sol_referral_fee: u8,
143
144 pub sol_withdraw_authority: Option<Pubkey>,
147
148 pub sol_withdrawal_fee: Fee,
150
151 pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
153
154 pub last_epoch_pool_token_supply: u64,
156
157 pub last_epoch_total_lamports: u64,
159}
160impl StakePool {
161 #[inline]
164 pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
165 if self.total_lamports == 0 || self.pool_token_supply == 0 {
166 return Some(stake_lamports);
167 }
168 u64::try_from(
169 (stake_lamports as u128)
170 .checked_mul(self.pool_token_supply as u128)?
171 .checked_div(self.total_lamports as u128)?,
172 )
173 .ok()
174 }
175
176 #[inline]
178 pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
179 let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
183 let denominator = self.pool_token_supply as u128;
184 if numerator < denominator || denominator == 0 {
185 Some(0)
186 } else {
187 u64::try_from(numerator.checked_div(denominator)?).ok()
188 }
189 }
190
191 #[inline]
193 pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
194 u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
195 }
196
197 #[inline]
199 pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
200 u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
201 }
202
203 #[inline]
205 pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
206 u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
207 }
208
209 #[inline]
211 pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
212 u64::try_from(
213 (stake_deposit_fee as u128)
214 .checked_mul(self.stake_referral_fee as u128)?
215 .checked_div(100u128)?,
216 )
217 .ok()
218 }
219
220 #[inline]
222 pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
223 u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
224 }
225
226 #[inline]
229 pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
230 u64::try_from(
231 (sol_deposit_fee as u128)
232 .checked_mul(self.sol_referral_fee as u128)?
233 .checked_div(100u128)?,
234 )
235 .ok()
236 }
237
238 #[inline]
243 pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
244 if reward_lamports == 0 {
245 return Some(0);
246 }
247 let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
248 let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
249 if total_lamports == fee_lamports || self.pool_token_supply == 0 {
250 Some(reward_lamports)
251 } else {
252 u64::try_from(
253 (self.pool_token_supply as u128)
254 .checked_mul(fee_lamports)?
255 .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
256 )
257 .ok()
258 }
259 }
260
261 #[inline]
263 pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
264 self.total_lamports
265 .checked_add(self.pool_token_supply)?
266 .checked_sub(1)?
267 .checked_div(self.pool_token_supply)
268 }
269
270 fn check_program_derived_authority(
272 authority_address: &Pubkey,
273 program_id: &Pubkey,
274 stake_pool_address: &Pubkey,
275 authority_seed: &[u8],
276 bump_seed: u8,
277 ) -> Result<(), ProgramError> {
278 let expected_address = Pubkey::create_program_address(
279 &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
280 program_id,
281 )?;
282
283 if *authority_address == expected_address {
284 Ok(())
285 } else {
286 msg!(
287 "Incorrect authority provided, expected {}, received {}",
288 expected_address,
289 authority_address
290 );
291 Err(StakePoolError::InvalidProgramAddress.into())
292 }
293 }
294
295 pub(crate) fn check_manager_fee_info(
298 &self,
299 manager_fee_info: &AccountInfo,
300 ) -> Result<(), ProgramError> {
301 let account_data = manager_fee_info.try_borrow_data()?;
302 let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
303 if manager_fee_info.owner != &self.token_program_id
304 || token_account.base.state != AccountState::Initialized
305 || token_account.base.mint != self.pool_mint
306 {
307 msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
308 return Err(StakePoolError::InvalidFeeAccount.into());
309 }
310 let extensions = token_account.get_extension_types()?;
311 if extensions
312 .iter()
313 .any(|x| !is_extension_supported_for_fee_account(x))
314 {
315 return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
316 }
317 Ok(())
318 }
319
320 #[inline]
322 pub(crate) fn check_authority_withdraw(
323 &self,
324 withdraw_authority: &Pubkey,
325 program_id: &Pubkey,
326 stake_pool_address: &Pubkey,
327 ) -> Result<(), ProgramError> {
328 Self::check_program_derived_authority(
329 withdraw_authority,
330 program_id,
331 stake_pool_address,
332 crate::AUTHORITY_WITHDRAW,
333 self.stake_withdraw_bump_seed,
334 )
335 }
336 #[inline]
338 pub(crate) fn check_stake_deposit_authority(
339 &self,
340 stake_deposit_authority: &Pubkey,
341 ) -> Result<(), ProgramError> {
342 if self.stake_deposit_authority == *stake_deposit_authority {
343 Ok(())
344 } else {
345 Err(StakePoolError::InvalidStakeDepositAuthority.into())
346 }
347 }
348
349 #[inline]
352 pub(crate) fn check_sol_deposit_authority(
353 &self,
354 maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
355 ) -> Result<(), ProgramError> {
356 if let Some(auth) = self.sol_deposit_authority {
357 let sol_deposit_authority = maybe_sol_deposit_authority?;
358 if auth != *sol_deposit_authority.key {
359 msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
360 return Err(StakePoolError::InvalidSolDepositAuthority.into());
361 }
362 if !sol_deposit_authority.is_signer {
363 msg!("SOL Deposit authority signature missing");
364 return Err(StakePoolError::SignatureMissing.into());
365 }
366 }
367 Ok(())
368 }
369
370 #[inline]
373 pub(crate) fn check_sol_withdraw_authority(
374 &self,
375 maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
376 ) -> Result<(), ProgramError> {
377 if let Some(auth) = self.sol_withdraw_authority {
378 let sol_withdraw_authority = maybe_sol_withdraw_authority?;
379 if auth != *sol_withdraw_authority.key {
380 return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
381 }
382 if !sol_withdraw_authority.is_signer {
383 msg!("SOL withdraw authority signature missing");
384 return Err(StakePoolError::SignatureMissing.into());
385 }
386 }
387 Ok(())
388 }
389
390 #[inline]
392 pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
393 if *mint_info.key != self.pool_mint {
394 Err(StakePoolError::WrongPoolMint.into())
395 } else {
396 let mint_data = mint_info.try_borrow_data()?;
397 let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
398 Ok(mint.base.decimals)
399 }
400 }
401
402 pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
404 if *manager_info.key != self.manager {
405 msg!(
406 "Incorrect manager provided, expected {}, received {}",
407 self.manager,
408 manager_info.key
409 );
410 return Err(StakePoolError::WrongManager.into());
411 }
412 if !manager_info.is_signer {
413 msg!("Manager signature missing");
414 return Err(StakePoolError::SignatureMissing.into());
415 }
416 Ok(())
417 }
418
419 pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
421 if *staker_info.key != self.staker {
422 msg!(
423 "Incorrect staker provided, expected {}, received {}",
424 self.staker,
425 staker_info.key
426 );
427 return Err(StakePoolError::WrongStaker.into());
428 }
429 if !staker_info.is_signer {
430 msg!("Staker signature missing");
431 return Err(StakePoolError::SignatureMissing.into());
432 }
433 Ok(())
434 }
435
436 pub fn check_validator_list(
438 &self,
439 validator_list_info: &AccountInfo,
440 ) -> Result<(), ProgramError> {
441 if *validator_list_info.key != self.validator_list {
442 msg!(
443 "Invalid validator list provided, expected {}, received {}",
444 self.validator_list,
445 validator_list_info.key
446 );
447 Err(StakePoolError::InvalidValidatorStakeList.into())
448 } else {
449 Ok(())
450 }
451 }
452
453 pub fn check_reserve_stake(
455 &self,
456 reserve_stake_info: &AccountInfo,
457 ) -> Result<(), ProgramError> {
458 if *reserve_stake_info.key != self.reserve_stake {
459 msg!(
460 "Invalid reserve stake provided, expected {}, received {}",
461 self.reserve_stake,
462 reserve_stake_info.key
463 );
464 Err(StakePoolError::InvalidProgramAddress.into())
465 } else {
466 Ok(())
467 }
468 }
469
470 pub fn is_valid(&self) -> bool {
472 self.account_type == AccountType::StakePool
473 }
474
475 pub fn is_uninitialized(&self) -> bool {
477 self.account_type == AccountType::Uninitialized
478 }
479
480 pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
482 match fee {
483 FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
484 FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
485 FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
486 FeeType::StakeWithdrawal(new_fee) => {
487 new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
488 self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
489 }
490 FeeType::SolWithdrawal(new_fee) => {
491 new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
492 self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
493 }
494 FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
495 FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
496 };
497 Ok(())
498 }
499}
500
501pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
503 const SUPPORTED_EXTENSIONS: [ExtensionType; 8] = [
504 ExtensionType::Uninitialized,
505 ExtensionType::TransferFeeConfig,
506 ExtensionType::ConfidentialTransferMint,
507 ExtensionType::ConfidentialTransferFeeConfig,
508 ExtensionType::DefaultAccountState, ExtensionType::InterestBearingConfig,
510 ExtensionType::MetadataPointer,
511 ExtensionType::TokenMetadata,
512 ];
513 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
514 msg!(
515 "Stake pool mint account cannot have the {:?} extension",
516 extension_type
517 );
518 false
519 } else {
520 true
521 }
522}
523
524pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
526 const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
530 ExtensionType::Uninitialized,
531 ExtensionType::TransferFeeAmount,
532 ExtensionType::ImmutableOwner,
533 ExtensionType::CpiGuard,
534 ];
535 if !SUPPORTED_EXTENSIONS.contains(extension_type) {
536 msg!("Fee account cannot have the {:?} extension", extension_type);
537 false
538 } else {
539 true
540 }
541}
542
543#[repr(C)]
545#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
546pub struct ValidatorList {
547 pub header: ValidatorListHeader,
550
551 pub validators: Vec<ValidatorStakeInfo>,
553}
554
555#[repr(C)]
557#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
558pub struct ValidatorListHeader {
559 pub account_type: AccountType,
561
562 pub max_validators: u32,
564}
565
566#[derive(
568 ToPrimitive,
569 FromPrimitive,
570 Copy,
571 Clone,
572 Debug,
573 PartialEq,
574 BorshDeserialize,
575 BorshSerialize,
576 BorshSchema,
577)]
578pub enum StakeStatus {
579 Active,
581 DeactivatingTransient,
584 ReadyForRemoval,
587 DeactivatingValidator,
590 DeactivatingAll,
593}
594impl Default for StakeStatus {
595 fn default() -> Self {
596 Self::Active
597 }
598}
599
600#[repr(transparent)]
603#[derive(
604 Clone,
605 Copy,
606 Debug,
607 Default,
608 PartialEq,
609 Pod,
610 Zeroable,
611 BorshDeserialize,
612 BorshSerialize,
613 BorshSchema,
614)]
615pub struct PodStakeStatus(u8);
616impl PodStakeStatus {
617 pub fn remove_validator_stake(&mut self) -> Result<(), ProgramError> {
620 let status = StakeStatus::try_from(*self)?;
621 let new_self = match status {
622 StakeStatus::Active
623 | StakeStatus::DeactivatingTransient
624 | StakeStatus::ReadyForRemoval => status,
625 StakeStatus::DeactivatingAll => StakeStatus::DeactivatingTransient,
626 StakeStatus::DeactivatingValidator => StakeStatus::ReadyForRemoval,
627 };
628 *self = new_self.into();
629 Ok(())
630 }
631 pub fn remove_transient_stake(&mut self) -> Result<(), ProgramError> {
634 let status = StakeStatus::try_from(*self)?;
635 let new_self = match status {
636 StakeStatus::Active
637 | StakeStatus::DeactivatingValidator
638 | StakeStatus::ReadyForRemoval => status,
639 StakeStatus::DeactivatingAll => StakeStatus::DeactivatingValidator,
640 StakeStatus::DeactivatingTransient => StakeStatus::ReadyForRemoval,
641 };
642 *self = new_self.into();
643 Ok(())
644 }
645}
646impl TryFrom<PodStakeStatus> for StakeStatus {
647 type Error = ProgramError;
648 fn try_from(pod: PodStakeStatus) -> Result<Self, Self::Error> {
649 FromPrimitive::from_u8(pod.0).ok_or(ProgramError::InvalidAccountData)
650 }
651}
652impl From<StakeStatus> for PodStakeStatus {
653 fn from(status: StakeStatus) -> Self {
654 PodStakeStatus(status.to_u8().unwrap())
657 }
658}
659
660#[derive(Debug, PartialEq)]
662pub(crate) enum StakeWithdrawSource {
663 Active,
665 Transient,
667 ValidatorRemoval,
669}
670
671#[repr(C)]
680#[derive(
681 Clone,
682 Copy,
683 Debug,
684 Default,
685 PartialEq,
686 Pod,
687 Zeroable,
688 BorshDeserialize,
689 BorshSerialize,
690 BorshSchema,
691)]
692pub struct ValidatorStakeInfo {
693 pub active_stake_lamports: PodU64,
698
699 pub transient_stake_lamports: PodU64,
704
705 pub last_update_epoch: PodU64,
707
708 pub transient_seed_suffix: PodU64,
711
712 pub unused: PodU32,
714
715 pub validator_seed_suffix: PodU32, pub status: PodStakeStatus,
720
721 pub vote_account_address: Pubkey,
723}
724
725impl ValidatorStakeInfo {
726 pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
728 u64::from(self.active_stake_lamports)
729 .checked_add(self.transient_stake_lamports.into())
730 .ok_or(StakePoolError::CalculationFailure)
731 }
732
733 pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
736 sol_memcmp(
737 &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
738 vote_address.as_ref(),
739 PUBKEY_BYTES,
740 ) == 0
741 }
742
743 pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
746 u64::try_from_slice(&data[0..8]).unwrap() > *lamports
748 }
749
750 pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
753 u64::try_from_slice(&data[8..16]).unwrap() > *lamports
755 }
756
757 pub fn is_not_removed(data: &[u8]) -> bool {
759 FromPrimitive::from_u8(data[40]) != Some(StakeStatus::ReadyForRemoval)
760 }
761}
762
763impl Sealed for ValidatorStakeInfo {}
764
765impl Pack for ValidatorStakeInfo {
766 const LEN: usize = 73;
767 fn pack_into_slice(&self, data: &mut [u8]) {
768 borsh::to_writer(data, self).unwrap();
771 }
772 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
773 let unpacked = Self::try_from_slice(src)?;
774 Ok(unpacked)
775 }
776}
777
778impl ValidatorList {
779 pub fn new(max_validators: u32) -> Self {
782 Self {
783 header: ValidatorListHeader {
784 account_type: AccountType::ValidatorList,
785 max_validators,
786 },
787 validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
788 }
789 }
790
791 pub fn calculate_max_validators(buffer_length: usize) -> usize {
794 let header_size = ValidatorListHeader::LEN.saturating_add(4);
795 buffer_length
796 .saturating_sub(header_size)
797 .saturating_div(ValidatorStakeInfo::LEN)
798 }
799
800 pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
802 self.validators
803 .iter()
804 .any(|x| x.vote_account_address == *vote_account_address)
805 }
806
807 pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
809 self.validators
810 .iter_mut()
811 .find(|x| x.vote_account_address == *vote_account_address)
812 }
813 pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
815 self.validators
816 .iter()
817 .find(|x| x.vote_account_address == *vote_account_address)
818 }
819
820 pub fn has_active_stake(&self) -> bool {
822 self.validators
823 .iter()
824 .any(|x| u64::from(x.active_stake_lamports) > 0)
825 }
826}
827
828impl ValidatorListHeader {
829 const LEN: usize = 1 + 4;
830
831 pub fn is_valid(&self) -> bool {
834 self.account_type == AccountType::ValidatorList
835 }
836
837 pub fn is_uninitialized(&self) -> bool {
839 self.account_type == AccountType::Uninitialized
840 }
841
842 pub fn deserialize_mut_slice<'a>(
845 big_vec: &'a mut BigVec,
846 skip: usize,
847 len: usize,
848 ) -> Result<&'a mut [ValidatorStakeInfo], ProgramError> {
849 big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)
850 }
851
852 pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec), ProgramError> {
854 let mut data_mut = data.borrow();
855 let header = ValidatorListHeader::deserialize(&mut data_mut)?;
856 let length = get_instance_packed_len(&header)?;
857
858 let big_vec = BigVec {
859 data: &mut data[length..],
860 };
861 Ok((header, big_vec))
862 }
863}
864
865#[repr(C)]
868#[derive(Clone, Copy, Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
869pub enum FutureEpoch<T> {
870 None,
872 One(T),
874 Two(T),
876}
877impl<T> Default for FutureEpoch<T> {
878 fn default() -> Self {
879 Self::None
880 }
881}
882impl<T> FutureEpoch<T> {
883 pub fn new(value: T) -> Self {
885 Self::Two(value)
886 }
887}
888impl<T: Clone> FutureEpoch<T> {
889 pub fn update_epoch(&mut self) {
891 match self {
892 Self::None => {}
893 Self::One(_) => {
894 *self = Self::None;
896 }
897 Self::Two(v) => {
899 *self = Self::One(v.clone());
900 }
901 }
902 }
903
904 pub fn get(&self) -> Option<&T> {
906 match self {
907 Self::None | Self::Two(_) => None,
908 Self::One(v) => Some(v),
909 }
910 }
911}
912impl<T> From<FutureEpoch<T>> for Option<T> {
913 fn from(v: FutureEpoch<T>) -> Option<T> {
914 match v {
915 FutureEpoch::None => None,
916 FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
917 }
918 }
919}
920
921#[repr(C)]
926#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
927pub struct Fee {
928 pub denominator: u64,
930 pub numerator: u64,
932}
933
934impl Fee {
935 #[inline]
940 pub fn apply(&self, amt: u64) -> Option<u128> {
941 if self.denominator == 0 {
942 return Some(0);
943 }
944 let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
945 let denominator = self.denominator as u128;
947 numerator
948 .checked_add(denominator)?
949 .checked_sub(1)?
950 .checked_div(denominator)
951 }
952
953 pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
956 let (old_num, old_denom) =
959 if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
960 (
961 WITHDRAWAL_BASELINE_FEE.numerator,
962 WITHDRAWAL_BASELINE_FEE.denominator,
963 )
964 } else {
965 (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
966 };
967
968 if (old_num as u128)
972 .checked_mul(self.denominator as u128)
973 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
974 .ok_or(StakePoolError::CalculationFailure)?
975 < (self.numerator as u128)
976 .checked_mul(old_denom as u128)
977 .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
978 .ok_or(StakePoolError::CalculationFailure)?
979 {
980 msg!(
981 "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
982 self.numerator.saturating_mul(old_denom),
983 old_num.saturating_mul(self.denominator),
984 );
985 return Err(StakePoolError::FeeIncreaseTooHigh);
986 }
987 Ok(())
988 }
989}
990
991impl fmt::Display for Fee {
992 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
993 if self.numerator > 0 && self.denominator > 0 {
994 write!(f, "{}/{}", self.numerator, self.denominator)
995 } else {
996 write!(f, "none")
997 }
998 }
999}
1000
1001#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
1003pub enum FeeType {
1004 SolReferral(u8),
1006 StakeReferral(u8),
1008 Epoch(Fee),
1010 StakeWithdrawal(Fee),
1012 SolDeposit(Fee),
1014 StakeDeposit(Fee),
1016 SolWithdrawal(Fee),
1018}
1019
1020impl FeeType {
1021 pub fn check_too_high(&self) -> Result<(), StakePoolError> {
1023 let too_high = match self {
1024 Self::SolReferral(pct) => *pct > 100u8,
1025 Self::StakeReferral(pct) => *pct > 100u8,
1026 Self::Epoch(fee) => fee.numerator > fee.denominator,
1027 Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
1028 Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
1029 Self::SolDeposit(fee) => fee.numerator > fee.denominator,
1030 Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
1031 };
1032 if too_high {
1033 msg!("Fee greater than 100%: {:?}", self);
1034 return Err(StakePoolError::FeeTooHigh);
1035 }
1036 Ok(())
1037 }
1038
1039 #[inline]
1042 pub fn can_only_change_next_epoch(&self) -> bool {
1043 matches!(
1044 self,
1045 Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
1046 )
1047 }
1048}
1049
1050#[cfg(test)]
1051mod test {
1052 #![allow(clippy::arithmetic_side_effects)]
1053 use {
1054 super::*,
1055 proptest::prelude::*,
1056 solana_program::{
1057 borsh1::{get_packed_len, try_from_slice_unchecked},
1058 clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
1059 native_token::LAMPORTS_PER_SOL,
1060 },
1061 };
1062
1063 fn uninitialized_validator_list() -> ValidatorList {
1064 ValidatorList {
1065 header: ValidatorListHeader {
1066 account_type: AccountType::Uninitialized,
1067 max_validators: 0,
1068 },
1069 validators: vec![],
1070 }
1071 }
1072
1073 fn test_validator_list(max_validators: u32) -> ValidatorList {
1074 ValidatorList {
1075 header: ValidatorListHeader {
1076 account_type: AccountType::ValidatorList,
1077 max_validators,
1078 },
1079 validators: vec![
1080 ValidatorStakeInfo {
1081 status: StakeStatus::Active.into(),
1082 vote_account_address: Pubkey::new_from_array([1; 32]),
1083 active_stake_lamports: u64::from_le_bytes([255; 8]).into(),
1084 transient_stake_lamports: u64::from_le_bytes([128; 8]).into(),
1085 last_update_epoch: u64::from_le_bytes([64; 8]).into(),
1086 transient_seed_suffix: 0.into(),
1087 unused: 0.into(),
1088 validator_seed_suffix: 0.into(),
1089 },
1090 ValidatorStakeInfo {
1091 status: StakeStatus::DeactivatingTransient.into(),
1092 vote_account_address: Pubkey::new_from_array([2; 32]),
1093 active_stake_lamports: 998877665544.into(),
1094 transient_stake_lamports: 222222222.into(),
1095 last_update_epoch: 11223445566.into(),
1096 transient_seed_suffix: 0.into(),
1097 unused: 0.into(),
1098 validator_seed_suffix: 0.into(),
1099 },
1100 ValidatorStakeInfo {
1101 status: StakeStatus::ReadyForRemoval.into(),
1102 vote_account_address: Pubkey::new_from_array([3; 32]),
1103 active_stake_lamports: 0.into(),
1104 transient_stake_lamports: 0.into(),
1105 last_update_epoch: 999999999999999.into(),
1106 transient_seed_suffix: 0.into(),
1107 unused: 0.into(),
1108 validator_seed_suffix: 0.into(),
1109 },
1110 ],
1111 }
1112 }
1113
1114 #[test]
1115 fn state_packing() {
1116 let max_validators = 10_000;
1117 let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1118 let stake_list = uninitialized_validator_list();
1119 let mut byte_vec = vec![0u8; size];
1120 let bytes = byte_vec.as_mut_slice();
1121 borsh::to_writer(bytes, &stake_list).unwrap();
1122 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1123 assert_eq!(stake_list_unpacked, stake_list);
1124
1125 let stake_list = ValidatorList {
1127 header: ValidatorListHeader {
1128 account_type: AccountType::ValidatorList,
1129 max_validators: 0,
1130 },
1131 validators: vec![],
1132 };
1133 let mut byte_vec = vec![0u8; size];
1134 let bytes = byte_vec.as_mut_slice();
1135 borsh::to_writer(bytes, &stake_list).unwrap();
1136 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1137 assert_eq!(stake_list_unpacked, stake_list);
1138
1139 let stake_list = test_validator_list(max_validators);
1141 let mut byte_vec = vec![0u8; size];
1142 let bytes = byte_vec.as_mut_slice();
1143 borsh::to_writer(bytes, &stake_list).unwrap();
1144 let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1145 assert_eq!(stake_list_unpacked, stake_list);
1146 }
1147
1148 #[test]
1149 fn validator_list_active_stake() {
1150 let max_validators = 10_000;
1151 let mut validator_list = test_validator_list(max_validators);
1152 assert!(validator_list.has_active_stake());
1153 for validator in validator_list.validators.iter_mut() {
1154 validator.active_stake_lamports = 0.into();
1155 }
1156 assert!(!validator_list.has_active_stake());
1157 }
1158
1159 #[test]
1160 fn validator_list_deserialize_mut_slice() {
1161 let max_validators = 10;
1162 let stake_list = test_validator_list(max_validators);
1163 let mut serialized = borsh::to_vec(&stake_list).unwrap();
1164 let (header, mut big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1165 let list = ValidatorListHeader::deserialize_mut_slice(
1166 &mut big_vec,
1167 0,
1168 stake_list.validators.len(),
1169 )
1170 .unwrap();
1171 assert_eq!(header.account_type, AccountType::ValidatorList);
1172 assert_eq!(header.max_validators, max_validators);
1173 assert!(list
1174 .iter()
1175 .zip(stake_list.validators.iter())
1176 .all(|(a, b)| a == b));
1177
1178 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 2).unwrap();
1179 assert!(list
1180 .iter()
1181 .zip(stake_list.validators[1..].iter())
1182 .all(|(a, b)| a == b));
1183 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 2, 1).unwrap();
1184 assert!(list
1185 .iter()
1186 .zip(stake_list.validators[2..].iter())
1187 .all(|(a, b)| a == b));
1188 let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 2).unwrap();
1189 assert!(list
1190 .iter()
1191 .zip(stake_list.validators[..2].iter())
1192 .all(|(a, b)| a == b));
1193
1194 assert_eq!(
1195 ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 4).unwrap_err(),
1196 ProgramError::AccountDataTooSmall
1197 );
1198 assert_eq!(
1199 ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 3).unwrap_err(),
1200 ProgramError::AccountDataTooSmall
1201 );
1202 }
1203
1204 #[test]
1205 fn validator_list_iter() {
1206 let max_validators = 10;
1207 let stake_list = test_validator_list(max_validators);
1208 let mut serialized = borsh::to_vec(&stake_list).unwrap();
1209 let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1210 for (a, b) in big_vec
1211 .deserialize_slice::<ValidatorStakeInfo>(0, big_vec.len() as usize)
1212 .unwrap()
1213 .iter()
1214 .zip(stake_list.validators.iter())
1215 {
1216 assert_eq!(a, b);
1217 }
1218 }
1219
1220 proptest! {
1221 #[test]
1222 fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1223 let validators = ValidatorList::new(test_amount);
1224 let size = get_instance_packed_len(&validators).unwrap();
1225 assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1226 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1227 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1228 assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1229 }
1230 }
1231
1232 prop_compose! {
1233 fn fee()(denominator in 1..=u16::MAX)(
1234 denominator in Just(denominator),
1235 numerator in 0..=denominator,
1236 ) -> (u64, u64) {
1237 (numerator as u64, denominator as u64)
1238 }
1239 }
1240
1241 prop_compose! {
1242 fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1243 total_lamports in Just(total_lamports),
1244 rewards in 0..=total_lamports,
1245 ) -> (u64, u64) {
1246 (total_lamports - rewards, rewards)
1247 }
1248 }
1249
1250 #[test]
1251 fn specific_fee_calculation() {
1252 let epoch_fee = Fee {
1254 numerator: 1,
1255 denominator: 10,
1256 };
1257 let mut stake_pool = StakePool {
1258 total_lamports: 100 * LAMPORTS_PER_SOL,
1259 pool_token_supply: 100 * LAMPORTS_PER_SOL,
1260 epoch_fee,
1261 ..StakePool::default()
1262 };
1263 let reward_lamports = 10 * LAMPORTS_PER_SOL;
1264 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1265
1266 stake_pool.total_lamports += reward_lamports;
1267 stake_pool.pool_token_supply += pool_token_fee;
1268
1269 let fee_lamports = stake_pool
1270 .calc_lamports_withdraw_amount(pool_token_fee)
1271 .unwrap();
1272 assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); }
1275
1276 #[test]
1277 fn zero_withdraw_calculation() {
1278 let epoch_fee = Fee {
1279 numerator: 0,
1280 denominator: 1,
1281 };
1282 let stake_pool = StakePool {
1283 epoch_fee,
1284 ..StakePool::default()
1285 };
1286 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1287 assert_eq!(fee_lamports, 0);
1288 }
1289
1290 #[test]
1291 fn divide_by_zero_fee() {
1292 let stake_pool = StakePool {
1293 total_lamports: 0,
1294 epoch_fee: Fee {
1295 numerator: 1,
1296 denominator: 10,
1297 },
1298 ..StakePool::default()
1299 };
1300 let rewards = 10;
1301 let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1302 assert_eq!(fee, rewards);
1303 }
1304
1305 #[test]
1306 fn approximate_apr_calculation() {
1307 let stake_pool = StakePool {
1309 last_epoch_total_lamports: 100_000,
1310 last_epoch_pool_token_supply: 100_000,
1311 total_lamports: 100_044,
1312 pool_token_supply: 100_000,
1313 ..StakePool::default()
1314 };
1315 let pool_token_value =
1316 stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1317 let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1318 / stake_pool.last_epoch_pool_token_supply as f64;
1319 let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1320 const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1321 const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1322 const EPSILON: f64 = 0.00001;
1323 let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1324 assert!((yearly_rate - 0.080355).abs() < EPSILON);
1325 }
1326
1327 proptest! {
1328 #[test]
1329 fn fee_calculation(
1330 (numerator, denominator) in fee(),
1331 (total_lamports, reward_lamports) in total_stake_and_rewards(),
1332 ) {
1333 let epoch_fee = Fee { denominator, numerator };
1334 let mut stake_pool = StakePool {
1335 total_lamports,
1336 pool_token_supply: total_lamports,
1337 epoch_fee,
1338 ..StakePool::default()
1339 };
1340 let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1341
1342 stake_pool.total_lamports += reward_lamports;
1343 stake_pool.pool_token_supply += pool_token_fee;
1344
1345 let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1346 let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1347 assert!(max_fee_lamports >= fee_lamports,
1348 "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1349 max_fee_lamports,
1350 fee_lamports);
1351
1352 let epsilon = 2 + reward_lamports / total_lamports;
1356 assert!(max_fee_lamports - fee_lamports <= epsilon,
1357 "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1358 max_fee_lamports, fee_lamports, epsilon);
1359 }
1360 }
1361
1362 prop_compose! {
1363 fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1364 total_lamports in Just(total_lamports),
1365 pool_token_supply in 1..=total_lamports,
1366 deposit_lamports in 1..total_lamports,
1367 ) -> (u64, u64, u64) {
1368 (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1369 }
1370 }
1371
1372 proptest! {
1373 #[test]
1374 fn deposit_and_withdraw(
1375 (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1376 ) {
1377 let mut stake_pool = StakePool {
1378 total_lamports,
1379 pool_token_supply,
1380 ..StakePool::default()
1381 };
1382 let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1383 prop_assume!(deposit_result > 0);
1384 stake_pool.total_lamports += deposit_stake;
1385 stake_pool.pool_token_supply += deposit_result;
1386 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1387 assert!(withdraw_result <= deposit_stake);
1388
1389 if deposit_result >= 2 {
1391 let first_half_deposit = deposit_result / 2;
1392 let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1393 stake_pool.total_lamports -= first_withdraw_result;
1394 stake_pool.pool_token_supply -= first_half_deposit;
1395 let second_half_deposit = deposit_result - first_half_deposit; let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1397 assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1398 }
1399 }
1400 }
1401
1402 #[test]
1403 fn specific_split_withdrawal() {
1404 let total_lamports = 1_100_000_000_000;
1405 let pool_token_supply = 1_000_000_000_000;
1406 let deposit_stake = 3;
1407 let mut stake_pool = StakePool {
1408 total_lamports,
1409 pool_token_supply,
1410 ..StakePool::default()
1411 };
1412 let deposit_result = stake_pool
1413 .calc_pool_tokens_for_deposit(deposit_stake)
1414 .unwrap();
1415 assert!(deposit_result > 0);
1416 stake_pool.total_lamports += deposit_stake;
1417 stake_pool.pool_token_supply += deposit_result;
1418 let withdraw_result = stake_pool
1419 .calc_lamports_withdraw_amount(deposit_result / 2)
1420 .unwrap();
1421 assert!(withdraw_result * 2 <= deposit_stake);
1422 }
1423
1424 #[test]
1425 fn withdraw_all() {
1426 let total_lamports = 1_100_000_000_000;
1427 let pool_token_supply = 1_000_000_000_000;
1428 let mut stake_pool = StakePool {
1429 total_lamports,
1430 pool_token_supply,
1431 ..StakePool::default()
1432 };
1433 let withdraw_result = stake_pool
1435 .calc_lamports_withdraw_amount(pool_token_supply)
1436 .unwrap();
1437 assert_eq!(stake_pool.total_lamports, withdraw_result);
1438
1439 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1441 stake_pool.total_lamports -= withdraw_result;
1442 stake_pool.pool_token_supply -= 1;
1443 let withdraw_result = stake_pool
1444 .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1445 .unwrap();
1446 assert_eq!(stake_pool.total_lamports, withdraw_result);
1447
1448 let mut stake_pool = StakePool {
1450 total_lamports,
1451 pool_token_supply,
1452 ..StakePool::default()
1453 };
1454 let withdraw_result = stake_pool
1455 .calc_lamports_withdraw_amount(pool_token_supply - 1)
1456 .unwrap();
1457 stake_pool.total_lamports -= withdraw_result;
1458 stake_pool.pool_token_supply = 1;
1459 assert_ne!(stake_pool.total_lamports, 0);
1460
1461 let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1462 assert_eq!(stake_pool.total_lamports, withdraw_result);
1463 }
1464}