1#[cfg(feature = "serde-traits")]
4use serde::{Deserialize, Serialize};
5use {
6 crate::{
7 error::TokenError,
8 extension::{
9 confidential_mint_burn::ConfidentialMintBurn,
10 confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11 confidential_transfer_fee::{
12 ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13 },
14 cpi_guard::CpiGuard,
15 default_account_state::DefaultAccountState,
16 group_member_pointer::GroupMemberPointer,
17 group_pointer::GroupPointer,
18 immutable_owner::ImmutableOwner,
19 interest_bearing_mint::InterestBearingConfig,
20 memo_transfer::MemoTransfer,
21 metadata_pointer::MetadataPointer,
22 mint_close_authority::MintCloseAuthority,
23 non_transferable::{NonTransferable, NonTransferableAccount},
24 pausable::{PausableAccount, PausableConfig},
25 permanent_delegate::PermanentDelegate,
26 scaled_ui_amount::ScaledUiAmountConfig,
27 transfer_fee::{TransferFeeAmount, TransferFeeConfig},
28 transfer_hook::{TransferHook, TransferHookAccount},
29 },
30 pod::{PodAccount, PodMint},
31 state::{Account, Mint, Multisig, PackedSizeOf},
32 },
33 bytemuck::{Pod, Zeroable},
34 num_enum::{IntoPrimitive, TryFromPrimitive},
35 solana_account_info::AccountInfo,
36 solana_program_error::ProgramError,
37 solana_program_pack::{IsInitialized, Pack},
38 spl_pod::{
39 bytemuck::{pod_from_bytes, pod_from_bytes_mut, pod_get_packed_len},
40 primitives::PodU16,
41 },
42 spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
43 spl_type_length_value::variable_len_pack::VariableLenPack,
44 std::{
45 cmp::Ordering,
46 convert::{TryFrom, TryInto},
47 mem::size_of,
48 },
49};
50
51pub mod confidential_transfer;
53pub mod confidential_transfer_fee;
55pub mod cpi_guard;
57pub mod default_account_state;
59pub mod group_member_pointer;
61pub mod group_pointer;
63pub mod immutable_owner;
65pub mod interest_bearing_mint;
67pub mod memo_transfer;
69pub mod metadata_pointer;
71pub mod mint_close_authority;
73pub mod non_transferable;
75pub mod pausable;
77pub mod permanent_delegate;
79pub mod reallocate;
81pub mod scaled_ui_amount;
83pub mod token_group;
85pub mod token_metadata;
87pub mod transfer_fee;
89pub mod transfer_hook;
91
92pub mod confidential_mint_burn;
94
95#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
97#[repr(transparent)]
98pub struct Length(PodU16);
99impl From<Length> for usize {
100 fn from(n: Length) -> Self {
101 Self::from(u16::from(n.0))
102 }
103}
104impl TryFrom<usize> for Length {
105 type Error = ProgramError;
106 fn try_from(n: usize) -> Result<Self, Self::Error> {
107 u16::try_from(n)
108 .map(|v| Self(PodU16::from(v)))
109 .map_err(|_| ProgramError::AccountDataTooSmall)
110 }
111}
112
113fn get_tlv_indices(type_start: usize) -> TlvIndices {
115 let length_start = type_start.saturating_add(size_of::<ExtensionType>());
116 let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
117 TlvIndices {
118 type_start,
119 length_start,
120 value_start,
121 }
122}
123
124const fn adjust_len_for_multisig(account_len: usize) -> usize {
127 if account_len == Multisig::LEN {
128 account_len.saturating_add(size_of::<ExtensionType>())
129 } else {
130 account_len
131 }
132}
133
134const fn add_type_and_length_to_len(value_len: usize) -> usize {
137 value_len
138 .saturating_add(size_of::<ExtensionType>())
139 .saturating_add(pod_get_packed_len::<Length>())
140}
141
142#[derive(Debug)]
145struct TlvIndices {
146 pub type_start: usize,
147 pub length_start: usize,
148 pub value_start: usize,
149}
150fn get_extension_indices<V: Extension>(
151 tlv_data: &[u8],
152 init: bool,
153) -> Result<TlvIndices, ProgramError> {
154 let mut start_index = 0;
155 let v_account_type = V::TYPE.get_account_type();
156 while start_index < tlv_data.len() {
157 let tlv_indices = get_tlv_indices(start_index);
158 if tlv_data.len() < tlv_indices.value_start {
159 return Err(ProgramError::InvalidAccountData);
160 }
161 let extension_type =
162 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
163 let account_type = extension_type.get_account_type();
164 if extension_type == V::TYPE {
165 return Ok(tlv_indices);
167 } else if extension_type == ExtensionType::Uninitialized {
170 if init {
171 return Ok(tlv_indices);
172 } else {
173 return Err(TokenError::ExtensionNotFound.into());
174 }
175 } else if v_account_type != account_type {
176 return Err(TokenError::ExtensionTypeMismatch.into());
177 } else {
178 let length = pod_from_bytes::<Length>(
179 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
180 )?;
181 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
182 start_index = value_end_index;
183 }
184 }
185 Err(ProgramError::InvalidAccountData)
186}
187
188#[derive(Debug, PartialEq)]
191struct TlvDataInfo {
192 extension_types: Vec<ExtensionType>,
194 used_len: usize,
199}
200
201fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
204 let mut extension_types = vec![];
205 let mut start_index = 0;
206 while start_index < tlv_data.len() {
207 let tlv_indices = get_tlv_indices(start_index);
208 if tlv_data.len() < tlv_indices.length_start {
209 return Ok(TlvDataInfo {
212 extension_types,
213 used_len: tlv_indices.type_start,
214 });
215 }
216 let extension_type =
217 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
218 if extension_type == ExtensionType::Uninitialized {
219 return Ok(TlvDataInfo {
220 extension_types,
221 used_len: tlv_indices.type_start,
222 });
223 } else {
224 if tlv_data.len() < tlv_indices.value_start {
225 return Err(ProgramError::InvalidAccountData);
227 }
228 extension_types.push(extension_type);
229 let length = pod_from_bytes::<Length>(
230 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
231 )?;
232
233 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
234 if value_end_index > tlv_data.len() {
235 return Err(ProgramError::InvalidAccountData);
237 }
238 start_index = value_end_index;
239 }
240 }
241 Ok(TlvDataInfo {
242 extension_types,
243 used_len: start_index,
244 })
245}
246
247fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
248 if tlv_data.is_empty() {
249 Ok(None)
250 } else {
251 let tlv_indices = get_tlv_indices(0);
252 if tlv_data.len() <= tlv_indices.length_start {
253 return Ok(None);
254 }
255 let extension_type =
256 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
257 if extension_type == ExtensionType::Uninitialized {
258 Ok(None)
259 } else {
260 Ok(Some(extension_type))
261 }
262 }
263}
264
265fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
266 if input.len() == Multisig::LEN || input.len() < minimum_len {
267 Err(ProgramError::InvalidAccountData)
268 } else {
269 Ok(())
270 }
271}
272
273fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
274 if account_type != S::ACCOUNT_TYPE {
275 Err(ProgramError::InvalidAccountData)
276 } else {
277 Ok(())
278 }
279}
280
281const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
302const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
305
306fn type_and_tlv_indices<S: BaseState>(
307 rest_input: &[u8],
308) -> Result<Option<(usize, usize)>, ProgramError> {
309 if rest_input.is_empty() {
310 Ok(None)
311 } else {
312 let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
313 let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
315 if rest_input.len() < tlv_start_index {
316 return Err(ProgramError::InvalidAccountData);
317 }
318 if rest_input[..account_type_index] != vec![0; account_type_index] {
319 Err(ProgramError::InvalidAccountData)
320 } else {
321 Ok(Some((account_type_index, tlv_start_index)))
322 }
323 }
324}
325
326fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
329 const ACCOUNT_INITIALIZED_INDEX: usize = 108; if input.len() != BASE_ACCOUNT_LENGTH {
332 return Err(ProgramError::InvalidAccountData);
333 }
334 Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
335}
336
337fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
338 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
339 return Err(ProgramError::InvalidAccountData);
340 }
341 let TlvIndices {
342 type_start: _,
343 length_start,
344 value_start,
345 } = get_extension_indices::<V>(tlv_data, false)?;
346 let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
349 let value_end = value_start.saturating_add(usize::from(*length));
350 if tlv_data.len() < value_end {
351 return Err(ProgramError::InvalidAccountData);
352 }
353 Ok(&tlv_data[value_start..value_end])
354}
355
356fn get_extension_bytes_mut<S: BaseState, V: Extension>(
357 tlv_data: &mut [u8],
358) -> Result<&mut [u8], ProgramError> {
359 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
360 return Err(ProgramError::InvalidAccountData);
361 }
362 let TlvIndices {
363 type_start: _,
364 length_start,
365 value_start,
366 } = get_extension_indices::<V>(tlv_data, false)?;
367 let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
370 let value_end = value_start.saturating_add(usize::from(*length));
371 if tlv_data.len() < value_end {
372 return Err(ProgramError::InvalidAccountData);
373 }
374 Ok(&mut tlv_data[value_start..value_end])
375}
376
377fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
383 tlv_data: &[u8],
384 new_extension_len: usize,
385) -> Result<usize, ProgramError> {
386 let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
388 let tlv_info = get_tlv_data_info(tlv_data)?;
389 let current_len = tlv_info
392 .used_len
393 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
394 let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
396 .map(|x| add_type_and_length_to_len(x.len()))
397 .unwrap_or(0);
398 let new_len = current_len
399 .saturating_sub(current_extension_len)
400 .saturating_add(new_extension_tlv_len);
401 Ok(adjust_len_for_multisig(new_len))
402}
403
404pub trait BaseStateWithExtensions<S: BaseState> {
406 fn get_tlv_data(&self) -> &[u8];
408
409 fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
411 get_extension_bytes::<S, V>(self.get_tlv_data())
412 }
413
414 fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
416 pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
417 }
418
419 fn get_variable_len_extension<V: Extension + VariableLenPack>(
421 &self,
422 ) -> Result<V, ProgramError> {
423 let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
424 V::unpack_from_slice(data)
425 }
426
427 fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
429 get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
430 }
431
432 fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
434 get_first_extension_type(self.get_tlv_data())
435 }
436
437 fn try_get_account_len(&self) -> Result<usize, ProgramError> {
439 let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
440 if tlv_info.extension_types.is_empty() {
441 Ok(S::SIZE_OF)
442 } else {
443 let total_len = tlv_info
444 .used_len
445 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
446 Ok(adjust_len_for_multisig(total_len))
447 }
448 }
449 fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
454 try_get_new_account_len_for_extension_len::<S, V>(
455 self.get_tlv_data(),
456 pod_get_packed_len::<V>(),
457 )
458 }
459
460 fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
463 &self,
464 new_extension: &V,
465 ) -> Result<usize, ProgramError> {
466 try_get_new_account_len_for_extension_len::<S, V>(
467 self.get_tlv_data(),
468 new_extension.get_packed_len()?,
469 )
470 }
471}
472
473#[derive(Clone, Debug, PartialEq)]
476pub struct StateWithExtensionsOwned<S: BaseState> {
477 pub base: S,
479 tlv_data: Vec<u8>,
481}
482impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
483 pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
487 check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
488 let mut rest = input.split_off(S::SIZE_OF);
489 let base = S::unpack(&input)?;
490 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
491 let account_type = AccountType::try_from(rest[account_type_index])
493 .map_err(|_| ProgramError::InvalidAccountData)?;
494 check_account_type::<S>(account_type)?;
495 let tlv_data = rest.split_off(tlv_start_index);
496 Ok(Self { base, tlv_data })
497 } else {
498 Ok(Self {
499 base,
500 tlv_data: vec![],
501 })
502 }
503 }
504}
505
506impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
507 fn get_tlv_data(&self) -> &[u8] {
508 &self.tlv_data
509 }
510}
511
512#[derive(Debug, PartialEq)]
515pub struct StateWithExtensions<'data, S: BaseState + Pack> {
516 pub base: S,
518 tlv_data: &'data [u8],
520}
521impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
522 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
526 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
527 let (base_data, rest) = input.split_at(S::SIZE_OF);
528 let base = S::unpack(base_data)?;
529 let tlv_data = unpack_tlv_data::<S>(rest)?;
530 Ok(Self { base, tlv_data })
531 }
532}
533impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
534 fn get_tlv_data(&self) -> &[u8] {
535 self.tlv_data
536 }
537}
538
539#[derive(Debug, PartialEq)]
542pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
543 pub base: &'data S,
545 tlv_data: &'data [u8],
547}
548impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
549 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
553 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
554 let (base_data, rest) = input.split_at(S::SIZE_OF);
555 let base = pod_from_bytes::<S>(base_data)?;
556 if !base.is_initialized() {
557 Err(ProgramError::UninitializedAccount)
558 } else {
559 let tlv_data = unpack_tlv_data::<S>(rest)?;
560 Ok(Self { base, tlv_data })
561 }
562 }
563}
564impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
565 fn get_tlv_data(&self) -> &[u8] {
566 self.tlv_data
567 }
568}
569
570pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
572 fn get_tlv_data_mut(&mut self) -> &mut [u8];
574
575 fn get_account_type_mut(&mut self) -> &mut [u8];
577
578 fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
580 get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
581 }
582
583 fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
586 pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
587 }
588
589 fn pack_variable_len_extension<V: Extension + VariableLenPack>(
592 &mut self,
593 extension: &V,
594 ) -> Result<(), ProgramError> {
595 let data = self.get_extension_bytes_mut::<V>()?;
596 extension.pack_into_slice(data)
599 }
600
601 fn init_extension<V: Extension + Pod + Default>(
607 &mut self,
608 overwrite: bool,
609 ) -> Result<&mut V, ProgramError> {
610 let length = pod_get_packed_len::<V>();
611 let buffer = self.alloc::<V>(length, overwrite)?;
612 let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
613 *extension_ref = V::default();
614 Ok(extension_ref)
615 }
616
617 fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
623 &mut self,
624 new_extension: &V,
625 ) -> Result<(), ProgramError> {
626 let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
627 new_extension.pack_into_slice(data)
628 }
629
630 fn realloc<V: Extension + VariableLenPack>(
640 &mut self,
641 length: usize,
642 ) -> Result<&mut [u8], ProgramError> {
643 let tlv_data = self.get_tlv_data_mut();
644 let TlvIndices {
645 type_start: _,
646 length_start,
647 value_start,
648 } = get_extension_indices::<V>(tlv_data, false)?;
649 let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
650 let data_len = tlv_data.len();
651
652 let length_ref = pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
653 let old_length = usize::from(*length_ref);
654
655 if old_length < length {
657 let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
658 if new_tlv_len > data_len {
659 return Err(ProgramError::InvalidAccountData);
660 }
661 }
662
663 *length_ref = Length::try_from(length)?;
666
667 let old_value_end = value_start.saturating_add(old_length);
668 let new_value_end = value_start.saturating_add(length);
669 tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
670 match old_length.cmp(&length) {
671 Ordering::Greater => {
672 let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
674 tlv_data[new_tlv_len..tlv_len].fill(0);
675 }
676 Ordering::Less => {
677 tlv_data[old_value_end..new_value_end].fill(0);
679 }
680 Ordering::Equal => {} }
682
683 Ok(&mut tlv_data[value_start..new_value_end])
684 }
685
686 fn init_variable_len_extension<V: Extension + VariableLenPack>(
692 &mut self,
693 extension: &V,
694 overwrite: bool,
695 ) -> Result<(), ProgramError> {
696 let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
697 extension.pack_into_slice(data)
698 }
699
700 fn alloc<V: Extension>(
702 &mut self,
703 length: usize,
704 overwrite: bool,
705 ) -> Result<&mut [u8], ProgramError> {
706 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
707 return Err(ProgramError::InvalidAccountData);
708 }
709 let tlv_data = self.get_tlv_data_mut();
710 let TlvIndices {
711 type_start,
712 length_start,
713 value_start,
714 } = get_extension_indices::<V>(tlv_data, true)?;
715
716 if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
717 return Err(ProgramError::InvalidAccountData);
718 }
719 let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
720
721 if extension_type == ExtensionType::Uninitialized || overwrite {
722 let extension_type_array: [u8; 2] = V::TYPE.into();
724 let extension_type_ref = &mut tlv_data[type_start..length_start];
725 extension_type_ref.copy_from_slice(&extension_type_array);
726 let length_ref =
728 pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
729
730 if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
733 return Err(TokenError::InvalidLengthForAlloc.into());
734 }
735
736 *length_ref = Length::try_from(length)?;
737
738 let value_end = value_start.saturating_add(length);
739 Ok(&mut tlv_data[value_start..value_end])
740 } else {
741 Err(TokenError::ExtensionAlreadyInitialized.into())
743 }
744 }
745
746 fn init_account_extension_from_type(
753 &mut self,
754 extension_type: ExtensionType,
755 ) -> Result<(), ProgramError> {
756 if extension_type.get_account_type() != AccountType::Account {
757 return Ok(());
758 }
759 match extension_type {
760 ExtensionType::TransferFeeAmount => {
761 self.init_extension::<TransferFeeAmount>(true).map(|_| ())
762 }
763 ExtensionType::ImmutableOwner => {
764 self.init_extension::<ImmutableOwner>(true).map(|_| ())
765 }
766 ExtensionType::NonTransferableAccount => self
767 .init_extension::<NonTransferableAccount>(true)
768 .map(|_| ()),
769 ExtensionType::TransferHookAccount => {
770 self.init_extension::<TransferHookAccount>(true).map(|_| ())
771 }
772 ExtensionType::ConfidentialTransferAccount => Ok(()),
775 ExtensionType::PausableAccount => {
776 self.init_extension::<PausableAccount>(true).map(|_| ())
777 }
778 #[cfg(test)]
779 ExtensionType::AccountPaddingTest => {
780 self.init_extension::<AccountPaddingTest>(true).map(|_| ())
781 }
782 _ => unreachable!(),
783 }
784 }
785
786 fn init_account_type(&mut self) -> Result<(), ProgramError> {
791 let first_extension_type = self.get_first_extension_type()?;
792 let account_type = self.get_account_type_mut();
793 if !account_type.is_empty() {
794 if let Some(extension_type) = first_extension_type {
795 let account_type = extension_type.get_account_type();
796 if account_type != S::ACCOUNT_TYPE {
797 return Err(TokenError::ExtensionBaseMismatch.into());
798 }
799 }
800 account_type[0] = S::ACCOUNT_TYPE.into();
801 }
802 Ok(())
803 }
804
805 fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
808 if let Some(extension_type) = self.get_first_extension_type()? {
809 let account_type = extension_type.get_account_type();
810 if account_type != S::ACCOUNT_TYPE {
811 return Err(TokenError::ExtensionBaseMismatch.into());
812 }
813 }
814 Ok(())
815 }
816}
817
818#[derive(Debug, PartialEq)]
821pub struct StateWithExtensionsMut<'data, S: BaseState> {
822 pub base: S,
824 base_data: &'data mut [u8],
826 account_type: &'data mut [u8],
828 tlv_data: &'data mut [u8],
830}
831impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
832 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
836 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
837 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
838 let base = S::unpack(base_data)?;
839 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
840 Ok(Self {
841 base,
842 base_data,
843 account_type,
844 tlv_data,
845 })
846 }
847
848 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
853 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
854 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
855 let base = S::unpack_unchecked(base_data)?;
856 if base.is_initialized() {
857 return Err(TokenError::AlreadyInUse.into());
858 }
859 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
860 let state = Self {
861 base,
862 base_data,
863 account_type,
864 tlv_data,
865 };
866 state.check_account_type_matches_extension_type()?;
867 Ok(state)
868 }
869
870 pub fn pack_base(&mut self) {
872 S::pack_into_slice(&self.base, self.base_data);
873 }
874}
875impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
876 fn get_tlv_data(&self) -> &[u8] {
877 self.tlv_data
878 }
879}
880impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
881 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
882 self.tlv_data
883 }
884 fn get_account_type_mut(&mut self) -> &mut [u8] {
885 self.account_type
886 }
887}
888
889#[derive(Debug, PartialEq)]
892pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
893 pub base: &'data mut S,
895 account_type: &'data mut [u8],
897 tlv_data: &'data mut [u8],
899}
900impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
901 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
905 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
906 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
907 let base = pod_from_bytes_mut::<S>(base_data)?;
908 if !base.is_initialized() {
909 Err(ProgramError::UninitializedAccount)
910 } else {
911 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
912 Ok(Self {
913 base,
914 account_type,
915 tlv_data,
916 })
917 }
918 }
919
920 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
925 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
926 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
927 let base = pod_from_bytes_mut::<S>(base_data)?;
928 if base.is_initialized() {
929 return Err(TokenError::AlreadyInUse.into());
930 }
931 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
932 let state = Self {
933 base,
934 account_type,
935 tlv_data,
936 };
937 state.check_account_type_matches_extension_type()?;
938 Ok(state)
939 }
940}
941
942impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
943 fn get_tlv_data(&self) -> &[u8] {
944 self.tlv_data
945 }
946}
947impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
948 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
949 self.tlv_data
950 }
951 fn get_account_type_mut(&mut self) -> &mut [u8] {
952 self.account_type
953 }
954}
955
956fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
957 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
958 let account_type = AccountType::try_from(rest[account_type_index])
960 .map_err(|_| ProgramError::InvalidAccountData)?;
961 check_account_type::<S>(account_type)?;
962 Ok(&rest[tlv_start_index..])
963 } else {
964 Ok(&[])
965 }
966}
967
968fn unpack_type_and_tlv_data_with_check_mut<
969 S: BaseState,
970 F: Fn(AccountType) -> Result<(), ProgramError>,
971>(
972 rest: &mut [u8],
973 check_fn: F,
974) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
975 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
976 let account_type = AccountType::try_from(rest[account_type_index])
978 .map_err(|_| ProgramError::InvalidAccountData)?;
979 check_fn(account_type)?;
980 let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
981 Ok((
982 &mut account_type[account_type_index..tlv_start_index],
983 tlv_data,
984 ))
985 } else {
986 Ok((&mut [], &mut []))
987 }
988}
989
990fn unpack_type_and_tlv_data_mut<S: BaseState>(
991 rest: &mut [u8],
992) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
993 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
994}
995
996fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
997 rest: &mut [u8],
998) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
999 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1000 if account_type != AccountType::Uninitialized {
1001 Err(ProgramError::InvalidAccountData)
1002 } else {
1003 Ok(())
1004 }
1005 })
1006}
1007
1008pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1013 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1014 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1015 if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1016 return Err(ProgramError::InvalidAccountData);
1017 }
1018 if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1019 let mut account_type = AccountType::try_from(rest[account_type_index])
1020 .map_err(|_| ProgramError::InvalidAccountData)?;
1021 if account_type == AccountType::Uninitialized {
1022 rest[account_type_index] = S::ACCOUNT_TYPE.into();
1023 account_type = S::ACCOUNT_TYPE;
1024 }
1025 check_account_type::<S>(account_type)?;
1026 Ok(())
1027 } else {
1028 Err(ProgramError::InvalidAccountData)
1029 }
1030}
1031
1032#[repr(u8)]
1037#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1038pub enum AccountType {
1039 Uninitialized,
1041 Mint,
1043 Account,
1045}
1046impl Default for AccountType {
1047 fn default() -> Self {
1048 Self::Uninitialized
1049 }
1050}
1051
1052#[repr(u16)]
1056#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
1057#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
1058#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1059pub enum ExtensionType {
1060 Uninitialized,
1063 TransferFeeConfig,
1066 TransferFeeAmount,
1068 MintCloseAuthority,
1070 ConfidentialTransferMint,
1072 ConfidentialTransferAccount,
1074 DefaultAccountState,
1076 ImmutableOwner,
1078 MemoTransfer,
1080 NonTransferable,
1082 InterestBearingConfig,
1084 CpiGuard,
1086 PermanentDelegate,
1088 NonTransferableAccount,
1091 TransferHook,
1094 TransferHookAccount,
1097 ConfidentialTransferFeeConfig,
1100 ConfidentialTransferFeeAmount,
1102 MetadataPointer,
1105 TokenMetadata,
1107 GroupPointer,
1110 TokenGroup,
1112 GroupMemberPointer,
1115 TokenGroupMember,
1117 ConfidentialMintBurn,
1119 ScaledUiAmount,
1121 Pausable,
1123 PausableAccount,
1125
1126 #[cfg(test)]
1128 VariableLenMintTest = u16::MAX - 2,
1129 #[cfg(test)]
1132 AccountPaddingTest,
1133 #[cfg(test)]
1136 MintPaddingTest,
1137}
1138impl TryFrom<&[u8]> for ExtensionType {
1139 type Error = ProgramError;
1140 fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1141 Self::try_from(u16::from_le_bytes(
1142 a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1143 ))
1144 .map_err(|_| ProgramError::InvalidAccountData)
1145 }
1146}
1147impl From<ExtensionType> for [u8; 2] {
1148 fn from(a: ExtensionType) -> Self {
1149 u16::from(a).to_le_bytes()
1150 }
1151}
1152impl ExtensionType {
1153 const fn sized(&self) -> bool {
1158 match self {
1159 ExtensionType::TokenMetadata => false,
1160 #[cfg(test)]
1161 ExtensionType::VariableLenMintTest => false,
1162 _ => true,
1163 }
1164 }
1165
1166 fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1170 if !self.sized() {
1171 return Err(ProgramError::InvalidArgument);
1172 }
1173 Ok(match self {
1174 ExtensionType::Uninitialized => 0,
1175 ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
1176 ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
1177 ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
1178 ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
1179 ExtensionType::ConfidentialTransferMint => {
1180 pod_get_packed_len::<ConfidentialTransferMint>()
1181 }
1182 ExtensionType::ConfidentialTransferAccount => {
1183 pod_get_packed_len::<ConfidentialTransferAccount>()
1184 }
1185 ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
1186 ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
1187 ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
1188 ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
1189 ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
1190 ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
1191 ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
1192 ExtensionType::TransferHook => pod_get_packed_len::<TransferHook>(),
1193 ExtensionType::TransferHookAccount => pod_get_packed_len::<TransferHookAccount>(),
1194 ExtensionType::ConfidentialTransferFeeConfig => {
1195 pod_get_packed_len::<ConfidentialTransferFeeConfig>()
1196 }
1197 ExtensionType::ConfidentialTransferFeeAmount => {
1198 pod_get_packed_len::<ConfidentialTransferFeeAmount>()
1199 }
1200 ExtensionType::MetadataPointer => pod_get_packed_len::<MetadataPointer>(),
1201 ExtensionType::TokenMetadata => unreachable!(),
1202 ExtensionType::GroupPointer => pod_get_packed_len::<GroupPointer>(),
1203 ExtensionType::TokenGroup => pod_get_packed_len::<TokenGroup>(),
1204 ExtensionType::GroupMemberPointer => pod_get_packed_len::<GroupMemberPointer>(),
1205 ExtensionType::TokenGroupMember => pod_get_packed_len::<TokenGroupMember>(),
1206 ExtensionType::ConfidentialMintBurn => pod_get_packed_len::<ConfidentialMintBurn>(),
1207 ExtensionType::ScaledUiAmount => pod_get_packed_len::<ScaledUiAmountConfig>(),
1208 ExtensionType::Pausable => pod_get_packed_len::<PausableConfig>(),
1209 ExtensionType::PausableAccount => pod_get_packed_len::<PausableAccount>(),
1210 #[cfg(test)]
1211 ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
1212 #[cfg(test)]
1213 ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
1214 #[cfg(test)]
1215 ExtensionType::VariableLenMintTest => unreachable!(),
1216 })
1217 }
1218
1219 fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1223 Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1224 }
1225
1226 fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1230 let mut extensions = vec![];
1232 for extension_type in extension_types {
1233 if !extensions.contains(&extension_type) {
1234 extensions.push(extension_type);
1235 }
1236 }
1237 extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1238 }
1239
1240 pub fn try_calculate_account_len<S: BaseState>(
1244 extension_types: &[Self],
1245 ) -> Result<usize, ProgramError> {
1246 if extension_types.is_empty() {
1247 Ok(S::SIZE_OF)
1248 } else {
1249 let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1250 let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1251 Ok(adjust_len_for_multisig(total_len))
1252 }
1253 }
1254
1255 pub fn get_account_type(&self) -> AccountType {
1257 match self {
1258 ExtensionType::Uninitialized => AccountType::Uninitialized,
1259 ExtensionType::TransferFeeConfig
1260 | ExtensionType::MintCloseAuthority
1261 | ExtensionType::ConfidentialTransferMint
1262 | ExtensionType::DefaultAccountState
1263 | ExtensionType::NonTransferable
1264 | ExtensionType::InterestBearingConfig
1265 | ExtensionType::PermanentDelegate
1266 | ExtensionType::TransferHook
1267 | ExtensionType::ConfidentialTransferFeeConfig
1268 | ExtensionType::MetadataPointer
1269 | ExtensionType::TokenMetadata
1270 | ExtensionType::GroupPointer
1271 | ExtensionType::TokenGroup
1272 | ExtensionType::GroupMemberPointer
1273 | ExtensionType::ConfidentialMintBurn
1274 | ExtensionType::TokenGroupMember
1275 | ExtensionType::ScaledUiAmount
1276 | ExtensionType::Pausable => AccountType::Mint,
1277 ExtensionType::ImmutableOwner
1278 | ExtensionType::TransferFeeAmount
1279 | ExtensionType::ConfidentialTransferAccount
1280 | ExtensionType::MemoTransfer
1281 | ExtensionType::NonTransferableAccount
1282 | ExtensionType::TransferHookAccount
1283 | ExtensionType::CpiGuard
1284 | ExtensionType::ConfidentialTransferFeeAmount
1285 | ExtensionType::PausableAccount => AccountType::Account,
1286 #[cfg(test)]
1287 ExtensionType::VariableLenMintTest => AccountType::Mint,
1288 #[cfg(test)]
1289 ExtensionType::AccountPaddingTest => AccountType::Account,
1290 #[cfg(test)]
1291 ExtensionType::MintPaddingTest => AccountType::Mint,
1292 }
1293 }
1294
1295 pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1298 let mut account_extension_types = vec![];
1299 for extension_type in mint_extension_types {
1300 match extension_type {
1301 ExtensionType::TransferFeeConfig => {
1302 account_extension_types.push(ExtensionType::TransferFeeAmount);
1303 }
1304 ExtensionType::NonTransferable => {
1305 account_extension_types.push(ExtensionType::NonTransferableAccount);
1306 account_extension_types.push(ExtensionType::ImmutableOwner);
1307 }
1308 ExtensionType::TransferHook => {
1309 account_extension_types.push(ExtensionType::TransferHookAccount);
1310 }
1311 ExtensionType::Pausable => {
1312 account_extension_types.push(ExtensionType::PausableAccount);
1313 }
1314 #[cfg(test)]
1315 ExtensionType::MintPaddingTest => {
1316 account_extension_types.push(ExtensionType::AccountPaddingTest);
1317 }
1318 _ => {}
1319 }
1320 }
1321 account_extension_types
1322 }
1323
1324 pub fn check_for_invalid_mint_extension_combinations(
1326 mint_extension_types: &[Self],
1327 ) -> Result<(), TokenError> {
1328 let mut transfer_fee_config = false;
1329 let mut confidential_transfer_mint = false;
1330 let mut confidential_transfer_fee_config = false;
1331 let mut confidential_mint_burn = false;
1332 let mut interest_bearing = false;
1333 let mut scaled_ui_amount = false;
1334
1335 for extension_type in mint_extension_types {
1336 match extension_type {
1337 ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1338 ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1339 ExtensionType::ConfidentialTransferFeeConfig => {
1340 confidential_transfer_fee_config = true
1341 }
1342 ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1343 ExtensionType::InterestBearingConfig => interest_bearing = true,
1344 ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1345 _ => (),
1346 }
1347 }
1348
1349 if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1350 {
1351 return Err(TokenError::InvalidExtensionCombination);
1352 }
1353
1354 if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1355 return Err(TokenError::InvalidExtensionCombination);
1356 }
1357
1358 if confidential_mint_burn && !confidential_transfer_mint {
1359 return Err(TokenError::InvalidExtensionCombination);
1360 }
1361
1362 if scaled_ui_amount && interest_bearing {
1363 return Err(TokenError::InvalidExtensionCombination);
1364 }
1365
1366 Ok(())
1367 }
1368}
1369
1370pub trait BaseState: PackedSizeOf + IsInitialized {
1372 const ACCOUNT_TYPE: AccountType;
1374}
1375impl BaseState for Account {
1376 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1377}
1378impl BaseState for Mint {
1379 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1380}
1381impl BaseState for PodAccount {
1382 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1383}
1384impl BaseState for PodMint {
1385 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1386}
1387
1388pub trait Extension {
1391 const TYPE: ExtensionType;
1393}
1394
1395#[cfg(test)]
1404#[repr(C)]
1405#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1406pub struct MintPaddingTest {
1407 pub padding1: [u8; 128],
1409 pub padding2: [u8; 48],
1411 pub padding3: [u8; 9],
1413}
1414#[cfg(test)]
1415impl Extension for MintPaddingTest {
1416 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1417}
1418#[cfg(test)]
1419impl Default for MintPaddingTest {
1420 fn default() -> Self {
1421 Self {
1422 padding1: [1; 128],
1423 padding2: [2; 48],
1424 padding3: [3; 9],
1425 }
1426 }
1427}
1428#[cfg(test)]
1430#[repr(C)]
1431#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1432pub struct AccountPaddingTest(MintPaddingTest);
1433#[cfg(test)]
1434impl Extension for AccountPaddingTest {
1435 const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1436}
1437
1438pub(crate) fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1453 account_info: &AccountInfo,
1454 new_extension: &V,
1455 overwrite: bool,
1456) -> Result<(), ProgramError> {
1457 let previous_account_len = account_info.try_data_len()?;
1458 let new_account_len = {
1459 let data = account_info.try_borrow_data()?;
1460 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1461 state.try_get_new_account_len::<V>()?
1462 };
1463
1464 if new_account_len > previous_account_len {
1466 account_info.realloc(new_account_len, false)?;
1467 }
1468 let mut buffer = account_info.try_borrow_mut_data()?;
1469 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1470 set_account_type::<S>(*buffer)?;
1471 }
1472 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1473
1474 let extension = state.init_extension::<V>(overwrite)?;
1476 *extension = *new_extension;
1477
1478 Ok(())
1479}
1480
1481pub(crate) fn alloc_and_serialize_variable_len_extension<
1490 S: BaseState + Pod,
1491 V: Extension + VariableLenPack,
1492>(
1493 account_info: &AccountInfo,
1494 new_extension: &V,
1495 overwrite: bool,
1496) -> Result<(), ProgramError> {
1497 let previous_account_len = account_info.try_data_len()?;
1498 let (new_account_len, extension_already_exists) = {
1499 let data = account_info.try_borrow_data()?;
1500 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1501 let new_account_len =
1502 state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1503 let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1504 (new_account_len, extension_already_exists)
1505 };
1506
1507 if extension_already_exists && !overwrite {
1508 return Err(TokenError::ExtensionAlreadyInitialized.into());
1509 }
1510
1511 if previous_account_len < new_account_len {
1512 account_info.realloc(new_account_len, false)?;
1515 let mut buffer = account_info.try_borrow_mut_data()?;
1516 if extension_already_exists {
1517 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1518 state.realloc_variable_len_extension(new_extension)?;
1519 } else {
1520 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1521 set_account_type::<S>(*buffer)?;
1522 }
1523 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1525 state.init_variable_len_extension(new_extension, false)?;
1526 }
1527 } else {
1528 let mut buffer = account_info.try_borrow_mut_data()?;
1530 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1531 if extension_already_exists {
1532 state.realloc_variable_len_extension(new_extension)?;
1533 } else {
1534 state.init_variable_len_extension(new_extension, false)?;
1536 }
1537
1538 let removed_bytes = previous_account_len
1539 .checked_sub(new_account_len)
1540 .ok_or(ProgramError::AccountDataTooSmall)?;
1541 if removed_bytes > 0 {
1542 drop(buffer);
1544 account_info.realloc(new_account_len, false)?;
1545 }
1546 }
1547 Ok(())
1548}
1549
1550#[cfg(test)]
1551mod test {
1552 use {
1553 super::*,
1554 crate::{
1555 pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1556 state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1557 },
1558 bytemuck::Pod,
1559 solana_account_info::{
1560 Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1561 },
1562 solana_clock::Epoch,
1563 solana_pubkey::Pubkey,
1564 spl_pod::{
1565 bytemuck::pod_bytes_of,
1566 optional_keys::OptionalNonZeroPubkey,
1567 primitives::{PodBool, PodU64},
1568 },
1569 transfer_fee::test::test_transfer_fee_config,
1570 };
1571
1572 #[repr(C)]
1574 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1575 struct FixedLenMintTest {
1576 data: [u8; 8],
1577 }
1578 impl Extension for FixedLenMintTest {
1579 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1580 }
1581
1582 #[derive(Clone, Debug, PartialEq)]
1584 struct VariableLenMintTest {
1585 data: Vec<u8>,
1586 }
1587 impl Extension for VariableLenMintTest {
1588 const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1589 }
1590 impl VariableLenPack for VariableLenMintTest {
1591 fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1592 let data_start = size_of::<u64>();
1593 let end = data_start + self.data.len();
1594 if dst.len() < end {
1595 Err(ProgramError::InvalidAccountData)
1596 } else {
1597 dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1598 dst[data_start..end].copy_from_slice(&self.data);
1599 Ok(())
1600 }
1601 }
1602 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1603 let data_start = size_of::<u64>();
1604 let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1605 if src[data_start..data_start + length].len() != length {
1606 return Err(ProgramError::InvalidAccountData);
1607 }
1608 let data = Vec::from(&src[data_start..data_start + length]);
1609 Ok(Self { data })
1610 }
1611 fn get_packed_len(&self) -> Result<usize, ProgramError> {
1612 Ok(size_of::<u64>().saturating_add(self.data.len()))
1613 }
1614 }
1615
1616 const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1617 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1618 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1619 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1621 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1622 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ];
1625
1626 const MINT_WITH_EXTENSION: &[u8] = &[
1627 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1628 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1629 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1631 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1632 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 32, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1637 1, 1, ];
1639
1640 const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1641 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1642 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1644 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1647 4, 4, 4, 4, 4, 4, 2, 1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1652 7, 7, 7, 7, 7, 7, 2, 15, 0, 1, 0, 1, ];
1658
1659 #[test]
1660 fn unpack_opaque_buffer() {
1661 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1663 assert_eq!(state.base, &TEST_POD_MINT);
1664 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1665 assert_eq!(state.base, &TEST_POD_MINT);
1666 let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1667 let close_authority =
1668 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1669 assert_eq!(extension.close_authority, close_authority);
1670 assert_eq!(
1671 state.get_extension::<TransferFeeConfig>(),
1672 Err(ProgramError::InvalidAccountData)
1673 );
1674 assert_eq!(
1675 PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1676 Err(ProgramError::UninitializedAccount)
1677 );
1678
1679 let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1680 assert_eq!(state.base, &TEST_POD_MINT);
1681
1682 let mut test_mint = TEST_MINT_SLICE.to_vec();
1683 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1684 assert_eq!(state.base, &TEST_POD_MINT);
1685
1686 let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1688 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1689 let extension = state.get_extension::<TransferHookAccount>().unwrap();
1690 let transferring = PodBool::from(true);
1691 assert_eq!(extension.transferring, transferring);
1692 assert_eq!(
1693 PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1694 Err(ProgramError::InvalidAccountData)
1695 );
1696
1697 let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1698 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1699
1700 let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1701 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1702 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1703 }
1704
1705 #[test]
1706 fn mint_fail_unpack_opaque_buffer() {
1707 let mut buffer = vec![0, 3];
1709 assert_eq!(
1710 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1711 Err(ProgramError::InvalidAccountData)
1712 );
1713 assert_eq!(
1714 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1715 Err(ProgramError::InvalidAccountData)
1716 );
1717 assert_eq!(
1718 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1719 Err(ProgramError::InvalidAccountData)
1720 );
1721
1722 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1724 buffer[BASE_ACCOUNT_LENGTH] = 3;
1725 assert_eq!(
1726 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1727 Err(ProgramError::InvalidAccountData)
1728 );
1729
1730 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1732 buffer[45] = 0;
1733 assert_eq!(
1734 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1735 Err(ProgramError::UninitializedAccount)
1736 );
1737
1738 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1740 buffer[PodMint::SIZE_OF] = 100;
1741 assert_eq!(
1742 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1743 Err(ProgramError::InvalidAccountData)
1744 );
1745
1746 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1748 buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1749 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1750 assert_eq!(
1751 state.get_extension::<TransferFeeConfig>(),
1752 Err(ProgramError::Custom(
1753 TokenError::ExtensionTypeMismatch as u32
1754 ))
1755 );
1756
1757 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1759 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1760 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1761 assert_eq!(
1762 state.get_extension::<TransferFeeConfig>(),
1763 Err(ProgramError::InvalidAccountData)
1764 );
1765
1766 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1768 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1769 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1770 assert_eq!(
1771 state.get_extension::<TransferFeeConfig>(),
1772 Err(ProgramError::InvalidAccountData)
1773 );
1774
1775 let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1777 let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1778 assert_eq!(
1779 state.get_extension::<MintCloseAuthority>(),
1780 Err(ProgramError::InvalidAccountData)
1781 );
1782 }
1783
1784 #[test]
1785 fn account_fail_unpack_opaque_buffer() {
1786 let mut buffer = vec![0, 3];
1788 assert_eq!(
1789 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1790 Err(ProgramError::InvalidAccountData)
1791 );
1792 assert_eq!(
1793 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1794 Err(ProgramError::InvalidAccountData)
1795 );
1796 assert_eq!(
1797 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1798 Err(ProgramError::InvalidAccountData)
1799 );
1800
1801 let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1804 assert_eq!(
1805 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1806 Err(ProgramError::UninitializedAccount)
1807 );
1808 assert_eq!(
1809 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1810 Err(ProgramError::UninitializedAccount)
1811 );
1812
1813 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1815 buffer[BASE_ACCOUNT_LENGTH] = 3;
1816 assert_eq!(
1817 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1818 Err(ProgramError::InvalidAccountData)
1819 );
1820
1821 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1823 buffer[108] = 0;
1824 assert_eq!(
1825 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1826 Err(ProgramError::UninitializedAccount)
1827 );
1828
1829 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1831 buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1832 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1833 assert_eq!(
1834 state.get_extension::<TransferHookAccount>(),
1835 Err(ProgramError::Custom(
1836 TokenError::ExtensionTypeMismatch as u32
1837 ))
1838 );
1839
1840 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1842 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1843 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1844 assert_eq!(
1845 state.get_extension::<TransferHookAccount>(),
1846 Err(ProgramError::InvalidAccountData)
1847 );
1848
1849 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1851 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1852 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1853 assert_eq!(
1854 state.get_extension::<TransferHookAccount>(),
1855 Err(ProgramError::InvalidAccountData)
1856 );
1857
1858 let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1860 let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1861 assert_eq!(
1862 state.get_extension::<TransferHookAccount>(),
1863 Err(ProgramError::InvalidAccountData)
1864 );
1865 }
1866
1867 #[test]
1868 fn get_extension_types_with_opaque_buffer() {
1869 assert_eq!(
1871 get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1872 ProgramError::InvalidAccountData,
1873 );
1874 assert_eq!(
1876 get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1877 ProgramError::InvalidAccountData,
1878 );
1879 assert_eq!(
1881 get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1882 TlvDataInfo {
1883 extension_types: vec![ExtensionType::try_from(1).unwrap()],
1884 used_len: add_type_and_length_to_len(0),
1885 }
1886 );
1887 assert_eq!(
1889 get_tlv_data_info(&[0, 0]).unwrap(),
1890 TlvDataInfo {
1891 extension_types: vec![],
1892 used_len: 0
1893 }
1894 );
1895 }
1896
1897 #[test]
1898 fn mint_with_extension_pack_unpack() {
1899 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1900 ExtensionType::MintCloseAuthority,
1901 ExtensionType::TransferFeeConfig,
1902 ])
1903 .unwrap();
1904 let mut buffer = vec![0; mint_size];
1905
1906 assert_eq!(
1908 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1909 Err(ProgramError::UninitializedAccount),
1910 );
1911
1912 let mut state =
1913 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1914 assert_eq!(
1916 state.init_extension::<TransferFeeAmount>(true),
1917 Err(ProgramError::InvalidAccountData),
1918 );
1919
1920 let close_authority =
1922 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1923 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1924 extension.close_authority = close_authority;
1925 assert_eq!(
1926 &state.get_extension_types().unwrap(),
1927 &[ExtensionType::MintCloseAuthority]
1928 );
1929
1930 assert_eq!(
1932 state.init_extension::<MintCloseAuthority>(false),
1933 Err(ProgramError::Custom(
1934 TokenError::ExtensionAlreadyInitialized as u32
1935 ))
1936 );
1937
1938 assert_eq!(
1940 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1941 Err(ProgramError::Custom(
1942 TokenError::ExtensionBaseMismatch as u32
1943 ))
1944 );
1945
1946 assert_eq!(
1948 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1949 Err(ProgramError::UninitializedAccount),
1950 );
1951
1952 let mut state =
1954 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1955 *state.base = TEST_POD_MINT;
1956 state.init_account_type().unwrap();
1957
1958 let mut expect = TEST_MINT_SLICE.to_vec();
1960 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
1962 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1963 expect
1964 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1965 expect.extend_from_slice(&[1; 32]); expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1967 expect.extend_from_slice(&[0; size_of::<Length>()]);
1968 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1969 assert_eq!(expect, buffer);
1970
1971 assert_eq!(
1973 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1974 Err(TokenError::AlreadyInUse.into()),
1975 );
1976
1977 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1979
1980 *state.base = TEST_POD_MINT;
1982 state.base.supply = (u64::from(state.base.supply) + 100).into();
1983
1984 let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1986 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1987
1988 let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1990 unpacked_extension.close_authority = close_authority;
1991
1992 let base = *state.base;
1994 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1995 assert_eq!(state.base, &base);
1996 let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1997 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1998
1999 let mut expect = vec![];
2001 expect.extend_from_slice(bytemuck::bytes_of(&base));
2002 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2004 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2005 expect
2006 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2007 expect.extend_from_slice(&[0; 32]);
2008 expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2009 expect.extend_from_slice(&[0; size_of::<Length>()]);
2010 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2011 assert_eq!(expect, buffer);
2012
2013 assert_eq!(
2015 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2016 Err(ProgramError::UninitializedAccount),
2017 );
2018
2019 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2020 let mint_transfer_fee = test_transfer_fee_config();
2022 let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2023 new_extension.transfer_fee_config_authority =
2024 mint_transfer_fee.transfer_fee_config_authority;
2025 new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2026 new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2027 new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2028 new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2029
2030 assert_eq!(
2031 &state.get_extension_types().unwrap(),
2032 &[
2033 ExtensionType::MintCloseAuthority,
2034 ExtensionType::TransferFeeConfig
2035 ]
2036 );
2037
2038 let mut expect = vec![];
2040 expect.extend_from_slice(pod_bytes_of(&base));
2041 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2043 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2044 expect
2045 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2046 expect.extend_from_slice(&[0; 32]); expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2048 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
2049 expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
2050 assert_eq!(expect, buffer);
2051
2052 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2054 assert_eq!(
2055 state.init_extension::<MintPaddingTest>(true),
2056 Err(ProgramError::InvalidAccountData),
2057 );
2058 }
2059
2060 #[test]
2061 fn mint_extension_any_order() {
2062 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2063 ExtensionType::MintCloseAuthority,
2064 ExtensionType::TransferFeeConfig,
2065 ])
2066 .unwrap();
2067 let mut buffer = vec![0; mint_size];
2068
2069 let mut state =
2070 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2071 let close_authority =
2073 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2074 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2075 extension.close_authority = close_authority;
2076
2077 let mint_transfer_fee = test_transfer_fee_config();
2078 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2079 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2080 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2081 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2082 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2083 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2084
2085 assert_eq!(
2086 &state.get_extension_types().unwrap(),
2087 &[
2088 ExtensionType::MintCloseAuthority,
2089 ExtensionType::TransferFeeConfig
2090 ]
2091 );
2092
2093 let mut state =
2095 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2096 *state.base = TEST_POD_MINT;
2097 state.init_account_type().unwrap();
2098
2099 let mut other_buffer = vec![0; mint_size];
2100 let mut state =
2101 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2102
2103 *state.base = TEST_POD_MINT;
2105 state.init_account_type().unwrap();
2106
2107 let mint_transfer_fee = test_transfer_fee_config();
2109 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2110 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2111 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2112 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2113 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2114 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2115
2116 let close_authority =
2117 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2118 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2119 extension.close_authority = close_authority;
2120
2121 assert_eq!(
2122 &state.get_extension_types().unwrap(),
2123 &[
2124 ExtensionType::TransferFeeConfig,
2125 ExtensionType::MintCloseAuthority
2126 ]
2127 );
2128
2129 assert_ne!(buffer, other_buffer);
2131 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2132 let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2133
2134 assert_eq!(
2136 state.get_extension::<TransferFeeConfig>().unwrap(),
2137 other_state.get_extension::<TransferFeeConfig>().unwrap()
2138 );
2139 assert_eq!(
2140 state.get_extension::<MintCloseAuthority>().unwrap(),
2141 other_state.get_extension::<MintCloseAuthority>().unwrap()
2142 );
2143 assert_eq!(state.base, other_state.base);
2144 }
2145
2146 #[test]
2147 fn mint_with_multisig_len() {
2148 let mut buffer = vec![0; Multisig::LEN];
2149 assert_eq!(
2150 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2151 Err(ProgramError::InvalidAccountData),
2152 );
2153 let mint_size =
2154 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2155 .unwrap();
2156 assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2157 let mut buffer = vec![0; mint_size];
2158
2159 let mut state =
2161 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2162 *state.base = TEST_POD_MINT;
2163 state.init_account_type().unwrap();
2164
2165 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2167 extension.padding1 = [1; 128];
2168 extension.padding2 = [1; 48];
2169 extension.padding3 = [1; 9];
2170
2171 assert_eq!(
2172 &state.get_extension_types().unwrap(),
2173 &[ExtensionType::MintPaddingTest]
2174 );
2175
2176 let mut expect = TEST_MINT_SLICE.to_vec();
2178 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2180 expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2181 expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
2182 expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
2183 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2184 assert_eq!(expect, buffer);
2185 }
2186
2187 #[test]
2188 fn account_with_extension_pack_unpack() {
2189 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2190 ExtensionType::TransferFeeAmount,
2191 ])
2192 .unwrap();
2193 let mut buffer = vec![0; account_size];
2194
2195 assert_eq!(
2197 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2198 Err(ProgramError::UninitializedAccount),
2199 );
2200
2201 let mut state =
2202 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2203 assert_eq!(
2205 state.init_extension::<TransferFeeConfig>(true),
2206 Err(ProgramError::InvalidAccountData),
2207 );
2208 let withheld_amount = PodU64::from(u64::MAX);
2210 let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2211 extension.withheld_amount = withheld_amount;
2212
2213 assert_eq!(
2214 &state.get_extension_types().unwrap(),
2215 &[ExtensionType::TransferFeeAmount]
2216 );
2217
2218 assert_eq!(
2220 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2221 Err(ProgramError::UninitializedAccount),
2222 );
2223
2224 let mut state =
2226 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2227 *state.base = TEST_POD_ACCOUNT;
2228 state.init_account_type().unwrap();
2229 let base = *state.base;
2230
2231 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2233 expect.push(AccountType::Account.into());
2234 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2235 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2236 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2237 assert_eq!(expect, buffer);
2238
2239 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2241 assert_eq!(state.base, &base);
2242 assert_eq!(
2243 &state.get_extension_types().unwrap(),
2244 &[ExtensionType::TransferFeeAmount]
2245 );
2246
2247 *state.base = TEST_POD_ACCOUNT;
2249 state.base.amount = (u64::from(state.base.amount) + 100).into();
2250
2251 let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2253 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2254
2255 let withheld_amount = PodU64::from(u32::MAX as u64);
2257 unpacked_extension.withheld_amount = withheld_amount;
2258
2259 let base = *state.base;
2261 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2262 assert_eq!(state.base, &base);
2263 let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2264 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2265
2266 let mut expect = vec![];
2268 expect.extend_from_slice(pod_bytes_of(&base));
2269 expect.push(AccountType::Account.into());
2270 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2271 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2272 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2273 assert_eq!(expect, buffer);
2274
2275 assert_eq!(
2277 PodStateWithExtensions::<PodMint>::unpack(&buffer),
2278 Err(ProgramError::InvalidAccountData),
2279 );
2280 }
2281
2282 #[test]
2283 fn account_with_multisig_len() {
2284 let mut buffer = vec![0; Multisig::LEN];
2285 assert_eq!(
2286 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2287 Err(ProgramError::InvalidAccountData),
2288 );
2289 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2290 ExtensionType::AccountPaddingTest,
2291 ])
2292 .unwrap();
2293 assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2294 let mut buffer = vec![0; account_size];
2295
2296 let mut state =
2298 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2299 *state.base = TEST_POD_ACCOUNT;
2300 state.init_account_type().unwrap();
2301
2302 let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2304 extension.0.padding1 = [2; 128];
2305 extension.0.padding2 = [2; 48];
2306 extension.0.padding3 = [2; 9];
2307
2308 assert_eq!(
2309 &state.get_extension_types().unwrap(),
2310 &[ExtensionType::AccountPaddingTest]
2311 );
2312
2313 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2315 expect.push(AccountType::Account.into());
2316 expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2317 expect
2318 .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
2319 expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
2320 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2321 assert_eq!(expect, buffer);
2322 }
2323
2324 #[test]
2325 fn test_set_account_type() {
2326 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2328 let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2329 ExtensionType::ImmutableOwner,
2330 ])
2331 .unwrap()
2332 - buffer.len();
2333 buffer.append(&mut vec![0; needed_len]);
2334 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2335 assert_eq!(err, ProgramError::InvalidAccountData);
2336 set_account_type::<PodAccount>(&mut buffer).unwrap();
2337 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2339 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2340 assert_eq!(state.account_type[0], AccountType::Account as u8);
2341 state.init_extension::<ImmutableOwner>(true).unwrap(); let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2345 buffer.append(&mut vec![0; 2]);
2346 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2347 assert_eq!(err, ProgramError::InvalidAccountData);
2348 set_account_type::<PodAccount>(&mut buffer).unwrap();
2349 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2351 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2352 assert_eq!(state.account_type[0], AccountType::Account as u8);
2353
2354 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2356 buffer.append(&mut vec![2, 0]);
2357 let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2358 set_account_type::<PodAccount>(&mut buffer).unwrap();
2359 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2360 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2361 assert_eq!(state.account_type[0], AccountType::Account as u8);
2362
2363 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2365 buffer.append(&mut vec![1, 0]);
2366 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2367 assert_eq!(err, ProgramError::InvalidAccountData);
2368 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2369 assert_eq!(err, ProgramError::InvalidAccountData);
2370
2371 let mut buffer = TEST_MINT_SLICE.to_vec();
2373 let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2374 ExtensionType::MintCloseAuthority,
2375 ])
2376 .unwrap()
2377 - buffer.len();
2378 buffer.append(&mut vec![0; needed_len]);
2379 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2380 assert_eq!(err, ProgramError::InvalidAccountData);
2381 set_account_type::<PodMint>(&mut buffer).unwrap();
2382 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2384 assert_eq!(state.base, &TEST_POD_MINT);
2385 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2386 state.init_extension::<MintCloseAuthority>(true).unwrap();
2387
2388 let mut buffer = TEST_MINT_SLICE.to_vec();
2390 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2391 buffer.append(&mut vec![0; 2]);
2392 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2393 assert_eq!(err, ProgramError::InvalidAccountData);
2394 set_account_type::<PodMint>(&mut buffer).unwrap();
2395 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2397 assert_eq!(state.base, &TEST_POD_MINT);
2398 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2399
2400 let mut buffer = TEST_MINT_SLICE.to_vec();
2402 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2403 buffer.append(&mut vec![1, 0]);
2404 set_account_type::<PodMint>(&mut buffer).unwrap();
2405 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2406 assert_eq!(state.base, &TEST_POD_MINT);
2407 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2408
2409 let mut buffer = TEST_MINT_SLICE.to_vec();
2411 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2412 buffer.append(&mut vec![2, 0]);
2413 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2414 assert_eq!(err, ProgramError::InvalidAccountData);
2415 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2416 assert_eq!(err, ProgramError::InvalidAccountData);
2417 }
2418
2419 #[test]
2420 fn test_set_account_type_wrongly() {
2421 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2423 buffer.append(&mut vec![0; 2]);
2424 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2425 assert_eq!(err, ProgramError::InvalidAccountData);
2426
2427 let mut buffer = TEST_MINT_SLICE.to_vec();
2429 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2430 buffer.append(&mut vec![0; 2]);
2431 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2432 assert_eq!(err, ProgramError::InvalidAccountData);
2433 }
2434
2435 #[test]
2436 fn test_get_required_init_account_extensions() {
2437 let mint_extensions = vec![
2439 ExtensionType::MintCloseAuthority,
2440 ExtensionType::Uninitialized,
2441 ];
2442 assert_eq!(
2443 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2444 vec![]
2445 );
2446
2447 let mint_extensions = vec![
2449 ExtensionType::TransferFeeConfig,
2450 ExtensionType::MintCloseAuthority,
2451 ];
2452 assert_eq!(
2453 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2454 vec![ExtensionType::TransferFeeAmount]
2455 );
2456
2457 let mint_extensions = vec![
2459 ExtensionType::TransferFeeConfig,
2460 ExtensionType::MintPaddingTest,
2461 ];
2462 assert_eq!(
2463 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2464 vec![
2465 ExtensionType::TransferFeeAmount,
2466 ExtensionType::AccountPaddingTest
2467 ]
2468 );
2469
2470 let mint_extensions = vec![
2472 ExtensionType::TransferFeeConfig,
2473 ExtensionType::TransferFeeConfig,
2474 ];
2475 assert_eq!(
2476 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2477 vec![
2478 ExtensionType::TransferFeeAmount,
2479 ExtensionType::TransferFeeAmount
2480 ]
2481 );
2482 }
2483
2484 #[test]
2485 fn mint_without_extensions() {
2486 let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2487 let mut buffer = vec![0; space];
2488 assert_eq!(
2489 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2490 Err(ProgramError::InvalidAccountData),
2491 );
2492
2493 let mut state =
2495 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2496 *state.base = TEST_POD_MINT;
2497 state.init_account_type().unwrap();
2498
2499 assert_eq!(
2501 state.init_extension::<TransferFeeConfig>(true),
2502 Err(ProgramError::InvalidAccountData),
2503 );
2504
2505 assert_eq!(TEST_MINT_SLICE, buffer);
2506 }
2507
2508 #[test]
2509 fn test_init_nonzero_default() {
2510 let mint_size =
2511 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2512 .unwrap();
2513 let mut buffer = vec![0; mint_size];
2514 let mut state =
2515 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2516 *state.base = TEST_POD_MINT;
2517 state.init_account_type().unwrap();
2518 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2519 assert_eq!(extension.padding1, [1; 128]);
2520 assert_eq!(extension.padding2, [2; 48]);
2521 assert_eq!(extension.padding3, [3; 9]);
2522 }
2523
2524 #[test]
2525 fn test_init_buffer_too_small() {
2526 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2527 ExtensionType::MintCloseAuthority,
2528 ])
2529 .unwrap();
2530 let mut buffer = vec![0; mint_size - 1];
2531 let mut state =
2532 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2533 let err = state
2534 .init_extension::<MintCloseAuthority>(true)
2535 .unwrap_err();
2536 assert_eq!(err, ProgramError::InvalidAccountData);
2537
2538 state.tlv_data[0] = 3;
2539 state.tlv_data[2] = 32;
2540 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2541 assert_eq!(err, ProgramError::InvalidAccountData);
2542
2543 let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2544 let err =
2545 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2546 assert_eq!(err, ProgramError::InvalidAccountData);
2547
2548 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2550 let mut state =
2551 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2552 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2553 assert_eq!(err, ProgramError::InvalidAccountData);
2554
2555 assert_eq!(state.get_extension_types().unwrap(), vec![]);
2556
2557 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2559 let state =
2560 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2561 assert_eq!(state.get_extension_types().unwrap(), []);
2562 }
2563
2564 #[test]
2565 fn test_extension_with_no_data() {
2566 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2567 ExtensionType::ImmutableOwner,
2568 ])
2569 .unwrap();
2570 let mut buffer = vec![0; account_size];
2571 let mut state =
2572 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2573 *state.base = TEST_POD_ACCOUNT;
2574 state.init_account_type().unwrap();
2575
2576 let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2577 assert_eq!(
2578 err,
2579 ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2580 );
2581
2582 state.init_extension::<ImmutableOwner>(true).unwrap();
2583 assert_eq!(
2584 get_first_extension_type(state.tlv_data).unwrap(),
2585 Some(ExtensionType::ImmutableOwner)
2586 );
2587 assert_eq!(
2588 get_tlv_data_info(state.tlv_data).unwrap(),
2589 TlvDataInfo {
2590 extension_types: vec![ExtensionType::ImmutableOwner],
2591 used_len: add_type_and_length_to_len(0)
2592 }
2593 );
2594 }
2595
2596 #[test]
2597 fn fail_account_len_with_metadata() {
2598 assert_eq!(
2599 ExtensionType::try_calculate_account_len::<PodMint>(&[
2600 ExtensionType::MintCloseAuthority,
2601 ExtensionType::VariableLenMintTest,
2602 ExtensionType::TransferFeeConfig,
2603 ])
2604 .unwrap_err(),
2605 ProgramError::InvalidArgument
2606 );
2607 }
2608
2609 #[test]
2610 fn alloc() {
2611 let variable_len = VariableLenMintTest { data: vec![1] };
2612 let alloc_size = variable_len.get_packed_len().unwrap();
2613 let account_size =
2614 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2615 let mut buffer = vec![0; account_size];
2616 let mut state =
2617 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2618 state
2619 .init_variable_len_extension(&variable_len, false)
2620 .unwrap();
2621
2622 assert_eq!(
2624 state
2625 .init_variable_len_extension(&variable_len, false)
2626 .unwrap_err(),
2627 TokenError::ExtensionAlreadyInitialized.into()
2628 );
2629
2630 state
2632 .init_variable_len_extension(&variable_len, true)
2633 .unwrap();
2634
2635 assert_eq!(
2637 state
2638 .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2639 .unwrap_err(),
2640 TokenError::InvalidLengthForAlloc.into()
2641 );
2642
2643 assert_eq!(
2645 state
2646 .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2647 .unwrap_err(),
2648 ProgramError::InvalidAccountData
2649 );
2650 }
2651
2652 #[test]
2653 fn realloc() {
2654 let small_variable_len = VariableLenMintTest {
2655 data: vec![1, 2, 3],
2656 };
2657 let base_variable_len = VariableLenMintTest {
2658 data: vec![1, 2, 3, 4],
2659 };
2660 let big_variable_len = VariableLenMintTest {
2661 data: vec![1, 2, 3, 4, 5],
2662 };
2663 let too_big_variable_len = VariableLenMintTest {
2664 data: vec![1, 2, 3, 4, 5, 6],
2665 };
2666 let account_size =
2667 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2668 .unwrap()
2669 + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2670 let mut buffer = vec![0; account_size];
2671 let mut state =
2672 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2673
2674 state
2676 .init_variable_len_extension(&base_variable_len, false)
2677 .unwrap();
2678 let max_pubkey =
2679 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
2680 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2681 extension.authority = max_pubkey;
2682 extension.metadata_address = max_pubkey;
2683
2684 state
2686 .realloc_variable_len_extension(&big_variable_len)
2687 .unwrap();
2688 let extension = state
2689 .get_variable_len_extension::<VariableLenMintTest>()
2690 .unwrap();
2691 assert_eq!(extension, big_variable_len);
2692 let extension = state.get_extension::<MetadataPointer>().unwrap();
2693 assert_eq!(extension.authority, max_pubkey);
2694 assert_eq!(extension.metadata_address, max_pubkey);
2695
2696 state
2698 .realloc_variable_len_extension(&small_variable_len)
2699 .unwrap();
2700 let extension = state
2701 .get_variable_len_extension::<VariableLenMintTest>()
2702 .unwrap();
2703 assert_eq!(extension, small_variable_len);
2704 let extension = state.get_extension::<MetadataPointer>().unwrap();
2705 assert_eq!(extension.authority, max_pubkey);
2706 assert_eq!(extension.metadata_address, max_pubkey);
2707 let diff = big_variable_len.get_packed_len().unwrap()
2708 - small_variable_len.get_packed_len().unwrap();
2709 assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2710
2711 let mut state =
2713 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2714 assert_eq!(
2716 state
2717 .realloc_variable_len_extension(&too_big_variable_len)
2718 .unwrap_err(),
2719 ProgramError::InvalidAccountData,
2720 );
2721 }
2722
2723 #[test]
2724 fn account_len() {
2725 let small_variable_len = VariableLenMintTest {
2726 data: vec![20, 30, 40],
2727 };
2728 let variable_len = VariableLenMintTest {
2729 data: vec![20, 30, 40, 50],
2730 };
2731 let big_variable_len = VariableLenMintTest {
2732 data: vec![20, 30, 40, 50, 60],
2733 };
2734 let value_len = variable_len.get_packed_len().unwrap();
2735 let account_size =
2736 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2737 let mut buffer = vec![0; account_size];
2738 let mut state =
2739 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2740
2741 let current_len = state.try_get_account_len().unwrap();
2744 assert_eq!(current_len, PodMint::SIZE_OF);
2745 let new_len = state
2746 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2747 &variable_len,
2748 )
2749 .unwrap();
2750 assert_eq!(
2751 new_len,
2752 BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2753 );
2754
2755 state
2756 .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2757 .unwrap();
2758 let current_len = state.try_get_account_len().unwrap();
2759 assert_eq!(current_len, new_len);
2760
2761 let new_len = state
2763 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2764 &small_variable_len,
2765 )
2766 .unwrap();
2767 assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2768
2769 let new_len = state
2771 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2772 &big_variable_len,
2773 )
2774 .unwrap();
2775 assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2776
2777 let new_len = state
2779 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2780 &variable_len,
2781 )
2782 .unwrap();
2783 assert_eq!(new_len, current_len);
2784 }
2785
2786 struct SolanaAccountData {
2789 data: Vec<u8>,
2790 lamports: u64,
2791 owner: Pubkey,
2792 }
2793 impl SolanaAccountData {
2794 fn new(account_data: &[u8]) -> Self {
2797 let mut data = vec![];
2798 data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2799 data.extend_from_slice(account_data);
2800 data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2801 Self {
2802 data,
2803 lamports: 10,
2804 owner: Pubkey::new_unique(),
2805 }
2806 }
2807
2808 fn data(&self) -> &[u8] {
2811 let start = size_of::<u64>();
2812 let len = self.len();
2813 &self.data[start..start + len]
2814 }
2815
2816 fn len(&self) -> usize {
2818 self.data
2819 .get(..size_of::<u64>())
2820 .and_then(|slice| slice.try_into().ok())
2821 .map(u64::from_le_bytes)
2822 .unwrap() as usize
2823 }
2824 }
2825 impl GetAccount for SolanaAccountData {
2826 fn get(&mut self) -> (&mut u64, &mut [u8], &Pubkey, bool, Epoch) {
2827 let start = size_of::<u64>();
2829 let len = self.len();
2830 (
2831 &mut self.lamports,
2832 &mut self.data[start..start + len],
2833 &self.owner,
2834 false,
2835 Epoch::default(),
2836 )
2837 }
2838 }
2839
2840 #[test]
2841 fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2842 let fixed_len = FixedLenMintTest {
2843 data: [1, 2, 3, 4, 5, 6, 7, 8],
2844 };
2845 let value_len = pod_get_packed_len::<FixedLenMintTest>();
2846 let base_account_size = PodMint::SIZE_OF;
2847 let mut buffer = vec![0; base_account_size];
2848 let state =
2849 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2850 *state.base = TEST_POD_MINT;
2851
2852 let mut data = SolanaAccountData::new(&buffer);
2853 let key = Pubkey::new_unique();
2854 let account_info = (&key, &mut data).into_account_info();
2855
2856 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2857 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2858 assert_eq!(data.len(), new_account_len);
2859 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2860 assert_eq!(
2861 state.get_extension::<FixedLenMintTest>().unwrap(),
2862 &fixed_len,
2863 );
2864
2865 let account_info = (&key, &mut data).into_account_info();
2867 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2868
2869 let account_info = (&key, &mut data).into_account_info();
2871 assert_eq!(
2872 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2873 TokenError::ExtensionAlreadyInitialized.into()
2874 );
2875 }
2876
2877 #[test]
2878 fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2879 let variable_len = VariableLenMintTest { data: vec![20, 99] };
2880 let value_len = variable_len.get_packed_len().unwrap();
2881 let base_account_size = PodMint::SIZE_OF;
2882 let mut buffer = vec![0; base_account_size];
2883 let state =
2884 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2885 *state.base = TEST_POD_MINT;
2886
2887 let mut data = SolanaAccountData::new(&buffer);
2888 let key = Pubkey::new_unique();
2889 let account_info = (&key, &mut data).into_account_info();
2890
2891 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2892 &account_info,
2893 &variable_len,
2894 false,
2895 )
2896 .unwrap();
2897 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2898 assert_eq!(data.len(), new_account_len);
2899 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2900 assert_eq!(
2901 state
2902 .get_variable_len_extension::<VariableLenMintTest>()
2903 .unwrap(),
2904 variable_len
2905 );
2906
2907 let account_info = (&key, &mut data).into_account_info();
2909 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2910 &account_info,
2911 &variable_len,
2912 true,
2913 )
2914 .unwrap();
2915
2916 let account_info = (&key, &mut data).into_account_info();
2918 assert_eq!(
2919 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2920 &account_info,
2921 &variable_len,
2922 false,
2923 )
2924 .unwrap_err(),
2925 TokenError::ExtensionAlreadyInitialized.into()
2926 );
2927 }
2928
2929 #[test]
2930 fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2931 let fixed_len = FixedLenMintTest {
2932 data: [1, 2, 3, 4, 5, 6, 7, 8],
2933 };
2934 let value_len = pod_get_packed_len::<FixedLenMintTest>();
2935 let account_size =
2936 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2937 .unwrap()
2938 + add_type_and_length_to_len(value_len);
2939 let mut buffer = vec![0; account_size];
2940 let mut state =
2941 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2942 *state.base = TEST_POD_MINT;
2943 state.init_account_type().unwrap();
2944
2945 let test_key =
2946 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2947 let extension = state.init_extension::<GroupPointer>(false).unwrap();
2948 extension.authority = test_key;
2949 extension.group_address = test_key;
2950
2951 let mut data = SolanaAccountData::new(&buffer);
2952 let key = Pubkey::new_unique();
2953 let account_info = (&key, &mut data).into_account_info();
2954
2955 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2956 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2957 + add_type_and_length_to_len(value_len)
2958 + add_type_and_length_to_len(size_of::<GroupPointer>());
2959 assert_eq!(data.len(), new_account_len);
2960 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2961 assert_eq!(
2962 state.get_extension::<FixedLenMintTest>().unwrap(),
2963 &fixed_len,
2964 );
2965 let extension = state.get_extension::<GroupPointer>().unwrap();
2966 assert_eq!(extension.authority, test_key);
2967 assert_eq!(extension.group_address, test_key);
2968
2969 let account_info = (&key, &mut data).into_account_info();
2971 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2972
2973 let account_info = (&key, &mut data).into_account_info();
2975 assert_eq!(
2976 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2977 TokenError::ExtensionAlreadyInitialized.into()
2978 );
2979 }
2980
2981 #[test]
2982 fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2983 let variable_len = VariableLenMintTest { data: vec![42, 6] };
2984 let value_len = variable_len.get_packed_len().unwrap();
2985 let account_size =
2986 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2987 .unwrap()
2988 + add_type_and_length_to_len(value_len);
2989 let mut buffer = vec![0; account_size];
2990 let mut state =
2991 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2992 *state.base = TEST_POD_MINT;
2993 state.init_account_type().unwrap();
2994
2995 let test_key =
2996 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2997 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2998 extension.authority = test_key;
2999 extension.metadata_address = test_key;
3000
3001 let mut data = SolanaAccountData::new(&buffer);
3002 let key = Pubkey::new_unique();
3003 let account_info = (&key, &mut data).into_account_info();
3004
3005 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3006 &account_info,
3007 &variable_len,
3008 false,
3009 )
3010 .unwrap();
3011 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3012 + add_type_and_length_to_len(value_len)
3013 + add_type_and_length_to_len(size_of::<MetadataPointer>());
3014 assert_eq!(data.len(), new_account_len);
3015 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3016 assert_eq!(
3017 state
3018 .get_variable_len_extension::<VariableLenMintTest>()
3019 .unwrap(),
3020 variable_len
3021 );
3022 let extension = state.get_extension::<MetadataPointer>().unwrap();
3023 assert_eq!(extension.authority, test_key);
3024 assert_eq!(extension.metadata_address, test_key);
3025
3026 let account_info = (&key, &mut data).into_account_info();
3028 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3029 &account_info,
3030 &variable_len,
3031 true,
3032 )
3033 .unwrap();
3034
3035 let account_info = (&key, &mut data).into_account_info();
3037 assert_eq!(
3038 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3039 &account_info,
3040 &variable_len,
3041 false,
3042 )
3043 .unwrap_err(),
3044 TokenError::ExtensionAlreadyInitialized.into()
3045 );
3046 }
3047
3048 #[test]
3049 fn realloc_variable_len_tlv_in_account_info() {
3050 let variable_len = VariableLenMintTest {
3051 data: vec![1, 2, 3, 4, 5],
3052 };
3053 let alloc_size = variable_len.get_packed_len().unwrap();
3054 let account_size =
3055 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3056 .unwrap()
3057 + add_type_and_length_to_len(alloc_size);
3058 let mut buffer = vec![0; account_size];
3059 let mut state =
3060 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3061 *state.base = TEST_POD_MINT;
3062 state.init_account_type().unwrap();
3063
3064 state
3066 .init_variable_len_extension(&variable_len, false)
3067 .unwrap();
3068 let max_pubkey =
3069 OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
3070 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3071 extension.authority = max_pubkey;
3072 extension.metadata_address = max_pubkey;
3073
3074 let mut data = SolanaAccountData::new(&buffer);
3076 let key = Pubkey::new_unique();
3077 let account_info = (&key, &mut data).into_account_info();
3078 let variable_len = VariableLenMintTest { data: vec![1, 2] };
3079 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3080 &account_info,
3081 &variable_len,
3082 true,
3083 )
3084 .unwrap();
3085
3086 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3087 let extension = state.get_extension::<MetadataPointer>().unwrap();
3088 assert_eq!(extension.authority, max_pubkey);
3089 assert_eq!(extension.metadata_address, max_pubkey);
3090 let extension = state
3091 .get_variable_len_extension::<VariableLenMintTest>()
3092 .unwrap();
3093 assert_eq!(extension, variable_len);
3094 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3095
3096 let account_info = (&key, &mut data).into_account_info();
3098 let variable_len = VariableLenMintTest {
3099 data: vec![1, 2, 3, 4, 5, 6, 7],
3100 };
3101 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3102 &account_info,
3103 &variable_len,
3104 true,
3105 )
3106 .unwrap();
3107
3108 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3109 let extension = state.get_extension::<MetadataPointer>().unwrap();
3110 assert_eq!(extension.authority, max_pubkey);
3111 assert_eq!(extension.metadata_address, max_pubkey);
3112 let extension = state
3113 .get_variable_len_extension::<VariableLenMintTest>()
3114 .unwrap();
3115 assert_eq!(extension, variable_len);
3116 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3117
3118 let account_info = (&key, &mut data).into_account_info();
3120 let variable_len = VariableLenMintTest {
3121 data: vec![7, 6, 5, 4, 3, 2, 1],
3122 };
3123 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3124 &account_info,
3125 &variable_len,
3126 true,
3127 )
3128 .unwrap();
3129
3130 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3131 let extension = state.get_extension::<MetadataPointer>().unwrap();
3132 assert_eq!(extension.authority, max_pubkey);
3133 assert_eq!(extension.metadata_address, max_pubkey);
3134 let extension = state
3135 .get_variable_len_extension::<VariableLenMintTest>()
3136 .unwrap();
3137 assert_eq!(extension, variable_len);
3138 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3139 }
3140}