1#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use {
6 crate::{
7 error::TokenError,
8 extension::{
9 confidential_mint_burn::ConfidentialMintBurn,
10 confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11 confidential_transfer_fee::{
12 ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13 },
14 cpi_guard::CpiGuard,
15 default_account_state::DefaultAccountState,
16 group_member_pointer::GroupMemberPointer,
17 group_pointer::GroupPointer,
18 immutable_owner::ImmutableOwner,
19 interest_bearing_mint::InterestBearingConfig,
20 memo_transfer::MemoTransfer,
21 metadata_pointer::MetadataPointer,
22 mint_close_authority::MintCloseAuthority,
23 non_transferable::{NonTransferable, NonTransferableAccount},
24 pausable::{PausableAccount, PausableConfig},
25 permanent_delegate::PermanentDelegate,
26 permissioned_burn::PermissionedBurnConfig,
27 scaled_ui_amount::ScaledUiAmountConfig,
28 transfer_fee::{TransferFeeAmount, TransferFeeConfig},
29 transfer_hook::{TransferHook, TransferHookAccount},
30 },
31 pod::{PodAccount, PodMint},
32 state::{Account, Mint, Multisig, PackedSizeOf},
33 },
34 alloc::{vec, vec::Vec},
35 bytemuck::{Pod, Zeroable},
36 core::{
37 cmp::Ordering,
38 convert::{TryFrom, TryInto},
39 mem::size_of,
40 },
41 num_enum::{IntoPrimitive, TryFromPrimitive},
42 solana_account_info::AccountInfo,
43 solana_program_error::ProgramError,
44 solana_program_pack::{IsInitialized, Pack},
45 solana_zero_copy::unaligned::U16,
46 spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
47 spl_type_length_value::variable_len_pack::VariableLenPack,
48};
49
50pub mod confidential_transfer;
52pub mod confidential_transfer_fee;
54pub mod cpi_guard;
56pub mod default_account_state;
58pub mod group_member_pointer;
60pub mod group_pointer;
62pub mod immutable_owner;
64pub mod interest_bearing_mint;
66pub mod memo_transfer;
68pub mod metadata_pointer;
70pub mod mint_close_authority;
72pub mod non_transferable;
74pub mod pausable;
76pub mod permanent_delegate;
78pub mod permissioned_burn;
80pub mod scaled_ui_amount;
82pub mod token_group;
84pub mod token_metadata;
86pub mod transfer_fee;
88pub mod transfer_hook;
90
91pub mod confidential_mint_burn;
93
94#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
96#[repr(transparent)]
97pub struct Length(U16);
98impl From<Length> for usize {
99 fn from(n: Length) -> Self {
100 Self::from(u16::from(n.0))
101 }
102}
103impl TryFrom<usize> for Length {
104 type Error = ProgramError;
105 fn try_from(n: usize) -> Result<Self, Self::Error> {
106 u16::try_from(n)
107 .map(|v| Self(U16::from(v)))
108 .map_err(|_| ProgramError::AccountDataTooSmall)
109 }
110}
111
112fn get_tlv_indices(type_start: usize) -> TlvIndices {
114 let length_start = type_start.saturating_add(size_of::<ExtensionType>());
115 let value_start = length_start.saturating_add(size_of::<Length>());
116 TlvIndices {
117 type_start,
118 length_start,
119 value_start,
120 }
121}
122
123const fn adjust_len_for_multisig(account_len: usize) -> usize {
126 if account_len == Multisig::LEN {
127 account_len.saturating_add(size_of::<ExtensionType>())
128 } else {
129 account_len
130 }
131}
132
133const fn add_type_and_length_to_len(value_len: usize) -> usize {
136 value_len
137 .saturating_add(size_of::<ExtensionType>())
138 .saturating_add(size_of::<Length>())
139}
140
141#[derive(Debug)]
144struct TlvIndices {
145 pub type_start: usize,
146 pub length_start: usize,
147 pub value_start: usize,
148}
149fn get_extension_indices<V: Extension>(
150 tlv_data: &[u8],
151 init: bool,
152) -> Result<TlvIndices, ProgramError> {
153 let mut start_index = 0;
154 while start_index < tlv_data.len() {
155 let tlv_indices = get_tlv_indices(start_index);
156 if tlv_data.len() < tlv_indices.value_start {
157 return Err(ProgramError::InvalidAccountData);
158 }
159 let extension_type = u16::from_le_bytes(
160 tlv_data[tlv_indices.type_start..tlv_indices.length_start]
161 .try_into()
162 .map_err(|_| ProgramError::InvalidAccountData)?,
163 );
164 if extension_type == u16::from(V::TYPE) {
165 return Ok(tlv_indices);
167 } else if extension_type == u16::from(ExtensionType::Uninitialized) {
170 if init {
171 return Ok(tlv_indices);
172 } else {
173 return Err(TokenError::ExtensionNotFound.into());
174 }
175 } else {
176 let length = bytemuck::try_from_bytes::<Length>(
177 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
178 )
179 .map_err(|_| ProgramError::InvalidArgument)?;
180 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
181 start_index = value_end_index;
182 }
183 }
184 Err(ProgramError::InvalidAccountData)
185}
186
187#[derive(Debug, PartialEq)]
190struct TlvDataInfo {
191 extension_types: Vec<ExtensionType>,
193 used_len: usize,
198}
199
200fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
203 let mut extension_types = vec![];
204 let mut start_index = 0;
205 while start_index < tlv_data.len() {
206 let tlv_indices = get_tlv_indices(start_index);
207 if tlv_data.len() < tlv_indices.length_start {
208 return Ok(TlvDataInfo {
211 extension_types,
212 used_len: tlv_indices.type_start,
213 });
214 }
215 let extension_type =
216 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
217 if extension_type == ExtensionType::Uninitialized {
218 return Ok(TlvDataInfo {
219 extension_types,
220 used_len: tlv_indices.type_start,
221 });
222 } else {
223 if tlv_data.len() < tlv_indices.value_start {
224 return Err(ProgramError::InvalidAccountData);
226 }
227 extension_types.push(extension_type);
228 let length = bytemuck::try_from_bytes::<Length>(
229 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
230 )
231 .map_err(|_| ProgramError::InvalidArgument)?;
232
233 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
234 if value_end_index > tlv_data.len() {
235 return Err(ProgramError::InvalidAccountData);
237 }
238 start_index = value_end_index;
239 }
240 }
241 Ok(TlvDataInfo {
242 extension_types,
243 used_len: start_index,
244 })
245}
246
247fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
248 if tlv_data.is_empty() {
249 Ok(None)
250 } else {
251 let tlv_indices = get_tlv_indices(0);
252 if tlv_data.len() <= tlv_indices.length_start {
253 return Ok(None);
254 }
255 let extension_type =
256 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
257 if extension_type == ExtensionType::Uninitialized {
258 Ok(None)
259 } else {
260 Ok(Some(extension_type))
261 }
262 }
263}
264
265fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
266 if input.len() == Multisig::LEN || input.len() < minimum_len {
267 Err(ProgramError::InvalidAccountData)
268 } else {
269 Ok(())
270 }
271}
272
273fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
274 if account_type != S::ACCOUNT_TYPE {
275 Err(ProgramError::InvalidAccountData)
276 } else {
277 Ok(())
278 }
279}
280
281const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
302const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
305
306fn type_and_tlv_indices<S: BaseState>(
307 rest_input: &[u8],
308) -> Result<Option<(usize, usize)>, ProgramError> {
309 if rest_input.is_empty() {
310 Ok(None)
311 } else {
312 let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
313 let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
315 if rest_input.len() < tlv_start_index {
316 return Err(ProgramError::InvalidAccountData);
317 }
318 if rest_input[..account_type_index] != vec![0; account_type_index] {
319 Err(ProgramError::InvalidAccountData)
320 } else {
321 Ok(Some((account_type_index, tlv_start_index)))
322 }
323 }
324}
325
326fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
329 const ACCOUNT_INITIALIZED_INDEX: usize = 108; if input.len() != BASE_ACCOUNT_LENGTH {
332 return Err(ProgramError::InvalidAccountData);
333 }
334 Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
335}
336
337fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
338 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
339 return Err(ProgramError::InvalidAccountData);
340 }
341 let TlvIndices {
342 type_start: _,
343 length_start,
344 value_start,
345 } = get_extension_indices::<V>(tlv_data, false)?;
346 let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
349 .map_err(|_| ProgramError::InvalidArgument)?;
350 let value_end = value_start.saturating_add(usize::from(*length));
351 if tlv_data.len() < value_end {
352 return Err(ProgramError::InvalidAccountData);
353 }
354 Ok(&tlv_data[value_start..value_end])
355}
356
357fn get_extension_bytes_mut<S: BaseState, V: Extension>(
358 tlv_data: &mut [u8],
359) -> Result<&mut [u8], ProgramError> {
360 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
361 return Err(ProgramError::InvalidAccountData);
362 }
363 let TlvIndices {
364 type_start: _,
365 length_start,
366 value_start,
367 } = get_extension_indices::<V>(tlv_data, false)?;
368 let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
371 .map_err(|_| ProgramError::InvalidArgument)?;
372 let value_end = value_start.saturating_add(usize::from(*length));
373 if tlv_data.len() < value_end {
374 return Err(ProgramError::InvalidAccountData);
375 }
376 Ok(&mut tlv_data[value_start..value_end])
377}
378
379fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
385 tlv_data: &[u8],
386 new_extension_len: usize,
387) -> Result<usize, ProgramError> {
388 let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
390 let tlv_info = get_tlv_data_info(tlv_data)?;
391 let current_len = tlv_info
394 .used_len
395 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
396 let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
398 .map(|x| add_type_and_length_to_len(x.len()))
399 .unwrap_or(0);
400 let new_len = current_len
401 .saturating_sub(current_extension_len)
402 .saturating_add(new_extension_tlv_len);
403 Ok(adjust_len_for_multisig(new_len))
404}
405
406pub trait BaseStateWithExtensions<S: BaseState> {
408 fn get_tlv_data(&self) -> &[u8];
410
411 fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
413 get_extension_bytes::<S, V>(self.get_tlv_data())
414 }
415
416 fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
418 bytemuck::try_from_bytes::<V>(self.get_extension_bytes::<V>()?)
419 .map_err(|_| ProgramError::InvalidArgument)
420 }
421
422 fn get_variable_len_extension<V: Extension + VariableLenPack>(
424 &self,
425 ) -> Result<V, ProgramError> {
426 let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
427 V::unpack_from_slice(data)
428 }
429
430 fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
432 get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
433 }
434
435 fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
437 get_first_extension_type(self.get_tlv_data())
438 }
439
440 fn try_get_account_len(&self) -> Result<usize, ProgramError> {
442 let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
443 if tlv_info.extension_types.is_empty() {
444 Ok(S::SIZE_OF)
445 } else {
446 let total_len = tlv_info
447 .used_len
448 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
449 Ok(adjust_len_for_multisig(total_len))
450 }
451 }
452 fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
457 try_get_new_account_len_for_extension_len::<S, V>(self.get_tlv_data(), size_of::<V>())
458 }
459
460 fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
463 &self,
464 new_extension: &V,
465 ) -> Result<usize, ProgramError> {
466 try_get_new_account_len_for_extension_len::<S, V>(
467 self.get_tlv_data(),
468 new_extension.get_packed_len()?,
469 )
470 }
471}
472
473#[derive(Clone, Debug, PartialEq)]
476pub struct StateWithExtensionsOwned<S: BaseState> {
477 pub base: S,
479 tlv_data: Vec<u8>,
481}
482impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
483 pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
487 check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
488 let mut rest = input.split_off(S::SIZE_OF);
489 let base = S::unpack(&input)?;
490 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
491 let account_type = AccountType::try_from(rest[account_type_index])
493 .map_err(|_| ProgramError::InvalidAccountData)?;
494 check_account_type::<S>(account_type)?;
495 let tlv_data = rest.split_off(tlv_start_index);
496 Ok(Self { base, tlv_data })
497 } else {
498 Ok(Self {
499 base,
500 tlv_data: vec![],
501 })
502 }
503 }
504}
505
506impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
507 fn get_tlv_data(&self) -> &[u8] {
508 &self.tlv_data
509 }
510}
511
512#[derive(Debug, PartialEq)]
515pub struct StateWithExtensions<'data, S: BaseState + Pack> {
516 pub base: S,
518 tlv_data: &'data [u8],
520}
521impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
522 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
526 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
527 let (base_data, rest) = input.split_at(S::SIZE_OF);
528 let base = S::unpack(base_data)?;
529 let tlv_data = unpack_tlv_data::<S>(rest)?;
530 Ok(Self { base, tlv_data })
531 }
532}
533impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
534 fn get_tlv_data(&self) -> &[u8] {
535 self.tlv_data
536 }
537}
538
539#[derive(Debug, PartialEq)]
542pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
543 pub base: &'data S,
545 tlv_data: &'data [u8],
547}
548impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
549 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
553 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
554 let (base_data, rest) = input.split_at(S::SIZE_OF);
555 let base =
556 bytemuck::try_from_bytes::<S>(base_data).map_err(|_| ProgramError::InvalidArgument)?;
557 if !base.is_initialized() {
558 Err(ProgramError::UninitializedAccount)
559 } else {
560 let tlv_data = unpack_tlv_data::<S>(rest)?;
561 Ok(Self { base, tlv_data })
562 }
563 }
564}
565impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
566 fn get_tlv_data(&self) -> &[u8] {
567 self.tlv_data
568 }
569}
570
571pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
573 fn get_tlv_data_mut(&mut self) -> &mut [u8];
575
576 fn get_account_type_mut(&mut self) -> &mut [u8];
578
579 fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
581 get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
582 }
583
584 fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
587 bytemuck::try_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
588 .map_err(|_| ProgramError::InvalidArgument)
589 }
590
591 fn pack_variable_len_extension<V: Extension + VariableLenPack>(
594 &mut self,
595 extension: &V,
596 ) -> Result<(), ProgramError> {
597 let data = self.get_extension_bytes_mut::<V>()?;
598 extension.pack_into_slice(data)
601 }
602
603 fn init_extension<V: Extension + Pod + Default>(
609 &mut self,
610 overwrite: bool,
611 ) -> Result<&mut V, ProgramError> {
612 let length = size_of::<V>();
613 let buffer = self.alloc::<V>(length, overwrite)?;
614 let extension_ref =
615 bytemuck::try_from_bytes_mut::<V>(buffer).map_err(|_| ProgramError::InvalidArgument)?;
616 *extension_ref = V::default();
617 Ok(extension_ref)
618 }
619
620 fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
626 &mut self,
627 new_extension: &V,
628 ) -> Result<(), ProgramError> {
629 let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
630 new_extension.pack_into_slice(data)
631 }
632
633 fn realloc<V: Extension + VariableLenPack>(
643 &mut self,
644 length: usize,
645 ) -> Result<&mut [u8], ProgramError> {
646 let tlv_data = self.get_tlv_data_mut();
647 let TlvIndices {
648 type_start: _,
649 length_start,
650 value_start,
651 } = get_extension_indices::<V>(tlv_data, false)?;
652 let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
653 let data_len = tlv_data.len();
654
655 let length_ref =
656 bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
657 .map_err(|_| ProgramError::InvalidArgument)?;
658 let old_length = usize::from(*length_ref);
659
660 if old_length < length {
662 let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
663 if new_tlv_len > data_len {
664 return Err(ProgramError::InvalidAccountData);
665 }
666 }
667
668 *length_ref = Length::try_from(length)?;
671
672 let old_value_end = value_start.saturating_add(old_length);
673 let new_value_end = value_start.saturating_add(length);
674 tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
675 match old_length.cmp(&length) {
676 Ordering::Greater => {
677 let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
679 tlv_data[new_tlv_len..tlv_len].fill(0);
680 }
681 Ordering::Less => {
682 tlv_data[old_value_end..new_value_end].fill(0);
684 }
685 Ordering::Equal => {} }
687
688 Ok(&mut tlv_data[value_start..new_value_end])
689 }
690
691 fn init_variable_len_extension<V: Extension + VariableLenPack>(
697 &mut self,
698 extension: &V,
699 overwrite: bool,
700 ) -> Result<(), ProgramError> {
701 let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
702 extension.pack_into_slice(data)
703 }
704
705 fn alloc<V: Extension>(
707 &mut self,
708 length: usize,
709 overwrite: bool,
710 ) -> Result<&mut [u8], ProgramError> {
711 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
712 return Err(ProgramError::InvalidAccountData);
713 }
714 let tlv_data = self.get_tlv_data_mut();
715 let TlvIndices {
716 type_start,
717 length_start,
718 value_start,
719 } = get_extension_indices::<V>(tlv_data, true)?;
720
721 if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
722 return Err(ProgramError::InvalidAccountData);
723 }
724 let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
725
726 if extension_type == ExtensionType::Uninitialized || overwrite {
727 let extension_type_array: [u8; 2] = V::TYPE.into();
729 let extension_type_ref = &mut tlv_data[type_start..length_start];
730 extension_type_ref.copy_from_slice(&extension_type_array);
731 let length_ref =
733 bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
734 .map_err(|_| ProgramError::InvalidArgument)?;
735
736 if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
739 return Err(TokenError::InvalidLengthForAlloc.into());
740 }
741
742 *length_ref = Length::try_from(length)?;
743
744 let value_end = value_start.saturating_add(length);
745 Ok(&mut tlv_data[value_start..value_end])
746 } else {
747 Err(TokenError::ExtensionAlreadyInitialized.into())
749 }
750 }
751
752 fn init_account_extension_from_type(
759 &mut self,
760 extension_type: ExtensionType,
761 ) -> Result<(), ProgramError> {
762 if extension_type.get_account_type() != AccountType::Account {
763 return Ok(());
764 }
765 match extension_type {
766 ExtensionType::TransferFeeAmount => {
767 self.init_extension::<TransferFeeAmount>(true).map(|_| ())
768 }
769 ExtensionType::ImmutableOwner => {
770 self.init_extension::<ImmutableOwner>(true).map(|_| ())
771 }
772 ExtensionType::NonTransferableAccount => self
773 .init_extension::<NonTransferableAccount>(true)
774 .map(|_| ()),
775 ExtensionType::TransferHookAccount => {
776 self.init_extension::<TransferHookAccount>(true).map(|_| ())
777 }
778 ExtensionType::ConfidentialTransferAccount => Ok(()),
781 ExtensionType::PausableAccount => {
782 self.init_extension::<PausableAccount>(true).map(|_| ())
783 }
784 #[cfg(test)]
785 ExtensionType::AccountPaddingTest => {
786 self.init_extension::<AccountPaddingTest>(true).map(|_| ())
787 }
788 _ => unreachable!(),
789 }
790 }
791
792 fn init_account_type(&mut self) -> Result<(), ProgramError> {
797 let first_extension_type = self.get_first_extension_type()?;
798 let account_type = self.get_account_type_mut();
799 if !account_type.is_empty() {
800 if let Some(extension_type) = first_extension_type {
801 let account_type = extension_type.get_account_type();
802 if account_type != S::ACCOUNT_TYPE {
803 return Err(TokenError::ExtensionBaseMismatch.into());
804 }
805 }
806 account_type[0] = S::ACCOUNT_TYPE.into();
807 }
808 Ok(())
809 }
810
811 fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
814 if let Some(extension_type) = self.get_first_extension_type()? {
815 let account_type = extension_type.get_account_type();
816 if account_type != S::ACCOUNT_TYPE {
817 return Err(TokenError::ExtensionBaseMismatch.into());
818 }
819 }
820 Ok(())
821 }
822}
823
824#[derive(Debug, PartialEq)]
827pub struct StateWithExtensionsMut<'data, S: BaseState> {
828 pub base: S,
830 base_data: &'data mut [u8],
832 account_type: &'data mut [u8],
834 tlv_data: &'data mut [u8],
836}
837impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
838 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
842 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
843 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
844 let base = S::unpack(base_data)?;
845 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
846 Ok(Self {
847 base,
848 base_data,
849 account_type,
850 tlv_data,
851 })
852 }
853
854 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
859 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
860 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
861 let base = S::unpack_unchecked(base_data)?;
862 if base.is_initialized() {
863 return Err(TokenError::AlreadyInUse.into());
864 }
865 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
866 let state = Self {
867 base,
868 base_data,
869 account_type,
870 tlv_data,
871 };
872 state.check_account_type_matches_extension_type()?;
873 Ok(state)
874 }
875
876 pub fn pack_base(&mut self) {
878 S::pack_into_slice(&self.base, self.base_data);
879 }
880}
881impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
882 fn get_tlv_data(&self) -> &[u8] {
883 self.tlv_data
884 }
885}
886impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
887 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
888 self.tlv_data
889 }
890 fn get_account_type_mut(&mut self) -> &mut [u8] {
891 self.account_type
892 }
893}
894
895#[derive(Debug, PartialEq)]
898pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
899 pub base: &'data mut S,
901 account_type: &'data mut [u8],
903 tlv_data: &'data mut [u8],
905}
906impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
907 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
911 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
912 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
913 let base = bytemuck::try_from_bytes_mut::<S>(base_data)
914 .map_err(|_| ProgramError::InvalidArgument)?;
915 if !base.is_initialized() {
916 Err(ProgramError::UninitializedAccount)
917 } else {
918 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
919 Ok(Self {
920 base,
921 account_type,
922 tlv_data,
923 })
924 }
925 }
926
927 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
932 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
933 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
934 let base = bytemuck::try_from_bytes_mut::<S>(base_data)
935 .map_err(|_| ProgramError::InvalidArgument)?;
936 if base.is_initialized() {
937 return Err(TokenError::AlreadyInUse.into());
938 }
939 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
940 let state = Self {
941 base,
942 account_type,
943 tlv_data,
944 };
945 state.check_account_type_matches_extension_type()?;
946 Ok(state)
947 }
948}
949
950impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
951 fn get_tlv_data(&self) -> &[u8] {
952 self.tlv_data
953 }
954}
955impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
956 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
957 self.tlv_data
958 }
959 fn get_account_type_mut(&mut self) -> &mut [u8] {
960 self.account_type
961 }
962}
963
964fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
965 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
966 let account_type = AccountType::try_from(rest[account_type_index])
968 .map_err(|_| ProgramError::InvalidAccountData)?;
969 check_account_type::<S>(account_type)?;
970 Ok(&rest[tlv_start_index..])
971 } else {
972 Ok(&[])
973 }
974}
975
976fn unpack_type_and_tlv_data_with_check_mut<
977 S: BaseState,
978 F: Fn(AccountType) -> Result<(), ProgramError>,
979>(
980 rest: &mut [u8],
981 check_fn: F,
982) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
983 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
984 let account_type = AccountType::try_from(rest[account_type_index])
986 .map_err(|_| ProgramError::InvalidAccountData)?;
987 check_fn(account_type)?;
988 let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
989 Ok((
990 &mut account_type[account_type_index..tlv_start_index],
991 tlv_data,
992 ))
993 } else {
994 Ok((&mut [], &mut []))
995 }
996}
997
998fn unpack_type_and_tlv_data_mut<S: BaseState>(
999 rest: &mut [u8],
1000) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1001 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
1002}
1003
1004fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
1005 rest: &mut [u8],
1006) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1007 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1008 if account_type != AccountType::Uninitialized {
1009 Err(ProgramError::InvalidAccountData)
1010 } else {
1011 Ok(())
1012 }
1013 })
1014}
1015
1016pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1021 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1022 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1023 if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1024 return Err(ProgramError::InvalidAccountData);
1025 }
1026 if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1027 let mut account_type = AccountType::try_from(rest[account_type_index])
1028 .map_err(|_| ProgramError::InvalidAccountData)?;
1029 if account_type == AccountType::Uninitialized {
1030 rest[account_type_index] = S::ACCOUNT_TYPE.into();
1031 account_type = S::ACCOUNT_TYPE;
1032 }
1033 check_account_type::<S>(account_type)?;
1034 Ok(())
1035 } else {
1036 Err(ProgramError::InvalidAccountData)
1037 }
1038}
1039
1040#[repr(u8)]
1045#[derive(Clone, Copy, Debug, Default, PartialEq, TryFromPrimitive, IntoPrimitive)]
1046pub enum AccountType {
1047 #[default]
1049 Uninitialized,
1050 Mint,
1052 Account,
1054}
1055
1056#[repr(u16)]
1060#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1061#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
1062#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1063pub enum ExtensionType {
1064 Uninitialized,
1067 TransferFeeConfig,
1070 TransferFeeAmount,
1072 MintCloseAuthority,
1074 ConfidentialTransferMint,
1076 ConfidentialTransferAccount,
1078 DefaultAccountState,
1080 ImmutableOwner,
1082 MemoTransfer,
1084 NonTransferable,
1086 InterestBearingConfig,
1088 CpiGuard,
1090 PermanentDelegate,
1092 NonTransferableAccount,
1095 TransferHook,
1098 TransferHookAccount,
1101 ConfidentialTransferFeeConfig,
1104 ConfidentialTransferFeeAmount,
1106 MetadataPointer,
1109 TokenMetadata,
1111 GroupPointer,
1114 TokenGroup,
1116 GroupMemberPointer,
1119 TokenGroupMember,
1121 ConfidentialMintBurn,
1123 ScaledUiAmount,
1125 Pausable,
1127 PausableAccount,
1129 PermissionedBurn,
1131
1132 #[cfg(test)]
1134 VariableLenMintTest = u16::MAX - 2,
1135 #[cfg(test)]
1138 AccountPaddingTest,
1139 #[cfg(test)]
1142 MintPaddingTest,
1143}
1144impl TryFrom<&[u8]> for ExtensionType {
1145 type Error = ProgramError;
1146 fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1147 Self::try_from(u16::from_le_bytes(
1148 a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1149 ))
1150 .map_err(|_| ProgramError::InvalidAccountData)
1151 }
1152}
1153impl From<ExtensionType> for [u8; 2] {
1154 fn from(a: ExtensionType) -> Self {
1155 u16::from(a).to_le_bytes()
1156 }
1157}
1158impl ExtensionType {
1159 const fn sized(&self) -> bool {
1164 match self {
1165 ExtensionType::TokenMetadata => false,
1166 #[cfg(test)]
1167 ExtensionType::VariableLenMintTest => false,
1168 _ => true,
1169 }
1170 }
1171
1172 fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1176 if !self.sized() {
1177 return Err(ProgramError::InvalidArgument);
1178 }
1179 Ok(match self {
1180 ExtensionType::Uninitialized => 0,
1181 ExtensionType::TransferFeeConfig => size_of::<TransferFeeConfig>(),
1182 ExtensionType::TransferFeeAmount => size_of::<TransferFeeAmount>(),
1183 ExtensionType::MintCloseAuthority => size_of::<MintCloseAuthority>(),
1184 ExtensionType::ImmutableOwner => size_of::<ImmutableOwner>(),
1185 ExtensionType::ConfidentialTransferMint => size_of::<ConfidentialTransferMint>(),
1186 ExtensionType::ConfidentialTransferAccount => size_of::<ConfidentialTransferAccount>(),
1187 ExtensionType::DefaultAccountState => size_of::<DefaultAccountState>(),
1188 ExtensionType::MemoTransfer => size_of::<MemoTransfer>(),
1189 ExtensionType::NonTransferable => size_of::<NonTransferable>(),
1190 ExtensionType::InterestBearingConfig => size_of::<InterestBearingConfig>(),
1191 ExtensionType::CpiGuard => size_of::<CpiGuard>(),
1192 ExtensionType::PermanentDelegate => size_of::<PermanentDelegate>(),
1193 ExtensionType::NonTransferableAccount => size_of::<NonTransferableAccount>(),
1194 ExtensionType::TransferHook => size_of::<TransferHook>(),
1195 ExtensionType::TransferHookAccount => size_of::<TransferHookAccount>(),
1196 ExtensionType::ConfidentialTransferFeeConfig => {
1197 size_of::<ConfidentialTransferFeeConfig>()
1198 }
1199 ExtensionType::ConfidentialTransferFeeAmount => {
1200 size_of::<ConfidentialTransferFeeAmount>()
1201 }
1202 ExtensionType::MetadataPointer => size_of::<MetadataPointer>(),
1203 ExtensionType::TokenMetadata => unreachable!(),
1204 ExtensionType::GroupPointer => size_of::<GroupPointer>(),
1205 ExtensionType::TokenGroup => size_of::<TokenGroup>(),
1206 ExtensionType::GroupMemberPointer => size_of::<GroupMemberPointer>(),
1207 ExtensionType::TokenGroupMember => size_of::<TokenGroupMember>(),
1208 ExtensionType::ConfidentialMintBurn => size_of::<ConfidentialMintBurn>(),
1209 ExtensionType::ScaledUiAmount => size_of::<ScaledUiAmountConfig>(),
1210 ExtensionType::Pausable => size_of::<PausableConfig>(),
1211 ExtensionType::PausableAccount => size_of::<PausableAccount>(),
1212 ExtensionType::PermissionedBurn => size_of::<PermissionedBurnConfig>(),
1213 #[cfg(test)]
1214 ExtensionType::AccountPaddingTest => size_of::<AccountPaddingTest>(),
1215 #[cfg(test)]
1216 ExtensionType::MintPaddingTest => size_of::<MintPaddingTest>(),
1217 #[cfg(test)]
1218 ExtensionType::VariableLenMintTest => unreachable!(),
1219 })
1220 }
1221
1222 fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1226 Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1227 }
1228
1229 fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1233 let mut extensions = vec![];
1235 for extension_type in extension_types {
1236 if !extensions.contains(&extension_type) {
1237 extensions.push(extension_type);
1238 }
1239 }
1240 extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1241 }
1242
1243 pub fn try_calculate_account_len<S: BaseState>(
1247 extension_types: &[Self],
1248 ) -> Result<usize, ProgramError> {
1249 if extension_types.is_empty() {
1250 Ok(S::SIZE_OF)
1251 } else {
1252 let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1253 let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1254 Ok(adjust_len_for_multisig(total_len))
1255 }
1256 }
1257
1258 pub fn get_account_type(&self) -> AccountType {
1260 match self {
1261 ExtensionType::Uninitialized => AccountType::Uninitialized,
1262 ExtensionType::TransferFeeConfig
1263 | ExtensionType::MintCloseAuthority
1264 | ExtensionType::ConfidentialTransferMint
1265 | ExtensionType::DefaultAccountState
1266 | ExtensionType::NonTransferable
1267 | ExtensionType::InterestBearingConfig
1268 | ExtensionType::PermanentDelegate
1269 | ExtensionType::TransferHook
1270 | ExtensionType::ConfidentialTransferFeeConfig
1271 | ExtensionType::MetadataPointer
1272 | ExtensionType::TokenMetadata
1273 | ExtensionType::GroupPointer
1274 | ExtensionType::TokenGroup
1275 | ExtensionType::GroupMemberPointer
1276 | ExtensionType::ConfidentialMintBurn
1277 | ExtensionType::TokenGroupMember
1278 | ExtensionType::ScaledUiAmount
1279 | ExtensionType::Pausable
1280 | ExtensionType::PermissionedBurn => AccountType::Mint,
1281 ExtensionType::ImmutableOwner
1282 | ExtensionType::TransferFeeAmount
1283 | ExtensionType::ConfidentialTransferAccount
1284 | ExtensionType::MemoTransfer
1285 | ExtensionType::NonTransferableAccount
1286 | ExtensionType::TransferHookAccount
1287 | ExtensionType::CpiGuard
1288 | ExtensionType::ConfidentialTransferFeeAmount
1289 | ExtensionType::PausableAccount => AccountType::Account,
1290 #[cfg(test)]
1291 ExtensionType::VariableLenMintTest => AccountType::Mint,
1292 #[cfg(test)]
1293 ExtensionType::AccountPaddingTest => AccountType::Account,
1294 #[cfg(test)]
1295 ExtensionType::MintPaddingTest => AccountType::Mint,
1296 }
1297 }
1298
1299 pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1302 let mut account_extension_types = vec![];
1303 for extension_type in mint_extension_types {
1304 match extension_type {
1305 ExtensionType::TransferFeeConfig => {
1306 account_extension_types.push(ExtensionType::TransferFeeAmount);
1307 }
1308 ExtensionType::NonTransferable => {
1309 account_extension_types.push(ExtensionType::NonTransferableAccount);
1310 account_extension_types.push(ExtensionType::ImmutableOwner);
1311 }
1312 ExtensionType::TransferHook => {
1313 account_extension_types.push(ExtensionType::TransferHookAccount);
1314 }
1315 ExtensionType::Pausable => {
1316 account_extension_types.push(ExtensionType::PausableAccount);
1317 }
1318 #[cfg(test)]
1319 ExtensionType::MintPaddingTest => {
1320 account_extension_types.push(ExtensionType::AccountPaddingTest);
1321 }
1322 _ => {}
1323 }
1324 }
1325 account_extension_types
1326 }
1327
1328 pub fn check_for_invalid_mint_extension_combinations(
1330 mint_extension_types: &[Self],
1331 ) -> Result<(), TokenError> {
1332 let mut transfer_fee_config = false;
1333 let mut confidential_transfer_mint = false;
1334 let mut confidential_transfer_fee_config = false;
1335 let mut confidential_mint_burn = false;
1336 let mut interest_bearing = false;
1337 let mut scaled_ui_amount = false;
1338 let mut non_transferable = false;
1339
1340 for extension_type in mint_extension_types {
1341 match extension_type {
1342 ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1343 ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1344 ExtensionType::ConfidentialTransferFeeConfig => {
1345 confidential_transfer_fee_config = true
1346 }
1347 ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1348 ExtensionType::InterestBearingConfig => interest_bearing = true,
1349 ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1350 ExtensionType::NonTransferable => non_transferable = true,
1351 _ => (),
1352 }
1353 }
1354
1355 if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1356 {
1357 return Err(TokenError::InvalidExtensionCombination);
1358 }
1359
1360 if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1361 return Err(TokenError::InvalidExtensionCombination);
1362 }
1363
1364 if confidential_mint_burn && !confidential_transfer_mint {
1365 return Err(TokenError::InvalidExtensionCombination);
1366 }
1367
1368 if scaled_ui_amount && interest_bearing {
1369 return Err(TokenError::InvalidExtensionCombination);
1370 }
1371
1372 if non_transferable && confidential_transfer_mint && !confidential_mint_burn {
1373 return Err(TokenError::InvalidExtensionCombination);
1374 }
1375
1376 Ok(())
1377 }
1378}
1379
1380pub trait BaseState: PackedSizeOf + IsInitialized {
1382 const ACCOUNT_TYPE: AccountType;
1384}
1385impl BaseState for Account {
1386 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1387}
1388impl BaseState for Mint {
1389 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1390}
1391impl BaseState for PodAccount {
1392 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1393}
1394impl BaseState for PodMint {
1395 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1396}
1397
1398pub trait Extension {
1401 const TYPE: ExtensionType;
1403}
1404
1405#[cfg(test)]
1414#[repr(C)]
1415#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1416pub struct MintPaddingTest {
1417 pub padding1: [u8; 128],
1419 pub padding2: [u8; 48],
1421 pub padding3: [u8; 9],
1423}
1424#[cfg(test)]
1425impl Extension for MintPaddingTest {
1426 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1427}
1428#[cfg(test)]
1429impl Default for MintPaddingTest {
1430 fn default() -> Self {
1431 Self {
1432 padding1: [1; 128],
1433 padding2: [2; 48],
1434 padding3: [3; 9],
1435 }
1436 }
1437}
1438#[cfg(test)]
1440#[repr(C)]
1441#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1442pub struct AccountPaddingTest(MintPaddingTest);
1443#[cfg(test)]
1444impl Extension for AccountPaddingTest {
1445 const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1446}
1447
1448pub fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1463 account_info: &AccountInfo,
1464 new_extension: &V,
1465 overwrite: bool,
1466) -> Result<(), ProgramError> {
1467 let previous_account_len = account_info.try_data_len()?;
1468 let new_account_len = {
1469 let data = account_info.try_borrow_data()?;
1470 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1471 state.try_get_new_account_len::<V>()?
1472 };
1473
1474 if new_account_len > previous_account_len {
1476 account_info.resize(new_account_len)?;
1477 }
1478 let mut buffer = account_info.try_borrow_mut_data()?;
1479 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1480 set_account_type::<S>(*buffer)?;
1481 }
1482 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1483
1484 let extension = state.init_extension::<V>(overwrite)?;
1486 *extension = *new_extension;
1487
1488 Ok(())
1489}
1490
1491pub fn alloc_and_serialize_variable_len_extension<
1500 S: BaseState + Pod,
1501 V: Extension + VariableLenPack,
1502>(
1503 account_info: &AccountInfo,
1504 new_extension: &V,
1505 overwrite: bool,
1506) -> Result<(), ProgramError> {
1507 let previous_account_len = account_info.try_data_len()?;
1508 let (new_account_len, extension_already_exists) = {
1509 let data = account_info.try_borrow_data()?;
1510 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1511 let new_account_len =
1512 state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1513 let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1514 (new_account_len, extension_already_exists)
1515 };
1516
1517 if extension_already_exists && !overwrite {
1518 return Err(TokenError::ExtensionAlreadyInitialized.into());
1519 }
1520
1521 if previous_account_len < new_account_len {
1522 account_info.resize(new_account_len)?;
1525 let mut buffer = account_info.try_borrow_mut_data()?;
1526 if extension_already_exists {
1527 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1528 state.realloc_variable_len_extension(new_extension)?;
1529 } else {
1530 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1531 set_account_type::<S>(*buffer)?;
1532 }
1533 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1535 state.init_variable_len_extension(new_extension, false)?;
1536 }
1537 } else {
1538 let mut buffer = account_info.try_borrow_mut_data()?;
1540 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1541 if extension_already_exists {
1542 state.realloc_variable_len_extension(new_extension)?;
1543 } else {
1544 state.init_variable_len_extension(new_extension, false)?;
1546 }
1547
1548 let removed_bytes = previous_account_len
1549 .checked_sub(new_account_len)
1550 .ok_or(ProgramError::AccountDataTooSmall)?;
1551 if removed_bytes > 0 {
1552 drop(buffer);
1554 account_info.resize(new_account_len)?;
1555 }
1556 }
1557 Ok(())
1558}
1559
1560#[cfg(test)]
1561mod test {
1562 use {
1563 super::*,
1564 crate::{
1565 pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1566 state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1567 },
1568 bytemuck::Pod,
1569 solana_account_info::{
1570 Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1571 },
1572 solana_address::Address,
1573 solana_nullable::MaybeNull,
1574 solana_zero_copy::unaligned::{Bool, U64},
1575 transfer_fee::test::test_transfer_fee_config,
1576 };
1577
1578 #[repr(C)]
1580 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1581 struct FixedLenMintTest {
1582 data: [u8; 8],
1583 }
1584 impl Extension for FixedLenMintTest {
1585 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1586 }
1587
1588 #[derive(Clone, Debug, PartialEq)]
1590 struct VariableLenMintTest {
1591 data: Vec<u8>,
1592 }
1593 impl Extension for VariableLenMintTest {
1594 const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1595 }
1596 impl VariableLenPack for VariableLenMintTest {
1597 fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1598 let data_start = size_of::<u64>();
1599 let end = data_start + self.data.len();
1600 if dst.len() < end {
1601 Err(ProgramError::InvalidAccountData)
1602 } else {
1603 dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1604 dst[data_start..end].copy_from_slice(&self.data);
1605 Ok(())
1606 }
1607 }
1608 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1609 let data_start = size_of::<u64>();
1610 let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1611 if src[data_start..data_start + length].len() != length {
1612 return Err(ProgramError::InvalidAccountData);
1613 }
1614 let data = Vec::from(&src[data_start..data_start + length]);
1615 Ok(Self { data })
1616 }
1617 fn get_packed_len(&self) -> Result<usize, ProgramError> {
1618 Ok(size_of::<u64>().saturating_add(self.data.len()))
1619 }
1620 }
1621
1622 const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1623 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1624 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1625 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1627 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1628 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ];
1631
1632 const MINT_WITH_EXTENSION: &[u8] = &[
1633 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1634 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1635 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1637 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1638 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 32, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1643 1, 1, ];
1645
1646 const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1647 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1648 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1650 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1653 4, 4, 4, 4, 4, 4, 2, 1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1658 7, 7, 7, 7, 7, 7, 2, 15, 0, 1, 0, 1, ];
1664
1665 #[test]
1666 fn unpack_opaque_buffer() {
1667 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1669 assert_eq!(state.base, &TEST_POD_MINT);
1670 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1671 assert_eq!(state.base, &TEST_POD_MINT);
1672 let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1673 let close_authority: MaybeNull<Address> =
1674 Some(Address::new_from_array([1; 32])).try_into().unwrap();
1675 assert_eq!(extension.close_authority, close_authority);
1676 assert_eq!(
1677 state.get_extension::<TransferFeeConfig>(),
1678 Err(ProgramError::InvalidAccountData)
1679 );
1680 assert_eq!(
1681 PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1682 Err(ProgramError::UninitializedAccount)
1683 );
1684
1685 let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1686 assert_eq!(state.base, &TEST_POD_MINT);
1687
1688 let mut test_mint = TEST_MINT_SLICE.to_vec();
1689 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1690 assert_eq!(state.base, &TEST_POD_MINT);
1691
1692 let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1694 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1695 let extension = state.get_extension::<TransferHookAccount>().unwrap();
1696 let transferring = Bool::from(true);
1697 assert_eq!(extension.transferring, transferring);
1698 assert_eq!(
1699 PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1700 Err(ProgramError::InvalidAccountData)
1701 );
1702
1703 let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1704 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1705
1706 let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1707 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1708 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1709 }
1710
1711 #[test]
1712 fn mint_fail_unpack_opaque_buffer() {
1713 let mut buffer = vec![0, 3];
1715 assert_eq!(
1716 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1717 Err(ProgramError::InvalidAccountData)
1718 );
1719 assert_eq!(
1720 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1721 Err(ProgramError::InvalidAccountData)
1722 );
1723 assert_eq!(
1724 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1725 Err(ProgramError::InvalidAccountData)
1726 );
1727
1728 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1730 buffer[BASE_ACCOUNT_LENGTH] = 3;
1731 assert_eq!(
1732 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1733 Err(ProgramError::InvalidAccountData)
1734 );
1735
1736 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1738 buffer[45] = 0;
1739 assert_eq!(
1740 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1741 Err(ProgramError::UninitializedAccount)
1742 );
1743
1744 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1746 buffer[PodMint::SIZE_OF] = 100;
1747 assert_eq!(
1748 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1749 Err(ProgramError::InvalidAccountData)
1750 );
1751
1752 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1754 buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1755 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1756 assert_eq!(
1757 state.get_extension::<TransferFeeConfig>(),
1758 Err(ProgramError::InvalidAccountData)
1759 );
1760
1761 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1763 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1764 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1765 assert_eq!(
1766 state.get_extension::<TransferFeeConfig>(),
1767 Err(ProgramError::InvalidAccountData)
1768 );
1769
1770 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1772 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1773 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1774 assert_eq!(
1775 state.get_extension::<TransferFeeConfig>(),
1776 Err(ProgramError::InvalidAccountData)
1777 );
1778
1779 let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1781 let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1782 assert_eq!(
1783 state.get_extension::<MintCloseAuthority>(),
1784 Err(ProgramError::InvalidAccountData)
1785 );
1786 }
1787
1788 #[test]
1789 fn account_fail_unpack_opaque_buffer() {
1790 let mut buffer = vec![0, 3];
1792 assert_eq!(
1793 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1794 Err(ProgramError::InvalidAccountData)
1795 );
1796 assert_eq!(
1797 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1798 Err(ProgramError::InvalidAccountData)
1799 );
1800 assert_eq!(
1801 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1802 Err(ProgramError::InvalidAccountData)
1803 );
1804
1805 let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1808 assert_eq!(
1809 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1810 Err(ProgramError::UninitializedAccount)
1811 );
1812 assert_eq!(
1813 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1814 Err(ProgramError::UninitializedAccount)
1815 );
1816
1817 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1819 buffer[BASE_ACCOUNT_LENGTH] = 3;
1820 assert_eq!(
1821 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1822 Err(ProgramError::InvalidAccountData)
1823 );
1824
1825 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1827 buffer[108] = 0;
1828 assert_eq!(
1829 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1830 Err(ProgramError::UninitializedAccount)
1831 );
1832
1833 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1835 buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1836 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1837 assert_eq!(
1838 state.get_extension::<TransferHookAccount>(),
1839 Err(ProgramError::InvalidAccountData),
1840 );
1841
1842 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1844 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1845 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1846 assert_eq!(
1847 state.get_extension::<TransferHookAccount>(),
1848 Err(ProgramError::InvalidAccountData)
1849 );
1850
1851 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1853 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1854 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1855 assert_eq!(
1856 state.get_extension::<TransferHookAccount>(),
1857 Err(ProgramError::InvalidAccountData)
1858 );
1859
1860 let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1862 let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1863 assert_eq!(
1864 state.get_extension::<TransferHookAccount>(),
1865 Err(ProgramError::InvalidAccountData)
1866 );
1867 }
1868
1869 #[test]
1870 fn get_extension_types_with_opaque_buffer() {
1871 assert_eq!(
1873 get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1874 ProgramError::InvalidAccountData,
1875 );
1876 assert_eq!(
1878 get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1879 ProgramError::InvalidAccountData,
1880 );
1881 assert_eq!(
1883 get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1884 TlvDataInfo {
1885 extension_types: vec![ExtensionType::try_from(1).unwrap()],
1886 used_len: add_type_and_length_to_len(0),
1887 }
1888 );
1889 assert_eq!(
1891 get_tlv_data_info(&[0, 0]).unwrap(),
1892 TlvDataInfo {
1893 extension_types: vec![],
1894 used_len: 0
1895 }
1896 );
1897 }
1898
1899 #[test]
1900 fn mint_with_extension_pack_unpack() {
1901 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1902 ExtensionType::MintCloseAuthority,
1903 ExtensionType::TransferFeeConfig,
1904 ])
1905 .unwrap();
1906 let mut buffer = vec![0; mint_size];
1907
1908 assert_eq!(
1910 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1911 Err(ProgramError::UninitializedAccount),
1912 );
1913
1914 let mut state =
1915 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1916 assert_eq!(
1918 state.init_extension::<TransferFeeAmount>(true),
1919 Err(ProgramError::InvalidAccountData),
1920 );
1921
1922 let close_authority: MaybeNull<Address> =
1924 Some(Address::new_from_array([1; 32])).try_into().unwrap();
1925 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1926 extension.close_authority = close_authority;
1927 assert_eq!(
1928 &state.get_extension_types().unwrap(),
1929 &[ExtensionType::MintCloseAuthority]
1930 );
1931
1932 assert_eq!(
1934 state.init_extension::<MintCloseAuthority>(false),
1935 Err(ProgramError::Custom(
1936 TokenError::ExtensionAlreadyInitialized as u32
1937 ))
1938 );
1939
1940 assert_eq!(
1942 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1943 Err(ProgramError::Custom(
1944 TokenError::ExtensionBaseMismatch as u32
1945 ))
1946 );
1947
1948 assert_eq!(
1950 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1951 Err(ProgramError::UninitializedAccount),
1952 );
1953
1954 let mut state =
1956 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1957 *state.base = TEST_POD_MINT;
1958 state.init_account_type().unwrap();
1959
1960 let mut expect = TEST_MINT_SLICE.to_vec();
1962 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
1964 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1965 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
1966 expect.extend_from_slice(&[1; 32]); expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1968 expect.extend_from_slice(&[0; size_of::<Length>()]);
1969 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1970 assert_eq!(expect, buffer);
1971
1972 assert_eq!(
1974 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1975 Err(TokenError::AlreadyInUse.into()),
1976 );
1977
1978 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1980
1981 *state.base = TEST_POD_MINT;
1983 state.base.supply = (u64::from(state.base.supply) + 100).into();
1984
1985 let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1987 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1988
1989 let close_authority: MaybeNull<Address> = None.try_into().unwrap();
1991 unpacked_extension.close_authority = close_authority;
1992
1993 let base = *state.base;
1995 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1996 assert_eq!(state.base, &base);
1997 let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1998 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1999
2000 let mut expect = vec![];
2002 expect.extend_from_slice(bytemuck::bytes_of(&base));
2003 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2005 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2006 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2007 expect.extend_from_slice(&[0; 32]);
2008 expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2009 expect.extend_from_slice(&[0; size_of::<Length>()]);
2010 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2011 assert_eq!(expect, buffer);
2012
2013 assert_eq!(
2015 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2016 Err(ProgramError::UninitializedAccount),
2017 );
2018
2019 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2020 let mint_transfer_fee = test_transfer_fee_config();
2022 let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2023 new_extension.transfer_fee_config_authority =
2024 mint_transfer_fee.transfer_fee_config_authority;
2025 new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2026 new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2027 new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2028 new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2029
2030 assert_eq!(
2031 &state.get_extension_types().unwrap(),
2032 &[
2033 ExtensionType::MintCloseAuthority,
2034 ExtensionType::TransferFeeConfig
2035 ]
2036 );
2037
2038 let mut expect = vec![];
2040 expect.extend_from_slice(bytemuck::bytes_of(&base));
2041 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2043 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2044 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2045 expect.extend_from_slice(&[0; 32]); expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2047 expect.extend_from_slice(&(size_of::<TransferFeeConfig>() as u16).to_le_bytes());
2048 expect.extend_from_slice(bytemuck::bytes_of(&mint_transfer_fee));
2049 assert_eq!(expect, buffer);
2050
2051 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2053 assert_eq!(
2054 state.init_extension::<MintPaddingTest>(true),
2055 Err(ProgramError::InvalidAccountData),
2056 );
2057 }
2058
2059 #[test]
2060 fn mint_extension_any_order() {
2061 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2062 ExtensionType::MintCloseAuthority,
2063 ExtensionType::TransferFeeConfig,
2064 ])
2065 .unwrap();
2066 let mut buffer = vec![0; mint_size];
2067
2068 let mut state =
2069 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2070 let close_authority: MaybeNull<Address> =
2072 Some(Address::new_from_array([1; 32])).try_into().unwrap();
2073 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2074 extension.close_authority = close_authority;
2075
2076 let mint_transfer_fee = test_transfer_fee_config();
2077 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2078 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2079 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2080 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2081 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2082 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2083
2084 assert_eq!(
2085 &state.get_extension_types().unwrap(),
2086 &[
2087 ExtensionType::MintCloseAuthority,
2088 ExtensionType::TransferFeeConfig
2089 ]
2090 );
2091
2092 let mut state =
2094 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2095 *state.base = TEST_POD_MINT;
2096 state.init_account_type().unwrap();
2097
2098 let mut other_buffer = vec![0; mint_size];
2099 let mut state =
2100 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2101
2102 *state.base = TEST_POD_MINT;
2104 state.init_account_type().unwrap();
2105
2106 let mint_transfer_fee = test_transfer_fee_config();
2108 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2109 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2110 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2111 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2112 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2113 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2114
2115 let close_authority: MaybeNull<Address> =
2116 Some(Address::new_from_array([1; 32])).try_into().unwrap();
2117 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2118 extension.close_authority = close_authority;
2119
2120 assert_eq!(
2121 &state.get_extension_types().unwrap(),
2122 &[
2123 ExtensionType::TransferFeeConfig,
2124 ExtensionType::MintCloseAuthority
2125 ]
2126 );
2127
2128 assert_ne!(buffer, other_buffer);
2130 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2131 let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2132
2133 assert_eq!(
2135 state.get_extension::<TransferFeeConfig>().unwrap(),
2136 other_state.get_extension::<TransferFeeConfig>().unwrap()
2137 );
2138 assert_eq!(
2139 state.get_extension::<MintCloseAuthority>().unwrap(),
2140 other_state.get_extension::<MintCloseAuthority>().unwrap()
2141 );
2142 assert_eq!(state.base, other_state.base);
2143 }
2144
2145 #[test]
2146 fn mint_with_multisig_len() {
2147 let mut buffer = vec![0; Multisig::LEN];
2148 assert_eq!(
2149 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2150 Err(ProgramError::InvalidAccountData),
2151 );
2152 let mint_size =
2153 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2154 .unwrap();
2155 assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2156 let mut buffer = vec![0; mint_size];
2157
2158 let mut state =
2160 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2161 *state.base = TEST_POD_MINT;
2162 state.init_account_type().unwrap();
2163
2164 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2166 extension.padding1 = [1; 128];
2167 extension.padding2 = [1; 48];
2168 extension.padding3 = [1; 9];
2169
2170 assert_eq!(
2171 &state.get_extension_types().unwrap(),
2172 &[ExtensionType::MintPaddingTest]
2173 );
2174
2175 let mut expect = TEST_MINT_SLICE.to_vec();
2177 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2179 expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2180 expect.extend_from_slice(&(size_of::<MintPaddingTest>() as u16).to_le_bytes());
2181 expect.extend_from_slice(&vec![1; size_of::<MintPaddingTest>()]);
2182 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2183 assert_eq!(expect, buffer);
2184 }
2185
2186 #[test]
2187 fn account_with_extension_pack_unpack() {
2188 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2189 ExtensionType::TransferFeeAmount,
2190 ])
2191 .unwrap();
2192 let mut buffer = vec![0; account_size];
2193
2194 assert_eq!(
2196 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2197 Err(ProgramError::UninitializedAccount),
2198 );
2199
2200 let mut state =
2201 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2202 assert_eq!(
2204 state.init_extension::<TransferFeeConfig>(true),
2205 Err(ProgramError::InvalidAccountData),
2206 );
2207 let withheld_amount = U64::from(u64::MAX);
2209 let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2210 extension.withheld_amount = withheld_amount;
2211
2212 assert_eq!(
2213 &state.get_extension_types().unwrap(),
2214 &[ExtensionType::TransferFeeAmount]
2215 );
2216
2217 assert_eq!(
2219 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2220 Err(ProgramError::UninitializedAccount),
2221 );
2222
2223 let mut state =
2225 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2226 *state.base = TEST_POD_ACCOUNT;
2227 state.init_account_type().unwrap();
2228 let base = *state.base;
2229
2230 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2232 expect.push(AccountType::Account.into());
2233 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2234 expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2235 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2236 assert_eq!(expect, buffer);
2237
2238 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2240 assert_eq!(state.base, &base);
2241 assert_eq!(
2242 &state.get_extension_types().unwrap(),
2243 &[ExtensionType::TransferFeeAmount]
2244 );
2245
2246 *state.base = TEST_POD_ACCOUNT;
2248 state.base.amount = (u64::from(state.base.amount) + 100).into();
2249
2250 let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2252 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2253
2254 let withheld_amount = U64::from(u32::MAX as u64);
2256 unpacked_extension.withheld_amount = withheld_amount;
2257
2258 let base = *state.base;
2260 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2261 assert_eq!(state.base, &base);
2262 let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2263 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2264
2265 let mut expect = vec![];
2267 expect.extend_from_slice(bytemuck::bytes_of(&base));
2268 expect.push(AccountType::Account.into());
2269 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2270 expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2271 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2272 assert_eq!(expect, buffer);
2273
2274 assert_eq!(
2276 PodStateWithExtensions::<PodMint>::unpack(&buffer),
2277 Err(ProgramError::InvalidAccountData),
2278 );
2279 }
2280
2281 #[test]
2282 fn account_with_multisig_len() {
2283 let mut buffer = vec![0; Multisig::LEN];
2284 assert_eq!(
2285 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2286 Err(ProgramError::InvalidAccountData),
2287 );
2288 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2289 ExtensionType::AccountPaddingTest,
2290 ])
2291 .unwrap();
2292 assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2293 let mut buffer = vec![0; account_size];
2294
2295 let mut state =
2297 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2298 *state.base = TEST_POD_ACCOUNT;
2299 state.init_account_type().unwrap();
2300
2301 let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2303 extension.0.padding1 = [2; 128];
2304 extension.0.padding2 = [2; 48];
2305 extension.0.padding3 = [2; 9];
2306
2307 assert_eq!(
2308 &state.get_extension_types().unwrap(),
2309 &[ExtensionType::AccountPaddingTest]
2310 );
2311
2312 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2314 expect.push(AccountType::Account.into());
2315 expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2316 expect.extend_from_slice(&(size_of::<AccountPaddingTest>() as u16).to_le_bytes());
2317 expect.extend_from_slice(&vec![2; size_of::<AccountPaddingTest>()]);
2318 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2319 assert_eq!(expect, buffer);
2320 }
2321
2322 #[test]
2323 fn test_set_account_type() {
2324 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2326 let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2327 ExtensionType::ImmutableOwner,
2328 ])
2329 .unwrap()
2330 - buffer.len();
2331 buffer.append(&mut vec![0; needed_len]);
2332 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2333 assert_eq!(err, ProgramError::InvalidAccountData);
2334 set_account_type::<PodAccount>(&mut buffer).unwrap();
2335 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2337 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2338 assert_eq!(state.account_type[0], AccountType::Account as u8);
2339 state.init_extension::<ImmutableOwner>(true).unwrap(); let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2343 buffer.append(&mut vec![0; 2]);
2344 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2345 assert_eq!(err, ProgramError::InvalidAccountData);
2346 set_account_type::<PodAccount>(&mut buffer).unwrap();
2347 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2349 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2350 assert_eq!(state.account_type[0], AccountType::Account as u8);
2351
2352 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2354 buffer.append(&mut vec![2, 0]);
2355 let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2356 set_account_type::<PodAccount>(&mut buffer).unwrap();
2357 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2358 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2359 assert_eq!(state.account_type[0], AccountType::Account as u8);
2360
2361 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2363 buffer.append(&mut vec![1, 0]);
2364 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2365 assert_eq!(err, ProgramError::InvalidAccountData);
2366 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2367 assert_eq!(err, ProgramError::InvalidAccountData);
2368
2369 let mut buffer = TEST_MINT_SLICE.to_vec();
2371 let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2372 ExtensionType::MintCloseAuthority,
2373 ])
2374 .unwrap()
2375 - buffer.len();
2376 buffer.append(&mut vec![0; needed_len]);
2377 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2378 assert_eq!(err, ProgramError::InvalidAccountData);
2379 set_account_type::<PodMint>(&mut buffer).unwrap();
2380 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2382 assert_eq!(state.base, &TEST_POD_MINT);
2383 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2384 state.init_extension::<MintCloseAuthority>(true).unwrap();
2385
2386 let mut buffer = TEST_MINT_SLICE.to_vec();
2388 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2389 buffer.append(&mut vec![0; 2]);
2390 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2391 assert_eq!(err, ProgramError::InvalidAccountData);
2392 set_account_type::<PodMint>(&mut buffer).unwrap();
2393 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2395 assert_eq!(state.base, &TEST_POD_MINT);
2396 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2397
2398 let mut buffer = TEST_MINT_SLICE.to_vec();
2400 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2401 buffer.append(&mut vec![1, 0]);
2402 set_account_type::<PodMint>(&mut buffer).unwrap();
2403 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2404 assert_eq!(state.base, &TEST_POD_MINT);
2405 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2406
2407 let mut buffer = TEST_MINT_SLICE.to_vec();
2409 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2410 buffer.append(&mut vec![2, 0]);
2411 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2412 assert_eq!(err, ProgramError::InvalidAccountData);
2413 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2414 assert_eq!(err, ProgramError::InvalidAccountData);
2415 }
2416
2417 #[test]
2418 fn test_set_account_type_wrongly() {
2419 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2421 buffer.append(&mut vec![0; 2]);
2422 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2423 assert_eq!(err, ProgramError::InvalidAccountData);
2424
2425 let mut buffer = TEST_MINT_SLICE.to_vec();
2427 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2428 buffer.append(&mut vec![0; 2]);
2429 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2430 assert_eq!(err, ProgramError::InvalidAccountData);
2431 }
2432
2433 #[test]
2434 fn test_get_required_init_account_extensions() {
2435 let mint_extensions = vec![
2437 ExtensionType::MintCloseAuthority,
2438 ExtensionType::Uninitialized,
2439 ];
2440 assert_eq!(
2441 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2442 vec![]
2443 );
2444
2445 let mint_extensions = vec![
2447 ExtensionType::TransferFeeConfig,
2448 ExtensionType::MintCloseAuthority,
2449 ];
2450 assert_eq!(
2451 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2452 vec![ExtensionType::TransferFeeAmount]
2453 );
2454
2455 let mint_extensions = vec![
2457 ExtensionType::TransferFeeConfig,
2458 ExtensionType::MintPaddingTest,
2459 ];
2460 assert_eq!(
2461 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2462 vec![
2463 ExtensionType::TransferFeeAmount,
2464 ExtensionType::AccountPaddingTest
2465 ]
2466 );
2467
2468 let mint_extensions = vec![
2470 ExtensionType::TransferFeeConfig,
2471 ExtensionType::TransferFeeConfig,
2472 ];
2473 assert_eq!(
2474 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2475 vec![
2476 ExtensionType::TransferFeeAmount,
2477 ExtensionType::TransferFeeAmount
2478 ]
2479 );
2480 }
2481
2482 #[test]
2483 fn mint_without_extensions() {
2484 let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2485 let mut buffer = vec![0; space];
2486 assert_eq!(
2487 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2488 Err(ProgramError::InvalidAccountData),
2489 );
2490
2491 let mut state =
2493 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2494 *state.base = TEST_POD_MINT;
2495 state.init_account_type().unwrap();
2496
2497 assert_eq!(
2499 state.init_extension::<TransferFeeConfig>(true),
2500 Err(ProgramError::InvalidAccountData),
2501 );
2502
2503 assert_eq!(TEST_MINT_SLICE, buffer);
2504 }
2505
2506 #[test]
2507 fn test_init_nonzero_default() {
2508 let mint_size =
2509 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2510 .unwrap();
2511 let mut buffer = vec![0; mint_size];
2512 let mut state =
2513 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2514 *state.base = TEST_POD_MINT;
2515 state.init_account_type().unwrap();
2516 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2517 assert_eq!(extension.padding1, [1; 128]);
2518 assert_eq!(extension.padding2, [2; 48]);
2519 assert_eq!(extension.padding3, [3; 9]);
2520 }
2521
2522 #[test]
2523 fn test_init_buffer_too_small() {
2524 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2525 ExtensionType::MintCloseAuthority,
2526 ])
2527 .unwrap();
2528 let mut buffer = vec![0; mint_size - 1];
2529 let mut state =
2530 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2531 let err = state
2532 .init_extension::<MintCloseAuthority>(true)
2533 .unwrap_err();
2534 assert_eq!(err, ProgramError::InvalidAccountData);
2535
2536 state.tlv_data[0] = 3;
2537 state.tlv_data[2] = 32;
2538 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2539 assert_eq!(err, ProgramError::InvalidAccountData);
2540
2541 let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2542 let err =
2543 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2544 assert_eq!(err, ProgramError::InvalidAccountData);
2545
2546 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2548 let mut state =
2549 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2550 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2551 assert_eq!(err, ProgramError::InvalidAccountData);
2552
2553 assert_eq!(state.get_extension_types().unwrap(), vec![]);
2554
2555 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2557 let state =
2558 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2559 assert_eq!(state.get_extension_types().unwrap(), []);
2560 }
2561
2562 #[test]
2563 fn test_extension_with_no_data() {
2564 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2565 ExtensionType::ImmutableOwner,
2566 ])
2567 .unwrap();
2568 let mut buffer = vec![0; account_size];
2569 let mut state =
2570 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2571 *state.base = TEST_POD_ACCOUNT;
2572 state.init_account_type().unwrap();
2573
2574 let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2575 assert_eq!(
2576 err,
2577 ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2578 );
2579
2580 state.init_extension::<ImmutableOwner>(true).unwrap();
2581 assert_eq!(
2582 get_first_extension_type(state.tlv_data).unwrap(),
2583 Some(ExtensionType::ImmutableOwner)
2584 );
2585 assert_eq!(
2586 get_tlv_data_info(state.tlv_data).unwrap(),
2587 TlvDataInfo {
2588 extension_types: vec![ExtensionType::ImmutableOwner],
2589 used_len: add_type_and_length_to_len(0)
2590 }
2591 );
2592 }
2593
2594 #[test]
2595 fn fail_account_len_with_metadata() {
2596 assert_eq!(
2597 ExtensionType::try_calculate_account_len::<PodMint>(&[
2598 ExtensionType::MintCloseAuthority,
2599 ExtensionType::VariableLenMintTest,
2600 ExtensionType::TransferFeeConfig,
2601 ])
2602 .unwrap_err(),
2603 ProgramError::InvalidArgument
2604 );
2605 }
2606
2607 #[test]
2608 fn alloc() {
2609 let variable_len = VariableLenMintTest { data: vec![1] };
2610 let alloc_size = variable_len.get_packed_len().unwrap();
2611 let account_size =
2612 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2613 let mut buffer = vec![0; account_size];
2614 let mut state =
2615 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2616 state
2617 .init_variable_len_extension(&variable_len, false)
2618 .unwrap();
2619
2620 assert_eq!(
2622 state
2623 .init_variable_len_extension(&variable_len, false)
2624 .unwrap_err(),
2625 TokenError::ExtensionAlreadyInitialized.into()
2626 );
2627
2628 state
2630 .init_variable_len_extension(&variable_len, true)
2631 .unwrap();
2632
2633 assert_eq!(
2635 state
2636 .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2637 .unwrap_err(),
2638 TokenError::InvalidLengthForAlloc.into()
2639 );
2640
2641 assert_eq!(
2643 state
2644 .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2645 .unwrap_err(),
2646 ProgramError::InvalidAccountData
2647 );
2648 }
2649
2650 #[test]
2651 fn realloc() {
2652 let small_variable_len = VariableLenMintTest {
2653 data: vec![1, 2, 3],
2654 };
2655 let base_variable_len = VariableLenMintTest {
2656 data: vec![1, 2, 3, 4],
2657 };
2658 let big_variable_len = VariableLenMintTest {
2659 data: vec![1, 2, 3, 4, 5],
2660 };
2661 let too_big_variable_len = VariableLenMintTest {
2662 data: vec![1, 2, 3, 4, 5, 6],
2663 };
2664 let account_size =
2665 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2666 .unwrap()
2667 + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2668 let mut buffer = vec![0; account_size];
2669 let mut state =
2670 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2671
2672 state
2674 .init_variable_len_extension(&base_variable_len, false)
2675 .unwrap();
2676 let max_pubkey: MaybeNull<Address> =
2677 Some(Address::new_from_array([255; 32])).try_into().unwrap();
2678 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2679 extension.authority = max_pubkey;
2680 extension.metadata_address = max_pubkey;
2681
2682 state
2684 .realloc_variable_len_extension(&big_variable_len)
2685 .unwrap();
2686 let extension = state
2687 .get_variable_len_extension::<VariableLenMintTest>()
2688 .unwrap();
2689 assert_eq!(extension, big_variable_len);
2690 let extension = state.get_extension::<MetadataPointer>().unwrap();
2691 assert_eq!(extension.authority, max_pubkey);
2692 assert_eq!(extension.metadata_address, max_pubkey);
2693
2694 state
2696 .realloc_variable_len_extension(&small_variable_len)
2697 .unwrap();
2698 let extension = state
2699 .get_variable_len_extension::<VariableLenMintTest>()
2700 .unwrap();
2701 assert_eq!(extension, small_variable_len);
2702 let extension = state.get_extension::<MetadataPointer>().unwrap();
2703 assert_eq!(extension.authority, max_pubkey);
2704 assert_eq!(extension.metadata_address, max_pubkey);
2705 let diff = big_variable_len.get_packed_len().unwrap()
2706 - small_variable_len.get_packed_len().unwrap();
2707 assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2708
2709 let mut state =
2711 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2712 assert_eq!(
2714 state
2715 .realloc_variable_len_extension(&too_big_variable_len)
2716 .unwrap_err(),
2717 ProgramError::InvalidAccountData,
2718 );
2719 }
2720
2721 #[test]
2722 fn account_len() {
2723 let small_variable_len = VariableLenMintTest {
2724 data: vec![20, 30, 40],
2725 };
2726 let variable_len = VariableLenMintTest {
2727 data: vec![20, 30, 40, 50],
2728 };
2729 let big_variable_len = VariableLenMintTest {
2730 data: vec![20, 30, 40, 50, 60],
2731 };
2732 let value_len = variable_len.get_packed_len().unwrap();
2733 let account_size =
2734 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2735 let mut buffer = vec![0; account_size];
2736 let mut state =
2737 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2738
2739 let current_len = state.try_get_account_len().unwrap();
2742 assert_eq!(current_len, PodMint::SIZE_OF);
2743 let new_len = state
2744 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2745 &variable_len,
2746 )
2747 .unwrap();
2748 assert_eq!(
2749 new_len,
2750 BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2751 );
2752
2753 state
2754 .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2755 .unwrap();
2756 let current_len = state.try_get_account_len().unwrap();
2757 assert_eq!(current_len, new_len);
2758
2759 let new_len = state
2761 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2762 &small_variable_len,
2763 )
2764 .unwrap();
2765 assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2766
2767 let new_len = state
2769 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2770 &big_variable_len,
2771 )
2772 .unwrap();
2773 assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2774
2775 let new_len = state
2777 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2778 &variable_len,
2779 )
2780 .unwrap();
2781 assert_eq!(new_len, current_len);
2782 }
2783
2784 struct SolanaAccountData {
2787 data: Vec<u8>,
2788 lamports: u64,
2789 owner: Address,
2790 }
2791 impl SolanaAccountData {
2792 fn new(account_data: &[u8]) -> Self {
2795 let mut data = vec![];
2796 data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2797 data.extend_from_slice(account_data);
2798 data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2799 Self {
2800 data,
2801 lamports: 10,
2802 owner: Address::new_unique(),
2803 }
2804 }
2805
2806 fn data(&self) -> &[u8] {
2809 let start = size_of::<u64>();
2810 let len = self.len();
2811 &self.data[start..start + len]
2812 }
2813
2814 fn len(&self) -> usize {
2816 self.data
2817 .get(..size_of::<u64>())
2818 .and_then(|slice| slice.try_into().ok())
2819 .map(u64::from_le_bytes)
2820 .unwrap() as usize
2821 }
2822 }
2823 impl GetAccount for SolanaAccountData {
2824 fn get(&mut self) -> (&mut u64, &mut [u8], &Address, bool) {
2825 let start = size_of::<u64>();
2827 let len = self.len();
2828 (
2829 &mut self.lamports,
2830 &mut self.data[start..start + len],
2831 &self.owner,
2832 false,
2833 )
2834 }
2835 }
2836
2837 #[test]
2838 fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2839 let fixed_len = FixedLenMintTest {
2840 data: [1, 2, 3, 4, 5, 6, 7, 8],
2841 };
2842 let value_len = size_of::<FixedLenMintTest>();
2843 let base_account_size = PodMint::SIZE_OF;
2844 let mut buffer = vec![0; base_account_size];
2845 let state =
2846 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2847 *state.base = TEST_POD_MINT;
2848
2849 let mut data = SolanaAccountData::new(&buffer);
2850 let key = Address::new_unique();
2851 let account_info = (&key, &mut data).into_account_info();
2852
2853 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2854 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2855 assert_eq!(data.len(), new_account_len);
2856 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2857 assert_eq!(
2858 state.get_extension::<FixedLenMintTest>().unwrap(),
2859 &fixed_len,
2860 );
2861
2862 let account_info = (&key, &mut data).into_account_info();
2864 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2865
2866 let account_info = (&key, &mut data).into_account_info();
2868 assert_eq!(
2869 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2870 TokenError::ExtensionAlreadyInitialized.into()
2871 );
2872 }
2873
2874 #[test]
2875 fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2876 let variable_len = VariableLenMintTest { data: vec![20, 99] };
2877 let value_len = variable_len.get_packed_len().unwrap();
2878 let base_account_size = PodMint::SIZE_OF;
2879 let mut buffer = vec![0; base_account_size];
2880 let state =
2881 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2882 *state.base = TEST_POD_MINT;
2883
2884 let mut data = SolanaAccountData::new(&buffer);
2885 let key = Address::new_unique();
2886 let account_info = (&key, &mut data).into_account_info();
2887
2888 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2889 &account_info,
2890 &variable_len,
2891 false,
2892 )
2893 .unwrap();
2894 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2895 assert_eq!(data.len(), new_account_len);
2896 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2897 assert_eq!(
2898 state
2899 .get_variable_len_extension::<VariableLenMintTest>()
2900 .unwrap(),
2901 variable_len
2902 );
2903
2904 let account_info = (&key, &mut data).into_account_info();
2906 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2907 &account_info,
2908 &variable_len,
2909 true,
2910 )
2911 .unwrap();
2912
2913 let account_info = (&key, &mut data).into_account_info();
2915 assert_eq!(
2916 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2917 &account_info,
2918 &variable_len,
2919 false,
2920 )
2921 .unwrap_err(),
2922 TokenError::ExtensionAlreadyInitialized.into()
2923 );
2924 }
2925
2926 #[test]
2927 fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2928 let fixed_len = FixedLenMintTest {
2929 data: [1, 2, 3, 4, 5, 6, 7, 8],
2930 };
2931 let value_len = size_of::<FixedLenMintTest>();
2932 let account_size =
2933 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2934 .unwrap()
2935 + add_type_and_length_to_len(value_len);
2936 let mut buffer = vec![0; account_size];
2937 let mut state =
2938 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2939 *state.base = TEST_POD_MINT;
2940 state.init_account_type().unwrap();
2941
2942 let test_key: MaybeNull<Address> =
2943 Some(Address::new_from_array([20; 32])).try_into().unwrap();
2944 let extension = state.init_extension::<GroupPointer>(false).unwrap();
2945 extension.authority = test_key;
2946 extension.group_address = test_key;
2947
2948 let mut data = SolanaAccountData::new(&buffer);
2949 let key = Address::new_unique();
2950 let account_info = (&key, &mut data).into_account_info();
2951
2952 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2953 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2954 + add_type_and_length_to_len(value_len)
2955 + add_type_and_length_to_len(size_of::<GroupPointer>());
2956 assert_eq!(data.len(), new_account_len);
2957 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2958 assert_eq!(
2959 state.get_extension::<FixedLenMintTest>().unwrap(),
2960 &fixed_len,
2961 );
2962 let extension = state.get_extension::<GroupPointer>().unwrap();
2963 assert_eq!(extension.authority, test_key);
2964 assert_eq!(extension.group_address, test_key);
2965
2966 let account_info = (&key, &mut data).into_account_info();
2968 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2969
2970 let account_info = (&key, &mut data).into_account_info();
2972 assert_eq!(
2973 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2974 TokenError::ExtensionAlreadyInitialized.into()
2975 );
2976 }
2977
2978 #[test]
2979 fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2980 let variable_len = VariableLenMintTest { data: vec![42, 6] };
2981 let value_len = variable_len.get_packed_len().unwrap();
2982 let account_size =
2983 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2984 .unwrap()
2985 + add_type_and_length_to_len(value_len);
2986 let mut buffer = vec![0; account_size];
2987 let mut state =
2988 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2989 *state.base = TEST_POD_MINT;
2990 state.init_account_type().unwrap();
2991
2992 let test_key: MaybeNull<Address> =
2993 Some(Address::new_from_array([20; 32])).try_into().unwrap();
2994 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2995 extension.authority = test_key;
2996 extension.metadata_address = test_key;
2997
2998 let mut data = SolanaAccountData::new(&buffer);
2999 let key = Address::new_unique();
3000 let account_info = (&key, &mut data).into_account_info();
3001
3002 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3003 &account_info,
3004 &variable_len,
3005 false,
3006 )
3007 .unwrap();
3008 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3009 + add_type_and_length_to_len(value_len)
3010 + add_type_and_length_to_len(size_of::<MetadataPointer>());
3011 assert_eq!(data.len(), new_account_len);
3012 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3013 assert_eq!(
3014 state
3015 .get_variable_len_extension::<VariableLenMintTest>()
3016 .unwrap(),
3017 variable_len
3018 );
3019 let extension = state.get_extension::<MetadataPointer>().unwrap();
3020 assert_eq!(extension.authority, test_key);
3021 assert_eq!(extension.metadata_address, test_key);
3022
3023 let account_info = (&key, &mut data).into_account_info();
3025 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3026 &account_info,
3027 &variable_len,
3028 true,
3029 )
3030 .unwrap();
3031
3032 let account_info = (&key, &mut data).into_account_info();
3034 assert_eq!(
3035 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3036 &account_info,
3037 &variable_len,
3038 false,
3039 )
3040 .unwrap_err(),
3041 TokenError::ExtensionAlreadyInitialized.into()
3042 );
3043 }
3044
3045 #[test]
3046 fn realloc_variable_len_tlv_in_account_info() {
3047 let variable_len = VariableLenMintTest {
3048 data: vec![1, 2, 3, 4, 5],
3049 };
3050 let alloc_size = variable_len.get_packed_len().unwrap();
3051 let account_size =
3052 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3053 .unwrap()
3054 + add_type_and_length_to_len(alloc_size);
3055 let mut buffer = vec![0; account_size];
3056 let mut state =
3057 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3058 *state.base = TEST_POD_MINT;
3059 state.init_account_type().unwrap();
3060
3061 state
3063 .init_variable_len_extension(&variable_len, false)
3064 .unwrap();
3065 let max_pubkey: MaybeNull<Address> =
3066 Some(Address::new_from_array([255; 32])).try_into().unwrap();
3067 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3068 extension.authority = max_pubkey;
3069 extension.metadata_address = max_pubkey;
3070
3071 let mut data = SolanaAccountData::new(&buffer);
3073 let key = Address::new_unique();
3074 let account_info = (&key, &mut data).into_account_info();
3075 let variable_len = VariableLenMintTest { data: vec![1, 2] };
3076 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3077 &account_info,
3078 &variable_len,
3079 true,
3080 )
3081 .unwrap();
3082
3083 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3084 let extension = state.get_extension::<MetadataPointer>().unwrap();
3085 assert_eq!(extension.authority, max_pubkey);
3086 assert_eq!(extension.metadata_address, max_pubkey);
3087 let extension = state
3088 .get_variable_len_extension::<VariableLenMintTest>()
3089 .unwrap();
3090 assert_eq!(extension, variable_len);
3091 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3092
3093 let account_info = (&key, &mut data).into_account_info();
3095 let variable_len = VariableLenMintTest {
3096 data: vec![1, 2, 3, 4, 5, 6, 7],
3097 };
3098 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3099 &account_info,
3100 &variable_len,
3101 true,
3102 )
3103 .unwrap();
3104
3105 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3106 let extension = state.get_extension::<MetadataPointer>().unwrap();
3107 assert_eq!(extension.authority, max_pubkey);
3108 assert_eq!(extension.metadata_address, max_pubkey);
3109 let extension = state
3110 .get_variable_len_extension::<VariableLenMintTest>()
3111 .unwrap();
3112 assert_eq!(extension, variable_len);
3113 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3114
3115 let account_info = (&key, &mut data).into_account_info();
3117 let variable_len = VariableLenMintTest {
3118 data: vec![7, 6, 5, 4, 3, 2, 1],
3119 };
3120 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3121 &account_info,
3122 &variable_len,
3123 true,
3124 )
3125 .unwrap();
3126
3127 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3128 let extension = state.get_extension::<MetadataPointer>().unwrap();
3129 assert_eq!(extension.authority, max_pubkey);
3130 assert_eq!(extension.metadata_address, max_pubkey);
3131 let extension = state
3132 .get_variable_len_extension::<VariableLenMintTest>()
3133 .unwrap();
3134 assert_eq!(extension, variable_len);
3135 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3136 }
3137}