spl_stake_pool/
state.rs

1//! State transition types
2
3use {
4    crate::{
5        big_vec::BigVec, error::StakePoolError, MAX_WITHDRAWAL_FEE_INCREASE,
6        WITHDRAWAL_BASELINE_FEE,
7    },
8    borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
9    bytemuck::{Pod, Zeroable},
10    num_derive::{FromPrimitive, ToPrimitive},
11    num_traits::{FromPrimitive, ToPrimitive},
12    solana_program::{
13        account_info::AccountInfo,
14        borsh1::get_instance_packed_len,
15        msg,
16        program_error::ProgramError,
17        program_memory::sol_memcmp,
18        program_pack::{Pack, Sealed},
19        pubkey::{Pubkey, PUBKEY_BYTES},
20        stake::state::Lockup,
21    },
22    spl_pod::primitives::{PodU32, PodU64},
23    spl_token_2022::{
24        extension::{BaseStateWithExtensions, ExtensionType, StateWithExtensions},
25        state::{Account, AccountState, Mint},
26    },
27    std::{borrow::Borrow, convert::TryFrom, fmt, matches},
28};
29
30/// Enum representing the account type managed by the program
31#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
32pub enum AccountType {
33    /// If the account has not been initialized, the enum will be 0
34    #[default]
35    Uninitialized,
36    /// Stake pool
37    StakePool,
38    /// Validator stake list
39    ValidatorList,
40}
41
42/// Initialized program details.
43#[repr(C)]
44#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
45pub struct StakePool {
46    /// Account type, must be `StakePool` currently
47    pub account_type: AccountType,
48
49    /// Manager authority, allows for updating the staker, manager, and fee
50    /// account
51    pub manager: Pubkey,
52
53    /// Staker authority, allows for adding and removing validators, and
54    /// managing stake distribution
55    pub staker: Pubkey,
56
57    /// Stake deposit authority
58    ///
59    /// If a depositor pubkey is specified on initialization, then deposits must
60    /// be signed by this authority. If no deposit authority is specified,
61    /// then the stake pool will default to the result of:
62    /// `Pubkey::find_program_address(
63    ///     &[&stake_pool_address.as_ref(), b"deposit"],
64    ///     program_id,
65    /// )`
66    pub stake_deposit_authority: Pubkey,
67
68    /// Stake withdrawal authority bump seed
69    /// for `create_program_address(&[state::StakePool account, "withdrawal"])`
70    pub stake_withdraw_bump_seed: u8,
71
72    /// Validator stake list storage account
73    pub validator_list: Pubkey,
74
75    /// Reserve stake account, holds deactivated stake
76    pub reserve_stake: Pubkey,
77
78    /// Pool Mint
79    pub pool_mint: Pubkey,
80
81    /// Manager fee account
82    pub manager_fee_account: Pubkey,
83
84    /// Pool token program id
85    pub token_program_id: Pubkey,
86
87    /// Total stake under management.
88    /// Note that if `last_update_epoch` does not match the current epoch then
89    /// this field may not be accurate
90    pub total_lamports: u64,
91
92    /// Total supply of pool tokens (should always match the supply in the Pool
93    /// Mint)
94    pub pool_token_supply: u64,
95
96    /// Last epoch the `total_lamports` field was updated
97    pub last_update_epoch: u64,
98
99    /// Lockup that all stakes in the pool must have
100    pub lockup: Lockup,
101
102    /// Fee taken as a proportion of rewards each epoch
103    pub epoch_fee: Fee,
104
105    /// Fee for next epoch
106    pub next_epoch_fee: FutureEpoch<Fee>,
107
108    /// Preferred deposit validator vote account pubkey
109    pub preferred_deposit_validator_vote_address: Option<Pubkey>,
110
111    /// Preferred withdraw validator vote account pubkey
112    pub preferred_withdraw_validator_vote_address: Option<Pubkey>,
113
114    /// Fee assessed on stake deposits
115    pub stake_deposit_fee: Fee,
116
117    /// Fee assessed on withdrawals
118    pub stake_withdrawal_fee: Fee,
119
120    /// Future stake withdrawal fee, to be set for the following epoch
121    pub next_stake_withdrawal_fee: FutureEpoch<Fee>,
122
123    /// Fees paid out to referrers on referred stake deposits.
124    /// Expressed as a percentage (0 - 100) of deposit fees.
125    /// i.e. `stake_deposit_fee`% of stake deposited is collected as deposit
126    /// fees for every deposit and `stake_referral_fee`% of the collected
127    /// stake deposit fees is paid out to the referrer
128    pub stake_referral_fee: u8,
129
130    /// Toggles whether the `DepositSol` instruction requires a signature from
131    /// this `sol_deposit_authority`
132    pub sol_deposit_authority: Option<Pubkey>,
133
134    /// Fee assessed on SOL deposits
135    pub sol_deposit_fee: Fee,
136
137    /// Fees paid out to referrers on referred SOL deposits.
138    /// Expressed as a percentage (0 - 100) of SOL deposit fees.
139    /// i.e. `sol_deposit_fee`% of SOL deposited is collected as deposit fees
140    /// for every deposit and `sol_referral_fee`% of the collected SOL
141    /// deposit fees is paid out to the referrer
142    pub sol_referral_fee: u8,
143
144    /// Toggles whether the `WithdrawSol` instruction requires a signature from
145    /// the `deposit_authority`
146    pub sol_withdraw_authority: Option<Pubkey>,
147
148    /// Fee assessed on SOL withdrawals
149    pub sol_withdrawal_fee: Fee,
150
151    /// Future SOL withdrawal fee, to be set for the following epoch
152    pub next_sol_withdrawal_fee: FutureEpoch<Fee>,
153
154    /// Last epoch's total pool tokens, used only for APR estimation
155    pub last_epoch_pool_token_supply: u64,
156
157    /// Last epoch's total lamports, used only for APR estimation
158    pub last_epoch_total_lamports: u64,
159}
160impl StakePool {
161    /// calculate the pool tokens that should be minted for a deposit of
162    /// `stake_lamports`
163    #[inline]
164    pub fn calc_pool_tokens_for_deposit(&self, stake_lamports: u64) -> Option<u64> {
165        if self.total_lamports == 0 || self.pool_token_supply == 0 {
166            return Some(stake_lamports);
167        }
168        u64::try_from(
169            (stake_lamports as u128)
170                .checked_mul(self.pool_token_supply as u128)?
171                .checked_div(self.total_lamports as u128)?,
172        )
173        .ok()
174    }
175
176    /// calculate lamports amount on withdrawal
177    #[inline]
178    pub fn calc_lamports_withdraw_amount(&self, pool_tokens: u64) -> Option<u64> {
179        // `checked_div` returns `None` for a 0 quotient result, but in this
180        // case, a return of 0 is valid for small amounts of pool tokens. So
181        // we check for that separately
182        let numerator = (pool_tokens as u128).checked_mul(self.total_lamports as u128)?;
183        let denominator = self.pool_token_supply as u128;
184        if numerator < denominator || denominator == 0 {
185            Some(0)
186        } else {
187            u64::try_from(numerator.checked_div(denominator)?).ok()
188        }
189    }
190
191    /// calculate pool tokens to be deducted as withdrawal fees
192    #[inline]
193    pub fn calc_pool_tokens_stake_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
194        u64::try_from(self.stake_withdrawal_fee.apply(pool_tokens)?).ok()
195    }
196
197    /// calculate pool tokens to be deducted as withdrawal fees
198    #[inline]
199    pub fn calc_pool_tokens_sol_withdrawal_fee(&self, pool_tokens: u64) -> Option<u64> {
200        u64::try_from(self.sol_withdrawal_fee.apply(pool_tokens)?).ok()
201    }
202
203    /// calculate pool tokens to be deducted as stake deposit fees
204    #[inline]
205    pub fn calc_pool_tokens_stake_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
206        u64::try_from(self.stake_deposit_fee.apply(pool_tokens_minted)?).ok()
207    }
208
209    /// calculate pool tokens to be deducted from deposit fees as referral fees
210    #[inline]
211    pub fn calc_pool_tokens_stake_referral_fee(&self, stake_deposit_fee: u64) -> Option<u64> {
212        u64::try_from(
213            (stake_deposit_fee as u128)
214                .checked_mul(self.stake_referral_fee as u128)?
215                .checked_div(100u128)?,
216        )
217        .ok()
218    }
219
220    /// calculate pool tokens to be deducted as SOL deposit fees
221    #[inline]
222    pub fn calc_pool_tokens_sol_deposit_fee(&self, pool_tokens_minted: u64) -> Option<u64> {
223        u64::try_from(self.sol_deposit_fee.apply(pool_tokens_minted)?).ok()
224    }
225
226    /// calculate pool tokens to be deducted from SOL deposit fees as referral
227    /// fees
228    #[inline]
229    pub fn calc_pool_tokens_sol_referral_fee(&self, sol_deposit_fee: u64) -> Option<u64> {
230        u64::try_from(
231            (sol_deposit_fee as u128)
232                .checked_mul(self.sol_referral_fee as u128)?
233                .checked_div(100u128)?,
234        )
235        .ok()
236    }
237
238    /// Calculate the fee in pool tokens that goes to the manager
239    ///
240    /// This function assumes that `reward_lamports` has not already been added
241    /// to the stake pool's `total_lamports`
242    #[inline]
243    pub fn calc_epoch_fee_amount(&self, reward_lamports: u64) -> Option<u64> {
244        if reward_lamports == 0 {
245            return Some(0);
246        }
247        let total_lamports = (self.total_lamports as u128).checked_add(reward_lamports as u128)?;
248        let fee_lamports = self.epoch_fee.apply(reward_lamports)?;
249        if total_lamports == fee_lamports || self.pool_token_supply == 0 {
250            Some(reward_lamports)
251        } else {
252            u64::try_from(
253                (self.pool_token_supply as u128)
254                    .checked_mul(fee_lamports)?
255                    .checked_div(total_lamports.checked_sub(fee_lamports)?)?,
256            )
257            .ok()
258        }
259    }
260
261    /// Get the current value of pool tokens, rounded up
262    #[inline]
263    pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
264        self.total_lamports
265            .checked_add(self.pool_token_supply)?
266            .checked_sub(1)?
267            .checked_div(self.pool_token_supply)
268    }
269
270    /// Checks that the withdraw or deposit authority is valid
271    fn check_program_derived_authority(
272        authority_address: &Pubkey,
273        program_id: &Pubkey,
274        stake_pool_address: &Pubkey,
275        authority_seed: &[u8],
276        bump_seed: u8,
277    ) -> Result<(), ProgramError> {
278        let expected_address = Pubkey::create_program_address(
279            &[stake_pool_address.as_ref(), authority_seed, &[bump_seed]],
280            program_id,
281        )?;
282
283        if *authority_address == expected_address {
284            Ok(())
285        } else {
286            msg!(
287                "Incorrect authority provided, expected {}, received {}",
288                expected_address,
289                authority_address
290            );
291            Err(StakePoolError::InvalidProgramAddress.into())
292        }
293    }
294
295    /// Check if the manager fee info is a valid token program account
296    /// capable of receiving tokens from the mint.
297    pub(crate) fn check_manager_fee_info(
298        &self,
299        manager_fee_info: &AccountInfo,
300    ) -> Result<(), ProgramError> {
301        let account_data = manager_fee_info.try_borrow_data()?;
302        let token_account = StateWithExtensions::<Account>::unpack(&account_data)?;
303        if manager_fee_info.owner != &self.token_program_id
304            || token_account.base.state != AccountState::Initialized
305            || token_account.base.mint != self.pool_mint
306        {
307            msg!("Manager fee account is not owned by token program, is not initialized, or does not match stake pool's mint");
308            return Err(StakePoolError::InvalidFeeAccount.into());
309        }
310        let extensions = token_account.get_extension_types()?;
311        if extensions
312            .iter()
313            .any(|x| !is_extension_supported_for_fee_account(x))
314        {
315            return Err(StakePoolError::UnsupportedFeeAccountExtension.into());
316        }
317        Ok(())
318    }
319
320    /// Checks that the withdraw authority is valid
321    #[inline]
322    pub(crate) fn check_authority_withdraw(
323        &self,
324        withdraw_authority: &Pubkey,
325        program_id: &Pubkey,
326        stake_pool_address: &Pubkey,
327    ) -> Result<(), ProgramError> {
328        Self::check_program_derived_authority(
329            withdraw_authority,
330            program_id,
331            stake_pool_address,
332            crate::AUTHORITY_WITHDRAW,
333            self.stake_withdraw_bump_seed,
334        )
335    }
336    /// Checks that the deposit authority is valid
337    #[inline]
338    pub(crate) fn check_stake_deposit_authority(
339        &self,
340        stake_deposit_authority: &Pubkey,
341    ) -> Result<(), ProgramError> {
342        if self.stake_deposit_authority == *stake_deposit_authority {
343            Ok(())
344        } else {
345            Err(StakePoolError::InvalidStakeDepositAuthority.into())
346        }
347    }
348
349    /// Checks that the deposit authority is valid
350    /// Does nothing if `sol_deposit_authority` is currently not set
351    #[inline]
352    pub(crate) fn check_sol_deposit_authority(
353        &self,
354        maybe_sol_deposit_authority: Result<&AccountInfo, ProgramError>,
355    ) -> Result<(), ProgramError> {
356        if let Some(auth) = self.sol_deposit_authority {
357            let sol_deposit_authority = maybe_sol_deposit_authority?;
358            if auth != *sol_deposit_authority.key {
359                msg!("Expected {}, received {}", auth, sol_deposit_authority.key);
360                return Err(StakePoolError::InvalidSolDepositAuthority.into());
361            }
362            if !sol_deposit_authority.is_signer {
363                msg!("SOL Deposit authority signature missing");
364                return Err(StakePoolError::SignatureMissing.into());
365            }
366        }
367        Ok(())
368    }
369
370    /// Checks that the sol withdraw authority is valid
371    /// Does nothing if `sol_withdraw_authority` is currently not set
372    #[inline]
373    pub(crate) fn check_sol_withdraw_authority(
374        &self,
375        maybe_sol_withdraw_authority: Result<&AccountInfo, ProgramError>,
376    ) -> Result<(), ProgramError> {
377        if let Some(auth) = self.sol_withdraw_authority {
378            let sol_withdraw_authority = maybe_sol_withdraw_authority?;
379            if auth != *sol_withdraw_authority.key {
380                return Err(StakePoolError::InvalidSolWithdrawAuthority.into());
381            }
382            if !sol_withdraw_authority.is_signer {
383                msg!("SOL withdraw authority signature missing");
384                return Err(StakePoolError::SignatureMissing.into());
385            }
386        }
387        Ok(())
388    }
389
390    /// Check mint is correct
391    #[inline]
392    pub(crate) fn check_mint(&self, mint_info: &AccountInfo) -> Result<u8, ProgramError> {
393        if *mint_info.key != self.pool_mint {
394            Err(StakePoolError::WrongPoolMint.into())
395        } else {
396            let mint_data = mint_info.try_borrow_data()?;
397            let mint = StateWithExtensions::<Mint>::unpack(&mint_data)?;
398            Ok(mint.base.decimals)
399        }
400    }
401
402    /// Check manager validity and signature
403    pub(crate) fn check_manager(&self, manager_info: &AccountInfo) -> Result<(), ProgramError> {
404        if *manager_info.key != self.manager {
405            msg!(
406                "Incorrect manager provided, expected {}, received {}",
407                self.manager,
408                manager_info.key
409            );
410            return Err(StakePoolError::WrongManager.into());
411        }
412        if !manager_info.is_signer {
413            msg!("Manager signature missing");
414            return Err(StakePoolError::SignatureMissing.into());
415        }
416        Ok(())
417    }
418
419    /// Check staker validity and signature
420    pub(crate) fn check_staker(&self, staker_info: &AccountInfo) -> Result<(), ProgramError> {
421        if *staker_info.key != self.staker {
422            msg!(
423                "Incorrect staker provided, expected {}, received {}",
424                self.staker,
425                staker_info.key
426            );
427            return Err(StakePoolError::WrongStaker.into());
428        }
429        if !staker_info.is_signer {
430            msg!("Staker signature missing");
431            return Err(StakePoolError::SignatureMissing.into());
432        }
433        Ok(())
434    }
435
436    /// Check the validator list is valid
437    pub fn check_validator_list(
438        &self,
439        validator_list_info: &AccountInfo,
440    ) -> Result<(), ProgramError> {
441        if *validator_list_info.key != self.validator_list {
442            msg!(
443                "Invalid validator list provided, expected {}, received {}",
444                self.validator_list,
445                validator_list_info.key
446            );
447            Err(StakePoolError::InvalidValidatorStakeList.into())
448        } else {
449            Ok(())
450        }
451    }
452
453    /// Check the reserve stake is valid
454    pub fn check_reserve_stake(
455        &self,
456        reserve_stake_info: &AccountInfo,
457    ) -> Result<(), ProgramError> {
458        if *reserve_stake_info.key != self.reserve_stake {
459            msg!(
460                "Invalid reserve stake provided, expected {}, received {}",
461                self.reserve_stake,
462                reserve_stake_info.key
463            );
464            Err(StakePoolError::InvalidProgramAddress.into())
465        } else {
466            Ok(())
467        }
468    }
469
470    /// Check if `StakePool` is actually initialized as a stake pool
471    pub fn is_valid(&self) -> bool {
472        self.account_type == AccountType::StakePool
473    }
474
475    /// Check if `StakePool` is currently uninitialized
476    pub fn is_uninitialized(&self) -> bool {
477        self.account_type == AccountType::Uninitialized
478    }
479
480    /// Updates one of the `StakePool`'s fees.
481    pub fn update_fee(&mut self, fee: &FeeType) -> Result<(), StakePoolError> {
482        match fee {
483            FeeType::SolReferral(new_fee) => self.sol_referral_fee = *new_fee,
484            FeeType::StakeReferral(new_fee) => self.stake_referral_fee = *new_fee,
485            FeeType::Epoch(new_fee) => self.next_epoch_fee = FutureEpoch::new(*new_fee),
486            FeeType::StakeWithdrawal(new_fee) => {
487                new_fee.check_withdrawal(&self.stake_withdrawal_fee)?;
488                self.next_stake_withdrawal_fee = FutureEpoch::new(*new_fee)
489            }
490            FeeType::SolWithdrawal(new_fee) => {
491                new_fee.check_withdrawal(&self.sol_withdrawal_fee)?;
492                self.next_sol_withdrawal_fee = FutureEpoch::new(*new_fee)
493            }
494            FeeType::SolDeposit(new_fee) => self.sol_deposit_fee = *new_fee,
495            FeeType::StakeDeposit(new_fee) => self.stake_deposit_fee = *new_fee,
496        };
497        Ok(())
498    }
499}
500
501/// Checks if the given extension is supported for the stake pool mint
502pub fn is_extension_supported_for_mint(extension_type: &ExtensionType) -> bool {
503    const SUPPORTED_EXTENSIONS: [ExtensionType; 8] = [
504        ExtensionType::Uninitialized,
505        ExtensionType::TransferFeeConfig,
506        ExtensionType::ConfidentialTransferMint,
507        ExtensionType::ConfidentialTransferFeeConfig,
508        ExtensionType::DefaultAccountState, // ok, but a freeze authority is not
509        ExtensionType::InterestBearingConfig,
510        ExtensionType::MetadataPointer,
511        ExtensionType::TokenMetadata,
512    ];
513    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
514        msg!(
515            "Stake pool mint account cannot have the {:?} extension",
516            extension_type
517        );
518        false
519    } else {
520        true
521    }
522}
523
524/// Checks if the given extension is supported for the stake pool's fee account
525pub fn is_extension_supported_for_fee_account(extension_type: &ExtensionType) -> bool {
526    // Note: this does not include the `ConfidentialTransferAccount` extension
527    // because it is possible to block non-confidential transfers with the
528    // extension enabled.
529    const SUPPORTED_EXTENSIONS: [ExtensionType; 4] = [
530        ExtensionType::Uninitialized,
531        ExtensionType::TransferFeeAmount,
532        ExtensionType::ImmutableOwner,
533        ExtensionType::CpiGuard,
534    ];
535    if !SUPPORTED_EXTENSIONS.contains(extension_type) {
536        msg!("Fee account cannot have the {:?} extension", extension_type);
537        false
538    } else {
539        true
540    }
541}
542
543/// Storage list for all validator stake accounts in the pool.
544#[repr(C)]
545#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
546pub struct ValidatorList {
547    /// Data outside of the validator list, separated out for cheaper
548    /// deserialization
549    pub header: ValidatorListHeader,
550
551    /// List of stake info for each validator in the pool
552    pub validators: Vec<ValidatorStakeInfo>,
553}
554
555/// Helper type to deserialize just the start of a `ValidatorList`
556#[repr(C)]
557#[derive(Clone, Debug, Default, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
558pub struct ValidatorListHeader {
559    /// Account type, must be `ValidatorList` currently
560    pub account_type: AccountType,
561
562    /// Maximum allowable number of validators
563    pub max_validators: u32,
564}
565
566/// Status of the stake account in the validator list, for accounting
567#[derive(
568    ToPrimitive,
569    FromPrimitive,
570    Copy,
571    Clone,
572    Debug,
573    PartialEq,
574    BorshDeserialize,
575    BorshSerialize,
576    BorshSchema,
577)]
578pub enum StakeStatus {
579    /// Stake account is active, there may be a transient stake as well
580    Active,
581    /// Only transient stake account exists, when a transient stake is
582    /// deactivating during validator removal
583    DeactivatingTransient,
584    /// No more validator stake accounts exist, entry ready for removal during
585    /// `UpdateStakePoolBalance`
586    ReadyForRemoval,
587    /// Only the validator stake account is deactivating, no transient stake
588    /// account exists
589    DeactivatingValidator,
590    /// Both the transient and validator stake account are deactivating, when
591    /// a validator is removed with a transient stake active
592    DeactivatingAll,
593}
594impl Default for StakeStatus {
595    fn default() -> Self {
596        Self::Active
597    }
598}
599
600/// Wrapper struct that can be `Pod`, containing a byte that *should* be a valid
601/// `StakeStatus` underneath.
602#[repr(transparent)]
603#[derive(
604    Clone,
605    Copy,
606    Debug,
607    Default,
608    PartialEq,
609    Pod,
610    Zeroable,
611    BorshDeserialize,
612    BorshSerialize,
613    BorshSchema,
614)]
615pub struct PodStakeStatus(u8);
616impl PodStakeStatus {
617    /// Downgrade the status towards ready for removal by removing the validator
618    /// stake
619    pub fn remove_validator_stake(&mut self) -> Result<(), ProgramError> {
620        let status = StakeStatus::try_from(*self)?;
621        let new_self = match status {
622            StakeStatus::Active
623            | StakeStatus::DeactivatingTransient
624            | StakeStatus::ReadyForRemoval => status,
625            StakeStatus::DeactivatingAll => StakeStatus::DeactivatingTransient,
626            StakeStatus::DeactivatingValidator => StakeStatus::ReadyForRemoval,
627        };
628        *self = new_self.into();
629        Ok(())
630    }
631    /// Downgrade the status towards ready for removal by removing the transient
632    /// stake
633    pub fn remove_transient_stake(&mut self) -> Result<(), ProgramError> {
634        let status = StakeStatus::try_from(*self)?;
635        let new_self = match status {
636            StakeStatus::Active
637            | StakeStatus::DeactivatingValidator
638            | StakeStatus::ReadyForRemoval => status,
639            StakeStatus::DeactivatingAll => StakeStatus::DeactivatingValidator,
640            StakeStatus::DeactivatingTransient => StakeStatus::ReadyForRemoval,
641        };
642        *self = new_self.into();
643        Ok(())
644    }
645}
646impl TryFrom<PodStakeStatus> for StakeStatus {
647    type Error = ProgramError;
648    fn try_from(pod: PodStakeStatus) -> Result<Self, Self::Error> {
649        FromPrimitive::from_u8(pod.0).ok_or(ProgramError::InvalidAccountData)
650    }
651}
652impl From<StakeStatus> for PodStakeStatus {
653    fn from(status: StakeStatus) -> Self {
654        // unwrap is safe here because the variants of `StakeStatus` fit very
655        // comfortably within a `u8`
656        PodStakeStatus(status.to_u8().unwrap())
657    }
658}
659
660/// Withdrawal type, figured out during `process_withdraw_stake`
661#[derive(Debug, PartialEq)]
662pub(crate) enum StakeWithdrawSource {
663    /// Some of an active stake account, but not all
664    Active,
665    /// Some of a transient stake account
666    Transient,
667    /// Take a whole validator stake account
668    ValidatorRemoval,
669}
670
671/// Information about a validator in the pool
672///
673/// NOTE: ORDER IS VERY IMPORTANT HERE, PLEASE DO NOT RE-ORDER THE FIELDS UNLESS
674/// THERE'S AN EXTREMELY GOOD REASON.
675///
676/// To save on BPF instructions, the serialized bytes are reinterpreted with a
677/// `bytemuck` transmute, which means that this structure cannot have any
678/// undeclared alignment-padding in its representation.
679#[repr(C)]
680#[derive(
681    Clone,
682    Copy,
683    Debug,
684    Default,
685    PartialEq,
686    Pod,
687    Zeroable,
688    BorshDeserialize,
689    BorshSerialize,
690    BorshSchema,
691)]
692pub struct ValidatorStakeInfo {
693    /// Amount of lamports on the validator stake account, including rent
694    ///
695    /// Note that if `last_update_epoch` does not match the current epoch then
696    /// this field may not be accurate
697    pub active_stake_lamports: PodU64,
698
699    /// Amount of transient stake delegated to this validator
700    ///
701    /// Note that if `last_update_epoch` does not match the current epoch then
702    /// this field may not be accurate
703    pub transient_stake_lamports: PodU64,
704
705    /// Last epoch the active and transient stake lamports fields were updated
706    pub last_update_epoch: PodU64,
707
708    /// Transient account seed suffix, used to derive the transient stake
709    /// account address
710    pub transient_seed_suffix: PodU64,
711
712    /// Unused space, initially meant to specify the end of seed suffixes
713    pub unused: PodU32,
714
715    /// Validator account seed suffix
716    pub validator_seed_suffix: PodU32, // really `Option<NonZeroU32>` so 0 is `None`
717
718    /// Status of the validator stake account
719    pub status: PodStakeStatus,
720
721    /// Validator vote account address
722    pub vote_account_address: Pubkey,
723}
724
725impl ValidatorStakeInfo {
726    /// Get the total lamports on this validator (active and transient)
727    pub fn stake_lamports(&self) -> Result<u64, StakePoolError> {
728        u64::from(self.active_stake_lamports)
729            .checked_add(self.transient_stake_lamports.into())
730            .ok_or(StakePoolError::CalculationFailure)
731    }
732
733    /// Performs a very cheap comparison, for checking if this validator stake
734    /// info matches the vote account address
735    pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
736        sol_memcmp(
737            &data[41..41_usize.saturating_add(PUBKEY_BYTES)],
738            vote_address.as_ref(),
739            PUBKEY_BYTES,
740        ) == 0
741    }
742
743    /// Performs a comparison, used to check if this validator stake
744    /// info has more active lamports than some limit
745    pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
746        // without this unwrap, compute usage goes up significantly
747        u64::try_from_slice(&data[0..8]).unwrap() > *lamports
748    }
749
750    /// Performs a comparison, used to check if this validator stake
751    /// info has more transient lamports than some limit
752    pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
753        // without this unwrap, compute usage goes up significantly
754        u64::try_from_slice(&data[8..16]).unwrap() > *lamports
755    }
756
757    /// Check that the validator stake info is valid
758    pub fn is_not_removed(data: &[u8]) -> bool {
759        FromPrimitive::from_u8(data[40]) != Some(StakeStatus::ReadyForRemoval)
760    }
761}
762
763impl Sealed for ValidatorStakeInfo {}
764
765impl Pack for ValidatorStakeInfo {
766    const LEN: usize = 73;
767    fn pack_into_slice(&self, data: &mut [u8]) {
768        // Removing this unwrap would require changing from `Pack` to some other
769        // trait or `bytemuck`, so it stays in for now
770        borsh::to_writer(data, self).unwrap();
771    }
772    fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
773        let unpacked = Self::try_from_slice(src)?;
774        Ok(unpacked)
775    }
776}
777
778impl ValidatorList {
779    /// Create an empty instance containing space for `max_validators` and
780    /// preferred validator keys
781    pub fn new(max_validators: u32) -> Self {
782        Self {
783            header: ValidatorListHeader {
784                account_type: AccountType::ValidatorList,
785                max_validators,
786            },
787            validators: vec![ValidatorStakeInfo::default(); max_validators as usize],
788        }
789    }
790
791    /// Calculate the number of validator entries that fit in the provided
792    /// length
793    pub fn calculate_max_validators(buffer_length: usize) -> usize {
794        let header_size = ValidatorListHeader::LEN.saturating_add(4);
795        buffer_length
796            .saturating_sub(header_size)
797            .saturating_div(ValidatorStakeInfo::LEN)
798    }
799
800    /// Check if contains validator with particular pubkey
801    pub fn contains(&self, vote_account_address: &Pubkey) -> bool {
802        self.validators
803            .iter()
804            .any(|x| x.vote_account_address == *vote_account_address)
805    }
806
807    /// Check if contains validator with particular pubkey
808    pub fn find_mut(&mut self, vote_account_address: &Pubkey) -> Option<&mut ValidatorStakeInfo> {
809        self.validators
810            .iter_mut()
811            .find(|x| x.vote_account_address == *vote_account_address)
812    }
813    /// Check if contains validator with particular pubkey
814    pub fn find(&self, vote_account_address: &Pubkey) -> Option<&ValidatorStakeInfo> {
815        self.validators
816            .iter()
817            .find(|x| x.vote_account_address == *vote_account_address)
818    }
819
820    /// Check if the list has any active stake
821    pub fn has_active_stake(&self) -> bool {
822        self.validators
823            .iter()
824            .any(|x| u64::from(x.active_stake_lamports) > 0)
825    }
826}
827
828impl ValidatorListHeader {
829    const LEN: usize = 1 + 4;
830
831    /// Check if validator stake list is actually initialized as a validator
832    /// stake list
833    pub fn is_valid(&self) -> bool {
834        self.account_type == AccountType::ValidatorList
835    }
836
837    /// Check if the validator stake list is uninitialized
838    pub fn is_uninitialized(&self) -> bool {
839        self.account_type == AccountType::Uninitialized
840    }
841
842    /// Extracts a slice of `ValidatorStakeInfo` types from the vec part
843    /// of the `ValidatorList`
844    pub fn deserialize_mut_slice<'a>(
845        big_vec: &'a mut BigVec,
846        skip: usize,
847        len: usize,
848    ) -> Result<&'a mut [ValidatorStakeInfo], ProgramError> {
849        big_vec.deserialize_mut_slice::<ValidatorStakeInfo>(skip, len)
850    }
851
852    /// Extracts the validator list into its header and internal `BigVec`
853    pub fn deserialize_vec(data: &mut [u8]) -> Result<(Self, BigVec), ProgramError> {
854        let mut data_mut = data.borrow();
855        let header = ValidatorListHeader::deserialize(&mut data_mut)?;
856        let length = get_instance_packed_len(&header)?;
857
858        let big_vec = BigVec {
859            data: &mut data[length..],
860        };
861        Ok((header, big_vec))
862    }
863}
864
865/// Wrapper type that "counts down" epochs, which is Borsh-compatible with the
866/// native `Option`
867#[repr(C)]
868#[derive(Clone, Copy, Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
869pub enum FutureEpoch<T> {
870    /// Nothing is set
871    None,
872    /// Value is ready after the next epoch boundary
873    One(T),
874    /// Value is ready after two epoch boundaries
875    Two(T),
876}
877impl<T> Default for FutureEpoch<T> {
878    fn default() -> Self {
879        Self::None
880    }
881}
882impl<T> FutureEpoch<T> {
883    /// Create a new value to be unlocked in a two epochs
884    pub fn new(value: T) -> Self {
885        Self::Two(value)
886    }
887}
888impl<T: Clone> FutureEpoch<T> {
889    /// Update the epoch, to be done after `get`ting the underlying value
890    pub fn update_epoch(&mut self) {
891        match self {
892            Self::None => {}
893            Self::One(_) => {
894                // The value has waited its last epoch
895                *self = Self::None;
896            }
897            // The value still has to wait one more epoch after this
898            Self::Two(v) => {
899                *self = Self::One(v.clone());
900            }
901        }
902    }
903
904    /// Get the value if it's ready, which is only at `One` epoch remaining
905    pub fn get(&self) -> Option<&T> {
906        match self {
907            Self::None | Self::Two(_) => None,
908            Self::One(v) => Some(v),
909        }
910    }
911}
912impl<T> From<FutureEpoch<T>> for Option<T> {
913    fn from(v: FutureEpoch<T>) -> Option<T> {
914        match v {
915            FutureEpoch::None => None,
916            FutureEpoch::One(inner) | FutureEpoch::Two(inner) => Some(inner),
917        }
918    }
919}
920
921/// Fee rate as a ratio, minted on `UpdateStakePoolBalance` as a proportion of
922/// the rewards
923/// If either the numerator or the denominator is 0, the fee is considered to be
924/// 0
925#[repr(C)]
926#[derive(Clone, Copy, Debug, Default, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)]
927pub struct Fee {
928    /// denominator of the fee ratio
929    pub denominator: u64,
930    /// numerator of the fee ratio
931    pub numerator: u64,
932}
933
934impl Fee {
935    /// Applies the Fee's rates to a given amount, `amt`
936    /// returning the amount to be subtracted from it as fees
937    /// (0 if denominator is 0 or amt is 0),
938    /// or None if overflow occurs
939    #[inline]
940    pub fn apply(&self, amt: u64) -> Option<u128> {
941        if self.denominator == 0 {
942            return Some(0);
943        }
944        let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
945        // ceiling the calculation by adding (denominator - 1) to the numerator
946        let denominator = self.denominator as u128;
947        numerator
948            .checked_add(denominator)?
949            .checked_sub(1)?
950            .checked_div(denominator)
951    }
952
953    /// Withdrawal fees have some additional restrictions, this function checks
954    /// if those are met, returning an error if not.
955    pub fn check_withdrawal(&self, old_withdrawal_fee: &Fee) -> Result<(), StakePoolError> {
956        // If the previous withdrawal fee was 0, we allow the fee to be set to a
957        // maximum of (WITHDRAWAL_BASELINE_FEE * MAX_WITHDRAWAL_FEE_INCREASE)
958        let (old_num, old_denom) =
959            if old_withdrawal_fee.denominator == 0 || old_withdrawal_fee.numerator == 0 {
960                (
961                    WITHDRAWAL_BASELINE_FEE.numerator,
962                    WITHDRAWAL_BASELINE_FEE.denominator,
963                )
964            } else {
965                (old_withdrawal_fee.numerator, old_withdrawal_fee.denominator)
966            };
967
968        // Check that new_fee / old_fee <= MAX_WITHDRAWAL_FEE_INCREASE
969        // Program fails if provided numerator or denominator is too large, resulting in
970        // overflow
971        if (old_num as u128)
972            .checked_mul(self.denominator as u128)
973            .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.numerator as u128))
974            .ok_or(StakePoolError::CalculationFailure)?
975            < (self.numerator as u128)
976                .checked_mul(old_denom as u128)
977                .map(|x| x.checked_mul(MAX_WITHDRAWAL_FEE_INCREASE.denominator as u128))
978                .ok_or(StakePoolError::CalculationFailure)?
979        {
980            msg!(
981                "Fee increase exceeds maximum allowed, proposed increase factor ({} / {})",
982                self.numerator.saturating_mul(old_denom),
983                old_num.saturating_mul(self.denominator),
984            );
985            return Err(StakePoolError::FeeIncreaseTooHigh);
986        }
987        Ok(())
988    }
989}
990
991impl fmt::Display for Fee {
992    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
993        if self.numerator > 0 && self.denominator > 0 {
994            write!(f, "{}/{}", self.numerator, self.denominator)
995        } else {
996            write!(f, "none")
997        }
998    }
999}
1000
1001/// The type of fees that can be set on the stake pool
1002#[derive(Clone, Debug, PartialEq, BorshDeserialize, BorshSerialize, BorshSchema)]
1003pub enum FeeType {
1004    /// Referral fees for SOL deposits
1005    SolReferral(u8),
1006    /// Referral fees for stake deposits
1007    StakeReferral(u8),
1008    /// Management fee paid per epoch
1009    Epoch(Fee),
1010    /// Stake withdrawal fee
1011    StakeWithdrawal(Fee),
1012    /// Deposit fee for SOL deposits
1013    SolDeposit(Fee),
1014    /// Deposit fee for stake deposits
1015    StakeDeposit(Fee),
1016    /// SOL withdrawal fee
1017    SolWithdrawal(Fee),
1018}
1019
1020impl FeeType {
1021    /// Checks if the provided fee is too high, returning an error if so
1022    pub fn check_too_high(&self) -> Result<(), StakePoolError> {
1023        let too_high = match self {
1024            Self::SolReferral(pct) => *pct > 100u8,
1025            Self::StakeReferral(pct) => *pct > 100u8,
1026            Self::Epoch(fee) => fee.numerator > fee.denominator,
1027            Self::StakeWithdrawal(fee) => fee.numerator > fee.denominator,
1028            Self::SolWithdrawal(fee) => fee.numerator > fee.denominator,
1029            Self::SolDeposit(fee) => fee.numerator > fee.denominator,
1030            Self::StakeDeposit(fee) => fee.numerator > fee.denominator,
1031        };
1032        if too_high {
1033            msg!("Fee greater than 100%: {:?}", self);
1034            return Err(StakePoolError::FeeTooHigh);
1035        }
1036        Ok(())
1037    }
1038
1039    /// Returns if the contained fee can only be updated earliest on the next
1040    /// epoch
1041    #[inline]
1042    pub fn can_only_change_next_epoch(&self) -> bool {
1043        matches!(
1044            self,
1045            Self::StakeWithdrawal(_) | Self::SolWithdrawal(_) | Self::Epoch(_)
1046        )
1047    }
1048}
1049
1050#[cfg(test)]
1051mod test {
1052    #![allow(clippy::arithmetic_side_effects)]
1053    use {
1054        super::*,
1055        proptest::prelude::*,
1056        solana_program::{
1057            borsh1::{get_packed_len, try_from_slice_unchecked},
1058            clock::{DEFAULT_SLOTS_PER_EPOCH, DEFAULT_S_PER_SLOT, SECONDS_PER_DAY},
1059            native_token::LAMPORTS_PER_SOL,
1060        },
1061    };
1062
1063    fn uninitialized_validator_list() -> ValidatorList {
1064        ValidatorList {
1065            header: ValidatorListHeader {
1066                account_type: AccountType::Uninitialized,
1067                max_validators: 0,
1068            },
1069            validators: vec![],
1070        }
1071    }
1072
1073    fn test_validator_list(max_validators: u32) -> ValidatorList {
1074        ValidatorList {
1075            header: ValidatorListHeader {
1076                account_type: AccountType::ValidatorList,
1077                max_validators,
1078            },
1079            validators: vec![
1080                ValidatorStakeInfo {
1081                    status: StakeStatus::Active.into(),
1082                    vote_account_address: Pubkey::new_from_array([1; 32]),
1083                    active_stake_lamports: u64::from_le_bytes([255; 8]).into(),
1084                    transient_stake_lamports: u64::from_le_bytes([128; 8]).into(),
1085                    last_update_epoch: u64::from_le_bytes([64; 8]).into(),
1086                    transient_seed_suffix: 0.into(),
1087                    unused: 0.into(),
1088                    validator_seed_suffix: 0.into(),
1089                },
1090                ValidatorStakeInfo {
1091                    status: StakeStatus::DeactivatingTransient.into(),
1092                    vote_account_address: Pubkey::new_from_array([2; 32]),
1093                    active_stake_lamports: 998877665544.into(),
1094                    transient_stake_lamports: 222222222.into(),
1095                    last_update_epoch: 11223445566.into(),
1096                    transient_seed_suffix: 0.into(),
1097                    unused: 0.into(),
1098                    validator_seed_suffix: 0.into(),
1099                },
1100                ValidatorStakeInfo {
1101                    status: StakeStatus::ReadyForRemoval.into(),
1102                    vote_account_address: Pubkey::new_from_array([3; 32]),
1103                    active_stake_lamports: 0.into(),
1104                    transient_stake_lamports: 0.into(),
1105                    last_update_epoch: 999999999999999.into(),
1106                    transient_seed_suffix: 0.into(),
1107                    unused: 0.into(),
1108                    validator_seed_suffix: 0.into(),
1109                },
1110            ],
1111        }
1112    }
1113
1114    #[test]
1115    fn state_packing() {
1116        let max_validators = 10_000;
1117        let size = get_instance_packed_len(&ValidatorList::new(max_validators)).unwrap();
1118        let stake_list = uninitialized_validator_list();
1119        let mut byte_vec = vec![0u8; size];
1120        let bytes = byte_vec.as_mut_slice();
1121        borsh::to_writer(bytes, &stake_list).unwrap();
1122        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1123        assert_eq!(stake_list_unpacked, stake_list);
1124
1125        // Empty, one preferred key
1126        let stake_list = ValidatorList {
1127            header: ValidatorListHeader {
1128                account_type: AccountType::ValidatorList,
1129                max_validators: 0,
1130            },
1131            validators: vec![],
1132        };
1133        let mut byte_vec = vec![0u8; size];
1134        let bytes = byte_vec.as_mut_slice();
1135        borsh::to_writer(bytes, &stake_list).unwrap();
1136        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1137        assert_eq!(stake_list_unpacked, stake_list);
1138
1139        // With several accounts
1140        let stake_list = test_validator_list(max_validators);
1141        let mut byte_vec = vec![0u8; size];
1142        let bytes = byte_vec.as_mut_slice();
1143        borsh::to_writer(bytes, &stake_list).unwrap();
1144        let stake_list_unpacked = try_from_slice_unchecked::<ValidatorList>(&byte_vec).unwrap();
1145        assert_eq!(stake_list_unpacked, stake_list);
1146    }
1147
1148    #[test]
1149    fn validator_list_active_stake() {
1150        let max_validators = 10_000;
1151        let mut validator_list = test_validator_list(max_validators);
1152        assert!(validator_list.has_active_stake());
1153        for validator in validator_list.validators.iter_mut() {
1154            validator.active_stake_lamports = 0.into();
1155        }
1156        assert!(!validator_list.has_active_stake());
1157    }
1158
1159    #[test]
1160    fn validator_list_deserialize_mut_slice() {
1161        let max_validators = 10;
1162        let stake_list = test_validator_list(max_validators);
1163        let mut serialized = borsh::to_vec(&stake_list).unwrap();
1164        let (header, mut big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1165        let list = ValidatorListHeader::deserialize_mut_slice(
1166            &mut big_vec,
1167            0,
1168            stake_list.validators.len(),
1169        )
1170        .unwrap();
1171        assert_eq!(header.account_type, AccountType::ValidatorList);
1172        assert_eq!(header.max_validators, max_validators);
1173        assert!(list
1174            .iter()
1175            .zip(stake_list.validators.iter())
1176            .all(|(a, b)| a == b));
1177
1178        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 2).unwrap();
1179        assert!(list
1180            .iter()
1181            .zip(stake_list.validators[1..].iter())
1182            .all(|(a, b)| a == b));
1183        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 2, 1).unwrap();
1184        assert!(list
1185            .iter()
1186            .zip(stake_list.validators[2..].iter())
1187            .all(|(a, b)| a == b));
1188        let list = ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 2).unwrap();
1189        assert!(list
1190            .iter()
1191            .zip(stake_list.validators[..2].iter())
1192            .all(|(a, b)| a == b));
1193
1194        assert_eq!(
1195            ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 0, 4).unwrap_err(),
1196            ProgramError::AccountDataTooSmall
1197        );
1198        assert_eq!(
1199            ValidatorListHeader::deserialize_mut_slice(&mut big_vec, 1, 3).unwrap_err(),
1200            ProgramError::AccountDataTooSmall
1201        );
1202    }
1203
1204    #[test]
1205    fn validator_list_iter() {
1206        let max_validators = 10;
1207        let stake_list = test_validator_list(max_validators);
1208        let mut serialized = borsh::to_vec(&stake_list).unwrap();
1209        let (_, big_vec) = ValidatorListHeader::deserialize_vec(&mut serialized).unwrap();
1210        for (a, b) in big_vec
1211            .deserialize_slice::<ValidatorStakeInfo>(0, big_vec.len() as usize)
1212            .unwrap()
1213            .iter()
1214            .zip(stake_list.validators.iter())
1215        {
1216            assert_eq!(a, b);
1217        }
1218    }
1219
1220    proptest! {
1221        #[test]
1222        fn stake_list_size_calculation(test_amount in 0..=100_000_u32) {
1223            let validators = ValidatorList::new(test_amount);
1224            let size = get_instance_packed_len(&validators).unwrap();
1225            assert_eq!(ValidatorList::calculate_max_validators(size), test_amount as usize);
1226            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(1)), test_amount as usize);
1227            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_add(get_packed_len::<ValidatorStakeInfo>())), (test_amount + 1)as usize);
1228            assert_eq!(ValidatorList::calculate_max_validators(size.saturating_sub(1)), (test_amount.saturating_sub(1)) as usize);
1229        }
1230    }
1231
1232    prop_compose! {
1233        fn fee()(denominator in 1..=u16::MAX)(
1234            denominator in Just(denominator),
1235            numerator in 0..=denominator,
1236        ) -> (u64, u64) {
1237            (numerator as u64, denominator as u64)
1238        }
1239    }
1240
1241    prop_compose! {
1242        fn total_stake_and_rewards()(total_lamports in 1..u64::MAX)(
1243            total_lamports in Just(total_lamports),
1244            rewards in 0..=total_lamports,
1245        ) -> (u64, u64) {
1246            (total_lamports - rewards, rewards)
1247        }
1248    }
1249
1250    #[test]
1251    fn specific_fee_calculation() {
1252        // 10% of 10 SOL in rewards should be 1 SOL in fees
1253        let epoch_fee = Fee {
1254            numerator: 1,
1255            denominator: 10,
1256        };
1257        let mut stake_pool = StakePool {
1258            total_lamports: 100 * LAMPORTS_PER_SOL,
1259            pool_token_supply: 100 * LAMPORTS_PER_SOL,
1260            epoch_fee,
1261            ..StakePool::default()
1262        };
1263        let reward_lamports = 10 * LAMPORTS_PER_SOL;
1264        let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1265
1266        stake_pool.total_lamports += reward_lamports;
1267        stake_pool.pool_token_supply += pool_token_fee;
1268
1269        let fee_lamports = stake_pool
1270            .calc_lamports_withdraw_amount(pool_token_fee)
1271            .unwrap();
1272        assert_eq!(fee_lamports, LAMPORTS_PER_SOL - 1); // off-by-one due to
1273                                                        // truncation
1274    }
1275
1276    #[test]
1277    fn zero_withdraw_calculation() {
1278        let epoch_fee = Fee {
1279            numerator: 0,
1280            denominator: 1,
1281        };
1282        let stake_pool = StakePool {
1283            epoch_fee,
1284            ..StakePool::default()
1285        };
1286        let fee_lamports = stake_pool.calc_lamports_withdraw_amount(0).unwrap();
1287        assert_eq!(fee_lamports, 0);
1288    }
1289
1290    #[test]
1291    fn divide_by_zero_fee() {
1292        let stake_pool = StakePool {
1293            total_lamports: 0,
1294            epoch_fee: Fee {
1295                numerator: 1,
1296                denominator: 10,
1297            },
1298            ..StakePool::default()
1299        };
1300        let rewards = 10;
1301        let fee = stake_pool.calc_epoch_fee_amount(rewards).unwrap();
1302        assert_eq!(fee, rewards);
1303    }
1304
1305    #[test]
1306    fn approximate_apr_calculation() {
1307        // 8% / year means roughly .044% / epoch
1308        let stake_pool = StakePool {
1309            last_epoch_total_lamports: 100_000,
1310            last_epoch_pool_token_supply: 100_000,
1311            total_lamports: 100_044,
1312            pool_token_supply: 100_000,
1313            ..StakePool::default()
1314        };
1315        let pool_token_value =
1316            stake_pool.total_lamports as f64 / stake_pool.pool_token_supply as f64;
1317        let last_epoch_pool_token_value = stake_pool.last_epoch_total_lamports as f64
1318            / stake_pool.last_epoch_pool_token_supply as f64;
1319        let epoch_rate = pool_token_value / last_epoch_pool_token_value - 1.0;
1320        const SECONDS_PER_EPOCH: f64 = DEFAULT_SLOTS_PER_EPOCH as f64 * DEFAULT_S_PER_SLOT;
1321        const EPOCHS_PER_YEAR: f64 = SECONDS_PER_DAY as f64 * 365.25 / SECONDS_PER_EPOCH;
1322        const EPSILON: f64 = 0.00001;
1323        let yearly_rate = epoch_rate * EPOCHS_PER_YEAR;
1324        assert!((yearly_rate - 0.080355).abs() < EPSILON);
1325    }
1326
1327    proptest! {
1328        #[test]
1329        fn fee_calculation(
1330            (numerator, denominator) in fee(),
1331            (total_lamports, reward_lamports) in total_stake_and_rewards(),
1332        ) {
1333            let epoch_fee = Fee { denominator, numerator };
1334            let mut stake_pool = StakePool {
1335                total_lamports,
1336                pool_token_supply: total_lamports,
1337                epoch_fee,
1338                ..StakePool::default()
1339            };
1340            let pool_token_fee = stake_pool.calc_epoch_fee_amount(reward_lamports).unwrap();
1341
1342            stake_pool.total_lamports += reward_lamports;
1343            stake_pool.pool_token_supply += pool_token_fee;
1344
1345            let fee_lamports = stake_pool.calc_lamports_withdraw_amount(pool_token_fee).unwrap();
1346            let max_fee_lamports = u64::try_from((reward_lamports as u128) * (epoch_fee.numerator as u128) / (epoch_fee.denominator as u128)).unwrap();
1347            assert!(max_fee_lamports >= fee_lamports,
1348                "Max possible fee must always be greater than or equal to what is actually withdrawn, max {} actual {}",
1349                max_fee_lamports,
1350                fee_lamports);
1351
1352            // since we do two "flooring" conversions, the max epsilon should be
1353            // correct up to 2 lamports (one for each floor division), plus a
1354            // correction for huge discrepancies between rewards and total stake
1355            let epsilon = 2 + reward_lamports / total_lamports;
1356            assert!(max_fee_lamports - fee_lamports <= epsilon,
1357                "Max expected fee in lamports {}, actually receive {}, epsilon {}",
1358                max_fee_lamports, fee_lamports, epsilon);
1359        }
1360    }
1361
1362    prop_compose! {
1363        fn total_tokens_and_deposit()(total_lamports in 1..u64::MAX)(
1364            total_lamports in Just(total_lamports),
1365            pool_token_supply in 1..=total_lamports,
1366            deposit_lamports in 1..total_lamports,
1367        ) -> (u64, u64, u64) {
1368            (total_lamports - deposit_lamports, pool_token_supply.saturating_sub(deposit_lamports).max(1), deposit_lamports)
1369        }
1370    }
1371
1372    proptest! {
1373        #[test]
1374        fn deposit_and_withdraw(
1375            (total_lamports, pool_token_supply, deposit_stake) in total_tokens_and_deposit()
1376        ) {
1377            let mut stake_pool = StakePool {
1378                total_lamports,
1379                pool_token_supply,
1380                ..StakePool::default()
1381            };
1382            let deposit_result = stake_pool.calc_pool_tokens_for_deposit(deposit_stake).unwrap();
1383            prop_assume!(deposit_result > 0);
1384            stake_pool.total_lamports += deposit_stake;
1385            stake_pool.pool_token_supply += deposit_result;
1386            let withdraw_result = stake_pool.calc_lamports_withdraw_amount(deposit_result).unwrap();
1387            assert!(withdraw_result <= deposit_stake);
1388
1389            // also test splitting the withdrawal in two operations
1390            if deposit_result >= 2 {
1391                let first_half_deposit = deposit_result / 2;
1392                let first_withdraw_result = stake_pool.calc_lamports_withdraw_amount(first_half_deposit).unwrap();
1393                stake_pool.total_lamports -= first_withdraw_result;
1394                stake_pool.pool_token_supply -= first_half_deposit;
1395                let second_half_deposit = deposit_result - first_half_deposit; // do the whole thing
1396                let second_withdraw_result = stake_pool.calc_lamports_withdraw_amount(second_half_deposit).unwrap();
1397                assert!(first_withdraw_result + second_withdraw_result <= deposit_stake);
1398            }
1399        }
1400    }
1401
1402    #[test]
1403    fn specific_split_withdrawal() {
1404        let total_lamports = 1_100_000_000_000;
1405        let pool_token_supply = 1_000_000_000_000;
1406        let deposit_stake = 3;
1407        let mut stake_pool = StakePool {
1408            total_lamports,
1409            pool_token_supply,
1410            ..StakePool::default()
1411        };
1412        let deposit_result = stake_pool
1413            .calc_pool_tokens_for_deposit(deposit_stake)
1414            .unwrap();
1415        assert!(deposit_result > 0);
1416        stake_pool.total_lamports += deposit_stake;
1417        stake_pool.pool_token_supply += deposit_result;
1418        let withdraw_result = stake_pool
1419            .calc_lamports_withdraw_amount(deposit_result / 2)
1420            .unwrap();
1421        assert!(withdraw_result * 2 <= deposit_stake);
1422    }
1423
1424    #[test]
1425    fn withdraw_all() {
1426        let total_lamports = 1_100_000_000_000;
1427        let pool_token_supply = 1_000_000_000_000;
1428        let mut stake_pool = StakePool {
1429            total_lamports,
1430            pool_token_supply,
1431            ..StakePool::default()
1432        };
1433        // take everything out at once
1434        let withdraw_result = stake_pool
1435            .calc_lamports_withdraw_amount(pool_token_supply)
1436            .unwrap();
1437        assert_eq!(stake_pool.total_lamports, withdraw_result);
1438
1439        // take out 1, then the rest
1440        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1441        stake_pool.total_lamports -= withdraw_result;
1442        stake_pool.pool_token_supply -= 1;
1443        let withdraw_result = stake_pool
1444            .calc_lamports_withdraw_amount(stake_pool.pool_token_supply)
1445            .unwrap();
1446        assert_eq!(stake_pool.total_lamports, withdraw_result);
1447
1448        // take out all except 1, then the rest
1449        let mut stake_pool = StakePool {
1450            total_lamports,
1451            pool_token_supply,
1452            ..StakePool::default()
1453        };
1454        let withdraw_result = stake_pool
1455            .calc_lamports_withdraw_amount(pool_token_supply - 1)
1456            .unwrap();
1457        stake_pool.total_lamports -= withdraw_result;
1458        stake_pool.pool_token_supply = 1;
1459        assert_ne!(stake_pool.total_lamports, 0);
1460
1461        let withdraw_result = stake_pool.calc_lamports_withdraw_amount(1).unwrap();
1462        assert_eq!(stake_pool.total_lamports, withdraw_result);
1463    }
1464}