Skip to main content

spl_token_2022_interface/extension/transfer_fee/
mod.rs

1use {
2    crate::{
3        error::TokenError,
4        extension::{Extension, ExtensionType},
5    },
6    bytemuck::{Pod, Zeroable},
7    core::{
8        cmp,
9        convert::{TryFrom, TryInto},
10    },
11    solana_address::Address,
12    solana_nullable::MaybeNull,
13    solana_program_error::ProgramResult,
14    solana_zero_copy::unaligned::{U16, U64},
15};
16#[cfg(feature = "serde")]
17use {
18    serde::{Deserialize, Serialize},
19    serde_with::{As, DisplayFromStr},
20};
21
22/// Transfer fee extension instructions
23pub mod instruction;
24
25/// Maximum possible fee in basis points is `100%`, aka 10,000 basis points
26pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
27const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
28
29/// Transfer fee information
30#[repr(C)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
33#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
34pub struct TransferFee {
35    /// First epoch where the transfer fee takes effect
36    pub epoch: U64, // Epoch,
37    /// Maximum fee assessed on transfers, expressed as an amount of tokens
38    pub maximum_fee: U64,
39    /// Amount of transfer collected as fees, expressed as basis points of the
40    /// transfer amount (increments of `0.01%`)
41    pub transfer_fee_basis_points: U16,
42}
43impl TransferFee {
44    /// Calculate ceiling-division
45    ///
46    /// Ceiling-division
47    ///     `ceil[ numerator / denominator ]`
48    /// can be represented as a floor-division
49    ///     `floor[ (numerator + denominator - 1) / denominator]`
50    fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
51        numerator
52            .checked_add(denominator)?
53            .checked_sub(1)?
54            .checked_div(denominator)
55    }
56
57    /// Calculate the transfer fee
58    pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
59        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
60        if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
61            Some(0)
62        } else {
63            let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
64            let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
65                .try_into() // guaranteed to be okay
66                .ok()?;
67
68            Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
69        }
70    }
71
72    /// Calculate the gross transfer amount after deducting fees
73    pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
74        pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
75    }
76
77    /// Calculate the transfer amount that will result in a specified net
78    /// transfer amount.
79    ///
80    /// The original transfer amount may not always be unique due to rounding.
81    /// In this case, the smaller amount will be chosen.
82    /// e.g. Both transfer amount 10, 11 with `10%` fee rate results in net
83    /// transfer amount of 9. In this case, 10 will be chosen.
84    /// e.g. Fee rate is `100%`. In this case, 0 will be chosen.
85    ///
86    /// The original transfer amount may not always exist on large net transfer
87    /// amounts due to overflow. In this case, `None` is returned.
88    /// e.g. The net fee amount is `u64::MAX` with a positive fee rate.
89    pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
90        let maximum_fee = u64::from(self.maximum_fee);
91        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
92        match (transfer_fee_basis_points, post_fee_amount) {
93            // no fee, same amount
94            (0, _) => Some(post_fee_amount),
95            // 0 zero out, 0 in
96            (_, 0) => Some(0),
97            // 100%, cap at max fee
98            (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
99            _ => {
100                let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
101                let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
102                let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
103
104                if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
105                    post_fee_amount.checked_add(maximum_fee)
106                } else {
107                    // should return `None` if `pre_fee_amount` overflows
108                    u64::try_from(raw_pre_fee_amount).ok()
109                }
110            }
111        }
112    }
113
114    /// Calculate the fee that would produce the given output
115    ///
116    /// Note: this function is not an exact inverse operation of
117    /// `calculate_fee`. Meaning, it is not the case that:
118    ///
119    /// `calculate_fee(x) == calculate_inverse_fee(x - calculate_fee(x))`
120    ///
121    /// Only the following relationship holds:
122    ///
123    /// `calculate_fee(x) >= calculate_inverse_fee(x - calculate_fee(x))`
124    pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
125        let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
126        self.calculate_fee(pre_fee_amount)
127    }
128}
129
130/// Transfer fee extension data for mints.
131#[repr(C)]
132#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
133#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
134#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
135pub struct TransferFeeConfig {
136    /// Optional authority to set the fee
137    #[cfg_attr(feature = "serde", serde(with = "As::<Option<DisplayFromStr>>"))]
138    pub transfer_fee_config_authority: MaybeNull<Address>,
139    /// Withdraw from mint instructions must be signed by this key
140    #[cfg_attr(feature = "serde", serde(with = "As::<Option<DisplayFromStr>>"))]
141    pub withdraw_withheld_authority: MaybeNull<Address>,
142    /// Withheld transfer fee tokens that have been moved to the mint for
143    /// withdrawal
144    pub withheld_amount: U64,
145    /// Older transfer fee, used if `current epoch < new_transfer_fee.epoch`
146    pub older_transfer_fee: TransferFee,
147    /// Newer transfer fee, used if `current epoch >= new_transfer_fee.epoch`
148    pub newer_transfer_fee: TransferFee,
149}
150impl TransferFeeConfig {
151    /// Get the fee for the given epoch
152    pub fn get_epoch_fee(&self, epoch: u64) -> &TransferFee {
153        if epoch >= self.newer_transfer_fee.epoch.into() {
154            &self.newer_transfer_fee
155        } else {
156            &self.older_transfer_fee
157        }
158    }
159    /// Calculate the fee for the given epoch and input amount
160    pub fn calculate_epoch_fee(&self, epoch: u64, pre_fee_amount: u64) -> Option<u64> {
161        self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
162    }
163    /// Calculate the fee for the given epoch and output amount
164    pub fn calculate_inverse_epoch_fee(&self, epoch: u64, post_fee_amount: u64) -> Option<u64> {
165        self.get_epoch_fee(epoch)
166            .calculate_inverse_fee(post_fee_amount)
167    }
168}
169impl Extension for TransferFeeConfig {
170    const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
171}
172
173/// Transfer fee extension data for accounts.
174#[repr(C)]
175#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
176#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
177#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
178pub struct TransferFeeAmount {
179    /// Amount withheld during transfers, to be harvested to the mint
180    pub withheld_amount: U64,
181}
182impl TransferFeeAmount {
183    /// Check if the extension is in a closable state
184    pub fn closable(&self) -> ProgramResult {
185        if self.withheld_amount == 0.into() {
186            Ok(())
187        } else {
188            Err(TokenError::AccountHasWithheldTransferFees.into())
189        }
190    }
191}
192impl Extension for TransferFeeAmount {
193    const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
194}
195
196#[cfg(test)]
197pub(crate) mod test {
198    use {super::*, core::convert::TryFrom, proptest::prelude::*, solana_address::Address};
199
200    const NEWER_EPOCH: u64 = 100;
201    const OLDER_EPOCH: u64 = 1;
202
203    pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
204        TransferFeeConfig {
205            transfer_fee_config_authority: Some(Address::new_from_array([10; 32]))
206                .try_into()
207                .unwrap(),
208            withdraw_withheld_authority: Some(Address::new_from_array([11; 32]))
209                .try_into()
210                .unwrap(),
211            withheld_amount: U64::from(u64::MAX),
212            older_transfer_fee: TransferFee {
213                epoch: U64::from(OLDER_EPOCH),
214                maximum_fee: U64::from(10),
215                transfer_fee_basis_points: U16::from(100),
216            },
217            newer_transfer_fee: TransferFee {
218                epoch: U64::from(NEWER_EPOCH),
219                maximum_fee: U64::from(5_000),
220                transfer_fee_basis_points: U16::from(1),
221            },
222        }
223    }
224
225    #[test]
226    fn epoch_fee() {
227        let transfer_fee_config = test_transfer_fee_config();
228        // during epoch 100 and after, use newer transfer fee
229        assert_eq!(
230            transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
231            NEWER_EPOCH.into()
232        );
233        assert_eq!(
234            transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
235            NEWER_EPOCH.into()
236        );
237        assert_eq!(
238            transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
239            NEWER_EPOCH.into()
240        );
241        // before that, use older transfer fee
242        assert_eq!(
243            transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
244            OLDER_EPOCH.into()
245        );
246        assert_eq!(
247            transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
248            OLDER_EPOCH.into()
249        );
250        assert_eq!(
251            transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
252            OLDER_EPOCH.into()
253        );
254    }
255
256    #[test]
257    fn calculate_fee_max() {
258        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
259        let transfer_fee = TransferFee {
260            epoch: U64::from(0),
261            maximum_fee: U64::from(5_000),
262            transfer_fee_basis_points: U16::from(1),
263        };
264        let maximum_fee = u64::from(transfer_fee.maximum_fee);
265        // hit maximum fee
266        assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
267        // at exactly the max
268        assert_eq!(
269            maximum_fee,
270            transfer_fee.calculate_fee(maximum_fee * one).unwrap()
271        );
272        // one token above, normally rounds up, but we're at the max
273        assert_eq!(
274            maximum_fee,
275            transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
276        );
277        // one token below, rounds up to the max
278        assert_eq!(
279            maximum_fee,
280            transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
281        );
282    }
283
284    #[test]
285    fn calculate_fee_min() {
286        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
287        let transfer_fee = TransferFee {
288            epoch: U64::from(0),
289            maximum_fee: U64::from(5_000),
290            transfer_fee_basis_points: U16::from(1),
291        };
292        let minimum_fee = 1;
293        // hit minimum fee even with 1 token
294        assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
295        // still minimum at 2 tokens
296        assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
297        // still minimum at 10_000 tokens
298        assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
299        // 2 token fee at 10_001
300        assert_eq!(
301            minimum_fee + 1,
302            transfer_fee.calculate_fee(one + 1).unwrap()
303        );
304        // zero is always zero
305        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
306    }
307
308    #[test]
309    fn calculate_fee_zero() {
310        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
311        let transfer_fee = TransferFee {
312            epoch: U64::from(0),
313            maximum_fee: U64::from(u64::MAX),
314            transfer_fee_basis_points: U16::from(0),
315        };
316        // always zero fee
317        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
318        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
319        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
320        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
321
322        let transfer_fee = TransferFee {
323            epoch: U64::from(0),
324            maximum_fee: U64::from(0),
325            transfer_fee_basis_points: U16::from(MAX_FEE_BASIS_POINTS),
326        };
327        // always zero fee
328        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
329        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
330        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
331        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
332    }
333
334    #[test]
335    fn calculate_fee_exact_out_max() {
336        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
337        let transfer_fee = TransferFee {
338            epoch: U64::from(0),
339            maximum_fee: U64::from(5_000),
340            transfer_fee_basis_points: U16::from(1),
341        };
342        let maximum_fee = u64::from(transfer_fee.maximum_fee);
343        // hit maximum fee
344        assert_eq!(
345            maximum_fee,
346            transfer_fee
347                .calculate_inverse_fee(u64::MAX - maximum_fee)
348                .unwrap()
349        );
350        // at exactly the max
351        assert_eq!(
352            maximum_fee,
353            transfer_fee
354                .calculate_inverse_fee(maximum_fee * one - maximum_fee)
355                .unwrap()
356        );
357        // one token above, normally rounds up, but we're at the max
358        assert_eq!(
359            maximum_fee,
360            transfer_fee
361                .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
362                .unwrap()
363        );
364        // one token below, rounds up to the max
365        assert_eq!(
366            maximum_fee,
367            transfer_fee
368                .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
369                .unwrap()
370        );
371    }
372
373    #[test]
374    fn calculate_pre_fee_amount_edge_cases() {
375        let maximum_fee = 5_000;
376        let transfer_fee = TransferFee {
377            epoch: U64::from(0),
378            maximum_fee: U64::from(maximum_fee),
379            transfer_fee_basis_points: U16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
380        };
381
382        // 0 zero out, 0 in
383        assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
384
385        // cap at max fee
386        assert_eq!(
387            1 + maximum_fee,
388            transfer_fee.calculate_pre_fee_amount(1).unwrap()
389        );
390
391        // no fee same amount
392        let transfer_fee = TransferFee {
393            epoch: U64::from(0),
394            maximum_fee: U64::from(maximum_fee),
395            transfer_fee_basis_points: U16::from(0),
396        };
397        assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
398    }
399
400    #[test]
401    fn calculate_fee_exact_out_min() {
402        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
403        let transfer_fee = TransferFee {
404            epoch: U64::from(0),
405            maximum_fee: U64::from(5_000),
406            transfer_fee_basis_points: U16::from(1),
407        };
408        let minimum_fee = 1;
409        // hit minimum fee even with 1 token
410        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
411        // still minimum at 2 tokens
412        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
413        // still minimum at 9_999 tokens
414        assert_eq!(
415            minimum_fee,
416            transfer_fee.calculate_inverse_fee(one - 1).unwrap()
417        );
418        // 2 token fee at 10_000
419        assert_eq!(
420            minimum_fee + 1,
421            transfer_fee.calculate_inverse_fee(one).unwrap()
422        );
423        // zero is zero token
424        assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
425    }
426
427    proptest! {
428        #[test]
429        fn round_trip_fee_calculation(
430            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
431            maximum_fee in u64::MIN..=u64::MAX,
432            amount_in in 0..=u64::MAX
433        ) {
434            let transfer_fee = TransferFee {
435                epoch: U64::from(0),
436                maximum_fee: U64::from(maximum_fee),
437                transfer_fee_basis_points: U16::from(transfer_fee_basis_points),
438            };
439            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
440            let amount_out = amount_in.checked_sub(fee).unwrap();
441            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
442            let diff = fee.abs_diff(fee_exact_out);
443            // We lose precision with every division by 10000, so for huge amounts,
444            // the difference can be in the hundreds. This comes out to less than
445            // 1 / 10^15
446            let one = MAX_FEE_BASIS_POINTS as u64;
447            let precision = amount_in / one / one / one;
448            assert!(diff < precision, "diff is {} for precision {}", diff, precision);
449        }
450    }
451
452    proptest! {
453        #[test]
454        fn inverse_fee_relationship(
455            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
456            maximum_fee in u64::MIN..=u64::MAX,
457            amount_in in 0..=u64::MAX
458        ) {
459            let transfer_fee = TransferFee {
460                epoch: U64::from(0),
461                maximum_fee: U64::from(maximum_fee),
462                transfer_fee_basis_points: U16::from(transfer_fee_basis_points),
463            };
464            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
465            let amount_out = amount_in.checked_sub(fee).unwrap();
466            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
467            assert!(fee >= fee_exact_out);
468        }
469    }
470}