1use {
2 crate::{
3 error::TokenError,
4 extension::{Extension, ExtensionType},
5 },
6 bytemuck::{Pod, Zeroable},
7 core::{
8 cmp,
9 convert::{TryFrom, TryInto},
10 },
11 solana_address::Address,
12 solana_nullable::MaybeNull,
13 solana_program_error::ProgramResult,
14 solana_zero_copy::unaligned::{U16, U64},
15};
16#[cfg(feature = "serde")]
17use {
18 serde::{Deserialize, Serialize},
19 serde_with::{As, DisplayFromStr},
20};
21
22pub mod instruction;
24
25pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
27const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
28
29#[repr(C)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
33#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
34pub struct TransferFee {
35 pub epoch: U64, pub maximum_fee: U64,
39 pub transfer_fee_basis_points: U16,
42}
43impl TransferFee {
44 fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
51 numerator
52 .checked_add(denominator)?
53 .checked_sub(1)?
54 .checked_div(denominator)
55 }
56
57 pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
59 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
60 if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
61 Some(0)
62 } else {
63 let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
64 let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
65 .try_into() .ok()?;
67
68 Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
69 }
70 }
71
72 pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
74 pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
75 }
76
77 pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
90 let maximum_fee = u64::from(self.maximum_fee);
91 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
92 match (transfer_fee_basis_points, post_fee_amount) {
93 (0, _) => Some(post_fee_amount),
95 (_, 0) => Some(0),
97 (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
99 _ => {
100 let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
101 let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
102 let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
103
104 if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
105 post_fee_amount.checked_add(maximum_fee)
106 } else {
107 u64::try_from(raw_pre_fee_amount).ok()
109 }
110 }
111 }
112 }
113
114 pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
125 let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
126 self.calculate_fee(pre_fee_amount)
127 }
128}
129
130#[repr(C)]
132#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
133#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
134#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
135pub struct TransferFeeConfig {
136 #[cfg_attr(feature = "serde", serde(with = "As::<Option<DisplayFromStr>>"))]
138 pub transfer_fee_config_authority: MaybeNull<Address>,
139 #[cfg_attr(feature = "serde", serde(with = "As::<Option<DisplayFromStr>>"))]
141 pub withdraw_withheld_authority: MaybeNull<Address>,
142 pub withheld_amount: U64,
145 pub older_transfer_fee: TransferFee,
147 pub newer_transfer_fee: TransferFee,
149}
150impl TransferFeeConfig {
151 pub fn get_epoch_fee(&self, epoch: u64) -> &TransferFee {
153 if epoch >= self.newer_transfer_fee.epoch.into() {
154 &self.newer_transfer_fee
155 } else {
156 &self.older_transfer_fee
157 }
158 }
159 pub fn calculate_epoch_fee(&self, epoch: u64, pre_fee_amount: u64) -> Option<u64> {
161 self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
162 }
163 pub fn calculate_inverse_epoch_fee(&self, epoch: u64, post_fee_amount: u64) -> Option<u64> {
165 self.get_epoch_fee(epoch)
166 .calculate_inverse_fee(post_fee_amount)
167 }
168}
169impl Extension for TransferFeeConfig {
170 const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
171}
172
173#[repr(C)]
175#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
176#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
177#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
178pub struct TransferFeeAmount {
179 pub withheld_amount: U64,
181}
182impl TransferFeeAmount {
183 pub fn closable(&self) -> ProgramResult {
185 if self.withheld_amount == 0.into() {
186 Ok(())
187 } else {
188 Err(TokenError::AccountHasWithheldTransferFees.into())
189 }
190 }
191}
192impl Extension for TransferFeeAmount {
193 const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
194}
195
196#[cfg(test)]
197pub(crate) mod test {
198 use {super::*, core::convert::TryFrom, proptest::prelude::*, solana_address::Address};
199
200 const NEWER_EPOCH: u64 = 100;
201 const OLDER_EPOCH: u64 = 1;
202
203 pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
204 TransferFeeConfig {
205 transfer_fee_config_authority: Some(Address::new_from_array([10; 32]))
206 .try_into()
207 .unwrap(),
208 withdraw_withheld_authority: Some(Address::new_from_array([11; 32]))
209 .try_into()
210 .unwrap(),
211 withheld_amount: U64::from(u64::MAX),
212 older_transfer_fee: TransferFee {
213 epoch: U64::from(OLDER_EPOCH),
214 maximum_fee: U64::from(10),
215 transfer_fee_basis_points: U16::from(100),
216 },
217 newer_transfer_fee: TransferFee {
218 epoch: U64::from(NEWER_EPOCH),
219 maximum_fee: U64::from(5_000),
220 transfer_fee_basis_points: U16::from(1),
221 },
222 }
223 }
224
225 #[test]
226 fn epoch_fee() {
227 let transfer_fee_config = test_transfer_fee_config();
228 assert_eq!(
230 transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
231 NEWER_EPOCH.into()
232 );
233 assert_eq!(
234 transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
235 NEWER_EPOCH.into()
236 );
237 assert_eq!(
238 transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
239 NEWER_EPOCH.into()
240 );
241 assert_eq!(
243 transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
244 OLDER_EPOCH.into()
245 );
246 assert_eq!(
247 transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
248 OLDER_EPOCH.into()
249 );
250 assert_eq!(
251 transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
252 OLDER_EPOCH.into()
253 );
254 }
255
256 #[test]
257 fn calculate_fee_max() {
258 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
259 let transfer_fee = TransferFee {
260 epoch: U64::from(0),
261 maximum_fee: U64::from(5_000),
262 transfer_fee_basis_points: U16::from(1),
263 };
264 let maximum_fee = u64::from(transfer_fee.maximum_fee);
265 assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
267 assert_eq!(
269 maximum_fee,
270 transfer_fee.calculate_fee(maximum_fee * one).unwrap()
271 );
272 assert_eq!(
274 maximum_fee,
275 transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
276 );
277 assert_eq!(
279 maximum_fee,
280 transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
281 );
282 }
283
284 #[test]
285 fn calculate_fee_min() {
286 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
287 let transfer_fee = TransferFee {
288 epoch: U64::from(0),
289 maximum_fee: U64::from(5_000),
290 transfer_fee_basis_points: U16::from(1),
291 };
292 let minimum_fee = 1;
293 assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
295 assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
297 assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
299 assert_eq!(
301 minimum_fee + 1,
302 transfer_fee.calculate_fee(one + 1).unwrap()
303 );
304 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
306 }
307
308 #[test]
309 fn calculate_fee_zero() {
310 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
311 let transfer_fee = TransferFee {
312 epoch: U64::from(0),
313 maximum_fee: U64::from(u64::MAX),
314 transfer_fee_basis_points: U16::from(0),
315 };
316 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
318 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
319 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
320 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
321
322 let transfer_fee = TransferFee {
323 epoch: U64::from(0),
324 maximum_fee: U64::from(0),
325 transfer_fee_basis_points: U16::from(MAX_FEE_BASIS_POINTS),
326 };
327 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
329 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
330 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
331 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
332 }
333
334 #[test]
335 fn calculate_fee_exact_out_max() {
336 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
337 let transfer_fee = TransferFee {
338 epoch: U64::from(0),
339 maximum_fee: U64::from(5_000),
340 transfer_fee_basis_points: U16::from(1),
341 };
342 let maximum_fee = u64::from(transfer_fee.maximum_fee);
343 assert_eq!(
345 maximum_fee,
346 transfer_fee
347 .calculate_inverse_fee(u64::MAX - maximum_fee)
348 .unwrap()
349 );
350 assert_eq!(
352 maximum_fee,
353 transfer_fee
354 .calculate_inverse_fee(maximum_fee * one - maximum_fee)
355 .unwrap()
356 );
357 assert_eq!(
359 maximum_fee,
360 transfer_fee
361 .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
362 .unwrap()
363 );
364 assert_eq!(
366 maximum_fee,
367 transfer_fee
368 .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
369 .unwrap()
370 );
371 }
372
373 #[test]
374 fn calculate_pre_fee_amount_edge_cases() {
375 let maximum_fee = 5_000;
376 let transfer_fee = TransferFee {
377 epoch: U64::from(0),
378 maximum_fee: U64::from(maximum_fee),
379 transfer_fee_basis_points: U16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
380 };
381
382 assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
384
385 assert_eq!(
387 1 + maximum_fee,
388 transfer_fee.calculate_pre_fee_amount(1).unwrap()
389 );
390
391 let transfer_fee = TransferFee {
393 epoch: U64::from(0),
394 maximum_fee: U64::from(maximum_fee),
395 transfer_fee_basis_points: U16::from(0),
396 };
397 assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
398 }
399
400 #[test]
401 fn calculate_fee_exact_out_min() {
402 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
403 let transfer_fee = TransferFee {
404 epoch: U64::from(0),
405 maximum_fee: U64::from(5_000),
406 transfer_fee_basis_points: U16::from(1),
407 };
408 let minimum_fee = 1;
409 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
411 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
413 assert_eq!(
415 minimum_fee,
416 transfer_fee.calculate_inverse_fee(one - 1).unwrap()
417 );
418 assert_eq!(
420 minimum_fee + 1,
421 transfer_fee.calculate_inverse_fee(one).unwrap()
422 );
423 assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
425 }
426
427 proptest! {
428 #[test]
429 fn round_trip_fee_calculation(
430 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
431 maximum_fee in u64::MIN..=u64::MAX,
432 amount_in in 0..=u64::MAX
433 ) {
434 let transfer_fee = TransferFee {
435 epoch: U64::from(0),
436 maximum_fee: U64::from(maximum_fee),
437 transfer_fee_basis_points: U16::from(transfer_fee_basis_points),
438 };
439 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
440 let amount_out = amount_in.checked_sub(fee).unwrap();
441 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
442 let diff = fee.abs_diff(fee_exact_out);
443 let one = MAX_FEE_BASIS_POINTS as u64;
447 let precision = amount_in / one / one / one;
448 assert!(diff < precision, "diff is {} for precision {}", diff, precision);
449 }
450 }
451
452 proptest! {
453 #[test]
454 fn inverse_fee_relationship(
455 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
456 maximum_fee in u64::MIN..=u64::MAX,
457 amount_in in 0..=u64::MAX
458 ) {
459 let transfer_fee = TransferFee {
460 epoch: U64::from(0),
461 maximum_fee: U64::from(maximum_fee),
462 transfer_fee_basis_points: U16::from(transfer_fee_basis_points),
463 };
464 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
465 let amount_out = amount_in.checked_sub(fee).unwrap();
466 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
467 assert!(fee >= fee_exact_out);
468 }
469 }
470}