1#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use {
6 crate::{
7 error::TokenError,
8 extension::{
9 account_len::TlvLenAccumulator,
10 confidential_mint_burn::ConfidentialMintBurn,
11 confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
12 confidential_transfer_fee::{
13 ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
14 },
15 cpi_guard::CpiGuard,
16 default_account_state::DefaultAccountState,
17 group_member_pointer::GroupMemberPointer,
18 group_pointer::GroupPointer,
19 immutable_owner::ImmutableOwner,
20 interest_bearing_mint::InterestBearingConfig,
21 memo_transfer::MemoTransfer,
22 metadata_pointer::MetadataPointer,
23 mint_close_authority::MintCloseAuthority,
24 non_transferable::{NonTransferable, NonTransferableAccount},
25 pausable::{PausableAccount, PausableConfig},
26 permanent_delegate::PermanentDelegate,
27 permissioned_burn::PermissionedBurnConfig,
28 scaled_ui_amount::ScaledUiAmountConfig,
29 transfer_fee::{TransferFeeAmount, TransferFeeConfig},
30 transfer_hook::{TransferHook, TransferHookAccount},
31 },
32 pod::{PodAccount, PodMint},
33 state::{Account, Mint, Multisig, PackedSizeOf},
34 },
35 alloc::{vec, vec::Vec},
36 bytemuck::{Pod, Zeroable},
37 core::{
38 cmp::Ordering,
39 convert::{TryFrom, TryInto},
40 mem::size_of,
41 },
42 num_enum::{IntoPrimitive, TryFromPrimitive},
43 solana_account_info::AccountInfo,
44 solana_program_error::ProgramError,
45 solana_program_pack::{IsInitialized, Pack},
46 solana_zero_copy::unaligned::U16,
47 spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
48 spl_type_length_value::variable_len_pack::VariableLenPack,
49};
50
51pub mod account_len;
53pub mod confidential_transfer;
55pub mod confidential_transfer_fee;
57pub mod cpi_guard;
59pub mod default_account_state;
61pub mod group_member_pointer;
63pub mod group_pointer;
65pub mod immutable_owner;
67pub mod interest_bearing_mint;
69pub mod memo_transfer;
71pub mod metadata_pointer;
73pub mod mint_close_authority;
75pub mod non_transferable;
77pub mod pausable;
79pub mod permanent_delegate;
81pub mod permissioned_burn;
83pub mod scaled_ui_amount;
85pub mod token_group;
87pub mod token_metadata;
89pub mod transfer_fee;
91pub mod transfer_hook;
93
94pub mod confidential_mint_burn;
96
97#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
99#[repr(transparent)]
100pub struct Length(U16);
101impl From<Length> for usize {
102 fn from(n: Length) -> Self {
103 Self::from(u16::from(n.0))
104 }
105}
106impl TryFrom<usize> for Length {
107 type Error = ProgramError;
108 fn try_from(n: usize) -> Result<Self, Self::Error> {
109 u16::try_from(n)
110 .map(|v| Self(U16::from(v)))
111 .map_err(|_| ProgramError::AccountDataTooSmall)
112 }
113}
114
115fn get_tlv_indices(type_start: usize) -> TlvIndices {
117 let length_start = type_start.saturating_add(size_of::<ExtensionType>());
118 let value_start = length_start.saturating_add(size_of::<Length>());
119 TlvIndices {
120 type_start,
121 length_start,
122 value_start,
123 }
124}
125
126const fn adjust_len_for_multisig(account_len: usize) -> usize {
129 if account_len == Multisig::LEN {
130 account_len.saturating_add(size_of::<ExtensionType>())
131 } else {
132 account_len
133 }
134}
135
136const fn add_type_and_length_to_len(value_len: usize) -> usize {
139 value_len
140 .saturating_add(size_of::<ExtensionType>())
141 .saturating_add(size_of::<Length>())
142}
143
144#[derive(Debug)]
147struct TlvIndices {
148 pub type_start: usize,
149 pub length_start: usize,
150 pub value_start: usize,
151}
152fn get_extension_indices<V: Extension>(
153 tlv_data: &[u8],
154 init: bool,
155) -> Result<TlvIndices, ProgramError> {
156 let mut start_index = 0;
157 while start_index < tlv_data.len() {
158 let tlv_indices = get_tlv_indices(start_index);
159 if tlv_data.len() < tlv_indices.value_start {
160 return Err(ProgramError::InvalidAccountData);
161 }
162 let extension_type = u16::from_le_bytes(
163 tlv_data[tlv_indices.type_start..tlv_indices.length_start]
164 .try_into()
165 .map_err(|_| ProgramError::InvalidAccountData)?,
166 );
167 if extension_type == u16::from(V::TYPE) {
168 return Ok(tlv_indices);
170 } else if extension_type == u16::from(ExtensionType::Uninitialized) {
173 if init {
174 return Ok(tlv_indices);
175 } else {
176 return Err(TokenError::ExtensionNotFound.into());
177 }
178 } else {
179 let length = bytemuck::try_from_bytes::<Length>(
180 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
181 )
182 .map_err(|_| ProgramError::InvalidArgument)?;
183 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
184 start_index = value_end_index;
185 }
186 }
187 Err(ProgramError::InvalidAccountData)
188}
189
190#[derive(Debug, PartialEq)]
193struct TlvDataInfo {
194 extension_types: Vec<ExtensionType>,
196 used_len: usize,
201}
202
203fn try_for_each_tlv_extension_type<F>(tlv_data: &[u8], mut f: F) -> Result<usize, ProgramError>
206where
207 F: FnMut(ExtensionType) -> Result<(), ProgramError>,
208{
209 let mut start_index = 0;
210 while start_index < tlv_data.len() {
211 let tlv_indices = get_tlv_indices(start_index);
212 if tlv_data.len() < tlv_indices.length_start {
213 return Ok(tlv_indices.type_start);
216 }
217 let extension_type =
218 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
219 if extension_type == ExtensionType::Uninitialized {
220 return Ok(tlv_indices.type_start);
221 } else {
222 if tlv_data.len() < tlv_indices.value_start {
223 return Err(ProgramError::InvalidAccountData);
225 }
226 let length = bytemuck::try_from_bytes::<Length>(
227 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
228 )
229 .map_err(|_| ProgramError::InvalidArgument)?;
230
231 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
232 if value_end_index > tlv_data.len() {
233 return Err(ProgramError::InvalidAccountData);
235 }
236 f(extension_type)?;
237 start_index = value_end_index;
238 }
239 }
240 Ok(start_index)
241}
242
243fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
246 let mut extension_types = vec![];
247 let used_len = try_for_each_tlv_extension_type(tlv_data, |extension_type| {
248 extension_types.push(extension_type);
249 Ok(())
250 })?;
251 Ok(TlvDataInfo {
252 extension_types,
253 used_len,
254 })
255}
256
257fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
258 if tlv_data.is_empty() {
259 Ok(None)
260 } else {
261 let tlv_indices = get_tlv_indices(0);
262 if tlv_data.len() <= tlv_indices.length_start {
263 return Ok(None);
264 }
265 let extension_type =
266 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
267 if extension_type == ExtensionType::Uninitialized {
268 Ok(None)
269 } else {
270 Ok(Some(extension_type))
271 }
272 }
273}
274
275fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
276 if input.len() == Multisig::LEN || input.len() < minimum_len {
277 Err(ProgramError::InvalidAccountData)
278 } else {
279 Ok(())
280 }
281}
282
283fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
284 if account_type != S::ACCOUNT_TYPE {
285 Err(ProgramError::InvalidAccountData)
286 } else {
287 Ok(())
288 }
289}
290
291const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
312const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
315
316fn type_and_tlv_indices<S: BaseState>(
317 rest_input: &[u8],
318) -> Result<Option<(usize, usize)>, ProgramError> {
319 if rest_input.is_empty() {
320 Ok(None)
321 } else {
322 let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
323 let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
325 if rest_input.len() < tlv_start_index {
326 return Err(ProgramError::InvalidAccountData);
327 }
328 if rest_input[..account_type_index].iter().any(|&b| b != 0) {
329 Err(ProgramError::InvalidAccountData)
330 } else {
331 Ok(Some((account_type_index, tlv_start_index)))
332 }
333 }
334}
335
336fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
339 const ACCOUNT_INITIALIZED_INDEX: usize = 108; if input.len() != BASE_ACCOUNT_LENGTH {
342 return Err(ProgramError::InvalidAccountData);
343 }
344 Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
345}
346
347fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
348 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
349 return Err(ProgramError::InvalidAccountData);
350 }
351 let TlvIndices {
352 type_start: _,
353 length_start,
354 value_start,
355 } = get_extension_indices::<V>(tlv_data, false)?;
356 let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
359 .map_err(|_| ProgramError::InvalidArgument)?;
360 let value_end = value_start.saturating_add(usize::from(*length));
361 if tlv_data.len() < value_end {
362 return Err(ProgramError::InvalidAccountData);
363 }
364 Ok(&tlv_data[value_start..value_end])
365}
366
367fn get_extension_bytes_mut<S: BaseState, V: Extension>(
368 tlv_data: &mut [u8],
369) -> Result<&mut [u8], ProgramError> {
370 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
371 return Err(ProgramError::InvalidAccountData);
372 }
373 let TlvIndices {
374 type_start: _,
375 length_start,
376 value_start,
377 } = get_extension_indices::<V>(tlv_data, false)?;
378 let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
381 .map_err(|_| ProgramError::InvalidArgument)?;
382 let value_end = value_start.saturating_add(usize::from(*length));
383 if tlv_data.len() < value_end {
384 return Err(ProgramError::InvalidAccountData);
385 }
386 Ok(&mut tlv_data[value_start..value_end])
387}
388
389fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
395 tlv_data: &[u8],
396 new_extension_len: usize,
397) -> Result<usize, ProgramError> {
398 let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
400 let tlv_info = get_tlv_data_info(tlv_data)?;
401 let current_len = tlv_info
404 .used_len
405 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
406 let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
408 .map(|x| add_type_and_length_to_len(x.len()))
409 .unwrap_or(0);
410 let new_len = current_len
411 .saturating_sub(current_extension_len)
412 .saturating_add(new_extension_tlv_len);
413 Ok(adjust_len_for_multisig(new_len))
414}
415
416pub trait BaseStateWithExtensions<S: BaseState> {
418 fn get_tlv_data(&self) -> &[u8];
420
421 fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
423 get_extension_bytes::<S, V>(self.get_tlv_data())
424 }
425
426 fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
428 bytemuck::try_from_bytes::<V>(self.get_extension_bytes::<V>()?)
429 .map_err(|_| ProgramError::InvalidArgument)
430 }
431
432 fn get_variable_len_extension<V: Extension + VariableLenPack>(
434 &self,
435 ) -> Result<V, ProgramError> {
436 let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
437 V::unpack_from_slice(data)
438 }
439
440 fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
442 get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
443 }
444
445 fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
447 get_first_extension_type(self.get_tlv_data())
448 }
449
450 fn try_get_account_len(&self) -> Result<usize, ProgramError> {
452 let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
453 if tlv_info.extension_types.is_empty() {
454 Ok(S::SIZE_OF)
455 } else {
456 let total_len = tlv_info
457 .used_len
458 .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
459 Ok(adjust_len_for_multisig(total_len))
460 }
461 }
462 fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
467 try_get_new_account_len_for_extension_len::<S, V>(self.get_tlv_data(), size_of::<V>())
468 }
469
470 fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
473 &self,
474 new_extension: &V,
475 ) -> Result<usize, ProgramError> {
476 try_get_new_account_len_for_extension_len::<S, V>(
477 self.get_tlv_data(),
478 new_extension.get_packed_len()?,
479 )
480 }
481}
482
483#[derive(Clone, Debug, PartialEq)]
486pub struct StateWithExtensionsOwned<S: BaseState> {
487 pub base: S,
489 tlv_data: Vec<u8>,
491}
492impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
493 pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
497 check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
498 let mut rest = input.split_off(S::SIZE_OF);
499 let base = S::unpack(&input)?;
500 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
501 let account_type = AccountType::try_from(rest[account_type_index])
503 .map_err(|_| ProgramError::InvalidAccountData)?;
504 check_account_type::<S>(account_type)?;
505 let tlv_data = rest.split_off(tlv_start_index);
506 Ok(Self { base, tlv_data })
507 } else {
508 Ok(Self {
509 base,
510 tlv_data: vec![],
511 })
512 }
513 }
514}
515
516impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
517 fn get_tlv_data(&self) -> &[u8] {
518 &self.tlv_data
519 }
520}
521
522#[derive(Debug, PartialEq)]
525pub struct StateWithExtensions<'data, S: BaseState + Pack> {
526 pub base: S,
528 tlv_data: &'data [u8],
530}
531impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
532 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
536 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
537 let (base_data, rest) = input.split_at(S::SIZE_OF);
538 let base = S::unpack(base_data)?;
539 let tlv_data = unpack_tlv_data::<S>(rest)?;
540 Ok(Self { base, tlv_data })
541 }
542}
543impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
544 fn get_tlv_data(&self) -> &[u8] {
545 self.tlv_data
546 }
547}
548
549#[derive(Debug, PartialEq)]
552pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
553 pub base: &'data S,
555 tlv_data: &'data [u8],
557}
558impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
559 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
563 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
564 let (base_data, rest) = input.split_at(S::SIZE_OF);
565 let base =
566 bytemuck::try_from_bytes::<S>(base_data).map_err(|_| ProgramError::InvalidArgument)?;
567 if !base.is_initialized() {
568 Err(ProgramError::UninitializedAccount)
569 } else {
570 let tlv_data = unpack_tlv_data::<S>(rest)?;
571 Ok(Self { base, tlv_data })
572 }
573 }
574}
575impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
576 fn get_tlv_data(&self) -> &[u8] {
577 self.tlv_data
578 }
579}
580
581pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
583 fn get_tlv_data_mut(&mut self) -> &mut [u8];
585
586 fn get_account_type_mut(&mut self) -> &mut [u8];
588
589 fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
591 get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
592 }
593
594 fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
597 bytemuck::try_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
598 .map_err(|_| ProgramError::InvalidArgument)
599 }
600
601 fn pack_variable_len_extension<V: Extension + VariableLenPack>(
604 &mut self,
605 extension: &V,
606 ) -> Result<(), ProgramError> {
607 let data = self.get_extension_bytes_mut::<V>()?;
608 extension.pack_into_slice(data)
611 }
612
613 fn init_extension<V: Extension + Pod + Default>(
619 &mut self,
620 overwrite: bool,
621 ) -> Result<&mut V, ProgramError> {
622 let length = size_of::<V>();
623 let buffer = self.alloc::<V>(length, overwrite)?;
624 let extension_ref =
625 bytemuck::try_from_bytes_mut::<V>(buffer).map_err(|_| ProgramError::InvalidArgument)?;
626 *extension_ref = V::default();
627 Ok(extension_ref)
628 }
629
630 fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
636 &mut self,
637 new_extension: &V,
638 ) -> Result<(), ProgramError> {
639 let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
640 new_extension.pack_into_slice(data)
641 }
642
643 fn realloc<V: Extension + VariableLenPack>(
653 &mut self,
654 length: usize,
655 ) -> Result<&mut [u8], ProgramError> {
656 let tlv_data = self.get_tlv_data_mut();
657 let TlvIndices {
658 type_start: _,
659 length_start,
660 value_start,
661 } = get_extension_indices::<V>(tlv_data, false)?;
662 let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
663 let data_len = tlv_data.len();
664
665 let length_ref =
666 bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
667 .map_err(|_| ProgramError::InvalidArgument)?;
668 let old_length = usize::from(*length_ref);
669
670 if old_length < length {
672 let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
673 if new_tlv_len > data_len {
674 return Err(ProgramError::InvalidAccountData);
675 }
676 }
677
678 *length_ref = Length::try_from(length)?;
681
682 let old_value_end = value_start.saturating_add(old_length);
683 let new_value_end = value_start.saturating_add(length);
684 tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
685 match old_length.cmp(&length) {
686 Ordering::Greater => {
687 let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
689 tlv_data[new_tlv_len..tlv_len].fill(0);
690 }
691 Ordering::Less => {
692 tlv_data[old_value_end..new_value_end].fill(0);
694 }
695 Ordering::Equal => {} }
697
698 Ok(&mut tlv_data[value_start..new_value_end])
699 }
700
701 fn init_variable_len_extension<V: Extension + VariableLenPack>(
707 &mut self,
708 extension: &V,
709 overwrite: bool,
710 ) -> Result<(), ProgramError> {
711 let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
712 extension.pack_into_slice(data)
713 }
714
715 fn alloc<V: Extension>(
717 &mut self,
718 length: usize,
719 overwrite: bool,
720 ) -> Result<&mut [u8], ProgramError> {
721 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
722 return Err(ProgramError::InvalidAccountData);
723 }
724 let tlv_data = self.get_tlv_data_mut();
725 let TlvIndices {
726 type_start,
727 length_start,
728 value_start,
729 } = get_extension_indices::<V>(tlv_data, true)?;
730
731 if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
732 return Err(ProgramError::InvalidAccountData);
733 }
734 let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
735
736 if extension_type == ExtensionType::Uninitialized || overwrite {
737 let extension_type_array: [u8; 2] = V::TYPE.into();
739 let extension_type_ref = &mut tlv_data[type_start..length_start];
740 extension_type_ref.copy_from_slice(&extension_type_array);
741 let length_ref =
743 bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
744 .map_err(|_| ProgramError::InvalidArgument)?;
745
746 if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
749 return Err(TokenError::InvalidLengthForAlloc.into());
750 }
751
752 *length_ref = Length::try_from(length)?;
753
754 let value_end = value_start.saturating_add(length);
755 Ok(&mut tlv_data[value_start..value_end])
756 } else {
757 Err(TokenError::ExtensionAlreadyInitialized.into())
759 }
760 }
761
762 fn init_account_extension_from_type(
769 &mut self,
770 extension_type: ExtensionType,
771 ) -> Result<(), ProgramError> {
772 if extension_type.get_account_type() != AccountType::Account {
773 return Ok(());
774 }
775 match extension_type {
776 ExtensionType::TransferFeeAmount => {
777 self.init_extension::<TransferFeeAmount>(true).map(|_| ())
778 }
779 ExtensionType::ImmutableOwner => {
780 self.init_extension::<ImmutableOwner>(true).map(|_| ())
781 }
782 ExtensionType::NonTransferableAccount => self
783 .init_extension::<NonTransferableAccount>(true)
784 .map(|_| ()),
785 ExtensionType::TransferHookAccount => {
786 self.init_extension::<TransferHookAccount>(true).map(|_| ())
787 }
788 ExtensionType::ConfidentialTransferAccount => Ok(()),
791 ExtensionType::PausableAccount => {
792 self.init_extension::<PausableAccount>(true).map(|_| ())
793 }
794 #[cfg(test)]
795 ExtensionType::AccountPaddingTest => {
796 self.init_extension::<AccountPaddingTest>(true).map(|_| ())
797 }
798 _ => unreachable!(),
799 }
800 }
801
802 fn init_account_type(&mut self) -> Result<(), ProgramError> {
807 let first_extension_type = self.get_first_extension_type()?;
808 let account_type = self.get_account_type_mut();
809 if !account_type.is_empty() {
810 if let Some(extension_type) = first_extension_type {
811 let account_type = extension_type.get_account_type();
812 if account_type != S::ACCOUNT_TYPE {
813 return Err(TokenError::ExtensionBaseMismatch.into());
814 }
815 }
816 account_type[0] = S::ACCOUNT_TYPE.into();
817 }
818 Ok(())
819 }
820
821 fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
824 if let Some(extension_type) = self.get_first_extension_type()? {
825 let account_type = extension_type.get_account_type();
826 if account_type != S::ACCOUNT_TYPE {
827 return Err(TokenError::ExtensionBaseMismatch.into());
828 }
829 }
830 Ok(())
831 }
832}
833
834#[derive(Debug, PartialEq)]
837pub struct StateWithExtensionsMut<'data, S: BaseState> {
838 pub base: S,
840 base_data: &'data mut [u8],
842 account_type: &'data mut [u8],
844 tlv_data: &'data mut [u8],
846}
847impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
848 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
852 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
853 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
854 let base = S::unpack(base_data)?;
855 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
856 Ok(Self {
857 base,
858 base_data,
859 account_type,
860 tlv_data,
861 })
862 }
863
864 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
869 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
870 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
871 let base = S::unpack_unchecked(base_data)?;
872 if base.is_initialized() {
873 return Err(TokenError::AlreadyInUse.into());
874 }
875 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
876 let state = Self {
877 base,
878 base_data,
879 account_type,
880 tlv_data,
881 };
882 state.check_account_type_matches_extension_type()?;
883 Ok(state)
884 }
885
886 pub fn pack_base(&mut self) {
888 S::pack_into_slice(&self.base, self.base_data);
889 }
890}
891impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
892 fn get_tlv_data(&self) -> &[u8] {
893 self.tlv_data
894 }
895}
896impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
897 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
898 self.tlv_data
899 }
900 fn get_account_type_mut(&mut self) -> &mut [u8] {
901 self.account_type
902 }
903}
904
905#[derive(Debug, PartialEq)]
908pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
909 pub base: &'data mut S,
911 account_type: &'data mut [u8],
913 tlv_data: &'data mut [u8],
915}
916impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
917 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
921 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
922 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
923 let base = bytemuck::try_from_bytes_mut::<S>(base_data)
924 .map_err(|_| ProgramError::InvalidArgument)?;
925 if !base.is_initialized() {
926 Err(ProgramError::UninitializedAccount)
927 } else {
928 let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
929 Ok(Self {
930 base,
931 account_type,
932 tlv_data,
933 })
934 }
935 }
936
937 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
942 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
943 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
944 let base = bytemuck::try_from_bytes_mut::<S>(base_data)
945 .map_err(|_| ProgramError::InvalidArgument)?;
946 if base.is_initialized() {
947 return Err(TokenError::AlreadyInUse.into());
948 }
949 let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
950 let state = Self {
951 base,
952 account_type,
953 tlv_data,
954 };
955 state.check_account_type_matches_extension_type()?;
956 Ok(state)
957 }
958}
959
960impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
961 fn get_tlv_data(&self) -> &[u8] {
962 self.tlv_data
963 }
964}
965impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
966 fn get_tlv_data_mut(&mut self) -> &mut [u8] {
967 self.tlv_data
968 }
969 fn get_account_type_mut(&mut self) -> &mut [u8] {
970 self.account_type
971 }
972}
973
974fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
975 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
976 let account_type = AccountType::try_from(rest[account_type_index])
978 .map_err(|_| ProgramError::InvalidAccountData)?;
979 check_account_type::<S>(account_type)?;
980 Ok(&rest[tlv_start_index..])
981 } else {
982 Ok(&[])
983 }
984}
985
986fn unpack_type_and_tlv_data_with_check_mut<
987 S: BaseState,
988 F: Fn(AccountType) -> Result<(), ProgramError>,
989>(
990 rest: &mut [u8],
991 check_fn: F,
992) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
993 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
994 let account_type = AccountType::try_from(rest[account_type_index])
996 .map_err(|_| ProgramError::InvalidAccountData)?;
997 check_fn(account_type)?;
998 let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
999 Ok((
1000 &mut account_type[account_type_index..tlv_start_index],
1001 tlv_data,
1002 ))
1003 } else {
1004 Ok((&mut [], &mut []))
1005 }
1006}
1007
1008fn unpack_type_and_tlv_data_mut<S: BaseState>(
1009 rest: &mut [u8],
1010) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1011 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
1012}
1013
1014fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
1015 rest: &mut [u8],
1016) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1017 unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1018 if account_type != AccountType::Uninitialized {
1019 Err(ProgramError::InvalidAccountData)
1020 } else {
1021 Ok(())
1022 }
1023 })
1024}
1025
1026pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1031 check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1032 let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1033 if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1034 return Err(ProgramError::InvalidAccountData);
1035 }
1036 if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1037 let mut account_type = AccountType::try_from(rest[account_type_index])
1038 .map_err(|_| ProgramError::InvalidAccountData)?;
1039 if account_type == AccountType::Uninitialized {
1040 rest[account_type_index] = S::ACCOUNT_TYPE.into();
1041 account_type = S::ACCOUNT_TYPE;
1042 }
1043 check_account_type::<S>(account_type)?;
1044 Ok(())
1045 } else {
1046 Err(ProgramError::InvalidAccountData)
1047 }
1048}
1049
1050#[repr(u8)]
1055#[derive(Clone, Copy, Debug, Default, PartialEq, TryFromPrimitive, IntoPrimitive)]
1056pub enum AccountType {
1057 #[default]
1059 Uninitialized,
1060 Mint,
1062 Account,
1064}
1065
1066#[repr(u16)]
1070#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1071#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
1072#[cfg_attr(test, derive(strum_macros::EnumIter))]
1073#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1074pub enum ExtensionType {
1075 Uninitialized,
1078 TransferFeeConfig,
1081 TransferFeeAmount,
1083 MintCloseAuthority,
1085 ConfidentialTransferMint,
1087 ConfidentialTransferAccount,
1089 DefaultAccountState,
1091 ImmutableOwner,
1093 MemoTransfer,
1095 NonTransferable,
1097 InterestBearingConfig,
1099 CpiGuard,
1101 PermanentDelegate,
1103 NonTransferableAccount,
1106 TransferHook,
1109 TransferHookAccount,
1112 ConfidentialTransferFeeConfig,
1115 ConfidentialTransferFeeAmount,
1117 MetadataPointer,
1120 TokenMetadata,
1122 GroupPointer,
1125 TokenGroup,
1127 GroupMemberPointer,
1130 TokenGroupMember,
1132 ConfidentialMintBurn,
1134 ScaledUiAmount,
1136 Pausable,
1138 PausableAccount,
1140 PermissionedBurn,
1142
1143 #[cfg(test)]
1145 VariableLenMintTest = u16::MAX - 2,
1146 #[cfg(test)]
1149 AccountPaddingTest,
1150 #[cfg(test)]
1153 MintPaddingTest,
1154}
1155impl TryFrom<&[u8]> for ExtensionType {
1156 type Error = ProgramError;
1157 fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1158 Self::try_from(u16::from_le_bytes(
1159 a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1160 ))
1161 .map_err(|_| ProgramError::InvalidAccountData)
1162 }
1163}
1164impl From<ExtensionType> for [u8; 2] {
1165 fn from(a: ExtensionType) -> Self {
1166 u16::from(a).to_le_bytes()
1167 }
1168}
1169impl ExtensionType {
1170 const fn sized(&self) -> bool {
1175 match self {
1176 ExtensionType::TokenMetadata => false,
1177 #[cfg(test)]
1178 ExtensionType::VariableLenMintTest => false,
1179 _ => true,
1180 }
1181 }
1182
1183 fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1187 if !self.sized() {
1188 return Err(ProgramError::InvalidArgument);
1189 }
1190 Ok(match self {
1191 ExtensionType::Uninitialized => 0,
1192 ExtensionType::TransferFeeConfig => size_of::<TransferFeeConfig>(),
1193 ExtensionType::TransferFeeAmount => size_of::<TransferFeeAmount>(),
1194 ExtensionType::MintCloseAuthority => size_of::<MintCloseAuthority>(),
1195 ExtensionType::ImmutableOwner => size_of::<ImmutableOwner>(),
1196 ExtensionType::ConfidentialTransferMint => size_of::<ConfidentialTransferMint>(),
1197 ExtensionType::ConfidentialTransferAccount => size_of::<ConfidentialTransferAccount>(),
1198 ExtensionType::DefaultAccountState => size_of::<DefaultAccountState>(),
1199 ExtensionType::MemoTransfer => size_of::<MemoTransfer>(),
1200 ExtensionType::NonTransferable => size_of::<NonTransferable>(),
1201 ExtensionType::InterestBearingConfig => size_of::<InterestBearingConfig>(),
1202 ExtensionType::CpiGuard => size_of::<CpiGuard>(),
1203 ExtensionType::PermanentDelegate => size_of::<PermanentDelegate>(),
1204 ExtensionType::NonTransferableAccount => size_of::<NonTransferableAccount>(),
1205 ExtensionType::TransferHook => size_of::<TransferHook>(),
1206 ExtensionType::TransferHookAccount => size_of::<TransferHookAccount>(),
1207 ExtensionType::ConfidentialTransferFeeConfig => {
1208 size_of::<ConfidentialTransferFeeConfig>()
1209 }
1210 ExtensionType::ConfidentialTransferFeeAmount => {
1211 size_of::<ConfidentialTransferFeeAmount>()
1212 }
1213 ExtensionType::MetadataPointer => size_of::<MetadataPointer>(),
1214 ExtensionType::TokenMetadata => unreachable!(),
1215 ExtensionType::GroupPointer => size_of::<GroupPointer>(),
1216 ExtensionType::TokenGroup => size_of::<TokenGroup>(),
1217 ExtensionType::GroupMemberPointer => size_of::<GroupMemberPointer>(),
1218 ExtensionType::TokenGroupMember => size_of::<TokenGroupMember>(),
1219 ExtensionType::ConfidentialMintBurn => size_of::<ConfidentialMintBurn>(),
1220 ExtensionType::ScaledUiAmount => size_of::<ScaledUiAmountConfig>(),
1221 ExtensionType::Pausable => size_of::<PausableConfig>(),
1222 ExtensionType::PausableAccount => size_of::<PausableAccount>(),
1223 ExtensionType::PermissionedBurn => size_of::<PermissionedBurnConfig>(),
1224 #[cfg(test)]
1225 ExtensionType::AccountPaddingTest => size_of::<AccountPaddingTest>(),
1226 #[cfg(test)]
1227 ExtensionType::MintPaddingTest => size_of::<MintPaddingTest>(),
1228 #[cfg(test)]
1229 ExtensionType::VariableLenMintTest => unreachable!(),
1230 })
1231 }
1232
1233 fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1237 Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1238 }
1239
1240 pub fn try_calculate_account_len<S: BaseState>(
1244 extension_types: &[Self],
1245 ) -> Result<usize, ProgramError> {
1246 let mut tlv_len = TlvLenAccumulator::default();
1247 for &extension_type in extension_types {
1248 tlv_len.insert(extension_type)?;
1249 }
1250 Ok(tlv_len.account_len::<S>())
1251 }
1252
1253 pub fn get_account_type(&self) -> AccountType {
1255 match self {
1256 ExtensionType::Uninitialized => AccountType::Uninitialized,
1257 ExtensionType::TransferFeeConfig
1258 | ExtensionType::MintCloseAuthority
1259 | ExtensionType::ConfidentialTransferMint
1260 | ExtensionType::DefaultAccountState
1261 | ExtensionType::NonTransferable
1262 | ExtensionType::InterestBearingConfig
1263 | ExtensionType::PermanentDelegate
1264 | ExtensionType::TransferHook
1265 | ExtensionType::ConfidentialTransferFeeConfig
1266 | ExtensionType::MetadataPointer
1267 | ExtensionType::TokenMetadata
1268 | ExtensionType::GroupPointer
1269 | ExtensionType::TokenGroup
1270 | ExtensionType::GroupMemberPointer
1271 | ExtensionType::ConfidentialMintBurn
1272 | ExtensionType::TokenGroupMember
1273 | ExtensionType::ScaledUiAmount
1274 | ExtensionType::Pausable
1275 | ExtensionType::PermissionedBurn => AccountType::Mint,
1276 ExtensionType::ImmutableOwner
1277 | ExtensionType::TransferFeeAmount
1278 | ExtensionType::ConfidentialTransferAccount
1279 | ExtensionType::MemoTransfer
1280 | ExtensionType::NonTransferableAccount
1281 | ExtensionType::TransferHookAccount
1282 | ExtensionType::CpiGuard
1283 | ExtensionType::ConfidentialTransferFeeAmount
1284 | ExtensionType::PausableAccount => AccountType::Account,
1285 #[cfg(test)]
1286 ExtensionType::VariableLenMintTest => AccountType::Mint,
1287 #[cfg(test)]
1288 ExtensionType::AccountPaddingTest => AccountType::Account,
1289 #[cfg(test)]
1290 ExtensionType::MintPaddingTest => AccountType::Mint,
1291 }
1292 }
1293
1294 fn required_init_account_extensions(&self) -> &'static [Self] {
1297 match self {
1298 ExtensionType::TransferFeeConfig => &[ExtensionType::TransferFeeAmount],
1299 ExtensionType::NonTransferable => &[
1300 ExtensionType::NonTransferableAccount,
1301 ExtensionType::ImmutableOwner,
1302 ],
1303 ExtensionType::TransferHook => &[ExtensionType::TransferHookAccount],
1304 ExtensionType::Pausable => &[ExtensionType::PausableAccount],
1305 #[cfg(test)]
1306 ExtensionType::MintPaddingTest => &[ExtensionType::AccountPaddingTest],
1307 _ => &[],
1308 }
1309 }
1310
1311 #[deprecated(
1314 since = "3.0.1",
1315 note = "Use `account_len::try_for_each_required_init_account_extension` instead"
1316 )]
1317 pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1318 mint_extension_types
1319 .iter()
1320 .flat_map(|extension_type| extension_type.required_init_account_extensions())
1321 .copied()
1322 .collect()
1323 }
1324
1325 pub fn check_for_invalid_mint_extension_combinations(
1327 mint_extension_types: &[Self],
1328 ) -> Result<(), TokenError> {
1329 let mut transfer_fee_config = false;
1330 let mut confidential_transfer_mint = false;
1331 let mut confidential_transfer_fee_config = false;
1332 let mut confidential_mint_burn = false;
1333 let mut interest_bearing = false;
1334 let mut scaled_ui_amount = false;
1335 let mut non_transferable = false;
1336
1337 for extension_type in mint_extension_types {
1338 match extension_type {
1339 ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1340 ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1341 ExtensionType::ConfidentialTransferFeeConfig => {
1342 confidential_transfer_fee_config = true
1343 }
1344 ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1345 ExtensionType::InterestBearingConfig => interest_bearing = true,
1346 ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1347 ExtensionType::NonTransferable => non_transferable = true,
1348 _ => (),
1349 }
1350 }
1351
1352 if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1353 {
1354 return Err(TokenError::InvalidExtensionCombination);
1355 }
1356
1357 if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1358 return Err(TokenError::InvalidExtensionCombination);
1359 }
1360
1361 if confidential_mint_burn && !confidential_transfer_mint {
1362 return Err(TokenError::InvalidExtensionCombination);
1363 }
1364
1365 if scaled_ui_amount && interest_bearing {
1366 return Err(TokenError::InvalidExtensionCombination);
1367 }
1368
1369 if non_transferable && confidential_transfer_mint && !confidential_mint_burn {
1370 return Err(TokenError::InvalidExtensionCombination);
1371 }
1372
1373 Ok(())
1374 }
1375}
1376
1377pub trait BaseState: PackedSizeOf + IsInitialized {
1379 const ACCOUNT_TYPE: AccountType;
1381}
1382impl BaseState for Account {
1383 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1384}
1385impl BaseState for Mint {
1386 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1387}
1388impl BaseState for PodAccount {
1389 const ACCOUNT_TYPE: AccountType = AccountType::Account;
1390}
1391impl BaseState for PodMint {
1392 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1393}
1394
1395pub trait Extension {
1398 const TYPE: ExtensionType;
1400}
1401
1402#[cfg(test)]
1411#[repr(C)]
1412#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1413pub struct MintPaddingTest {
1414 pub padding1: [u8; 128],
1416 pub padding2: [u8; 48],
1418 pub padding3: [u8; 9],
1420}
1421#[cfg(test)]
1422impl Extension for MintPaddingTest {
1423 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1424}
1425#[cfg(test)]
1426impl Default for MintPaddingTest {
1427 fn default() -> Self {
1428 Self {
1429 padding1: [1; 128],
1430 padding2: [2; 48],
1431 padding3: [3; 9],
1432 }
1433 }
1434}
1435#[cfg(test)]
1437#[repr(C)]
1438#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1439pub struct AccountPaddingTest(MintPaddingTest);
1440#[cfg(test)]
1441impl Extension for AccountPaddingTest {
1442 const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1443}
1444
1445pub fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1460 account_info: &AccountInfo,
1461 new_extension: &V,
1462 overwrite: bool,
1463) -> Result<(), ProgramError> {
1464 let previous_account_len = account_info.try_data_len()?;
1465 let new_account_len = {
1466 let data = account_info.try_borrow_data()?;
1467 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1468 state.try_get_new_account_len::<V>()?
1469 };
1470
1471 if new_account_len > previous_account_len {
1473 account_info.resize(new_account_len)?;
1474 }
1475 let mut buffer = account_info.try_borrow_mut_data()?;
1476 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1477 set_account_type::<S>(*buffer)?;
1478 }
1479 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1480
1481 let extension = state.init_extension::<V>(overwrite)?;
1483 *extension = *new_extension;
1484
1485 Ok(())
1486}
1487
1488pub fn alloc_and_serialize_variable_len_extension<
1497 S: BaseState + Pod,
1498 V: Extension + VariableLenPack,
1499>(
1500 account_info: &AccountInfo,
1501 new_extension: &V,
1502 overwrite: bool,
1503) -> Result<(), ProgramError> {
1504 let previous_account_len = account_info.try_data_len()?;
1505 let (new_account_len, extension_already_exists) = {
1506 let data = account_info.try_borrow_data()?;
1507 let state = PodStateWithExtensions::<S>::unpack(&data)?;
1508 let new_account_len =
1509 state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1510 let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1511 (new_account_len, extension_already_exists)
1512 };
1513
1514 if extension_already_exists && !overwrite {
1515 return Err(TokenError::ExtensionAlreadyInitialized.into());
1516 }
1517
1518 if previous_account_len < new_account_len {
1519 account_info.resize(new_account_len)?;
1522 let mut buffer = account_info.try_borrow_mut_data()?;
1523 if extension_already_exists {
1524 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1525 state.realloc_variable_len_extension(new_extension)?;
1526 } else {
1527 if previous_account_len <= BASE_ACCOUNT_LENGTH {
1528 set_account_type::<S>(*buffer)?;
1529 }
1530 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1532 state.init_variable_len_extension(new_extension, false)?;
1533 }
1534 } else {
1535 let mut buffer = account_info.try_borrow_mut_data()?;
1537 let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1538 if extension_already_exists {
1539 state.realloc_variable_len_extension(new_extension)?;
1540 } else {
1541 state.init_variable_len_extension(new_extension, false)?;
1543 }
1544
1545 let removed_bytes = previous_account_len
1546 .checked_sub(new_account_len)
1547 .ok_or(ProgramError::AccountDataTooSmall)?;
1548 if removed_bytes > 0 {
1549 drop(buffer);
1551 account_info.resize(new_account_len)?;
1552 }
1553 }
1554 Ok(())
1555}
1556
1557#[cfg(test)]
1558mod test {
1559 use {
1560 super::*,
1561 crate::{
1562 pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1563 state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1564 },
1565 bytemuck::Pod,
1566 solana_account_info::{
1567 Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1568 },
1569 solana_address::Address,
1570 solana_nullable::MaybeNull,
1571 solana_zero_copy::unaligned::{Bool, U64},
1572 transfer_fee::test::test_transfer_fee_config,
1573 };
1574
1575 #[repr(C)]
1577 #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1578 struct FixedLenMintTest {
1579 data: [u8; 8],
1580 }
1581 impl Extension for FixedLenMintTest {
1582 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1583 }
1584
1585 #[derive(Clone, Debug, PartialEq)]
1587 struct VariableLenMintTest {
1588 data: Vec<u8>,
1589 }
1590 impl Extension for VariableLenMintTest {
1591 const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1592 }
1593 impl VariableLenPack for VariableLenMintTest {
1594 fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1595 let data_start = size_of::<u64>();
1596 let end = data_start + self.data.len();
1597 if dst.len() < end {
1598 Err(ProgramError::InvalidAccountData)
1599 } else {
1600 dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1601 dst[data_start..end].copy_from_slice(&self.data);
1602 Ok(())
1603 }
1604 }
1605 fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1606 let data_start = size_of::<u64>();
1607 let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1608 if src[data_start..data_start + length].len() != length {
1609 return Err(ProgramError::InvalidAccountData);
1610 }
1611 let data = Vec::from(&src[data_start..data_start + length]);
1612 Ok(Self { data })
1613 }
1614 fn get_packed_len(&self) -> Result<usize, ProgramError> {
1615 Ok(size_of::<u64>().saturating_add(self.data.len()))
1616 }
1617 }
1618
1619 const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1620 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1621 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1622 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1624 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1625 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ];
1628
1629 const MINT_WITH_EXTENSION: &[u8] = &[
1630 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1631 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1632 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1634 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1635 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 32, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1640 1, 1, ];
1642
1643 const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1644 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1645 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1647 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1650 4, 4, 4, 4, 4, 4, 2, 1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1655 7, 7, 7, 7, 7, 7, 2, 15, 0, 1, 0, 1, ];
1661
1662 #[test]
1663 fn unpack_opaque_buffer() {
1664 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1666 assert_eq!(state.base, &TEST_POD_MINT);
1667 let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1668 assert_eq!(state.base, &TEST_POD_MINT);
1669 let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1670 let close_authority: MaybeNull<Address> =
1671 Some(Address::new_from_array([1; 32])).try_into().unwrap();
1672 assert_eq!(extension.close_authority, close_authority);
1673 assert_eq!(
1674 state.get_extension::<TransferFeeConfig>(),
1675 Err(ProgramError::InvalidAccountData)
1676 );
1677 assert_eq!(
1678 PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1679 Err(ProgramError::UninitializedAccount)
1680 );
1681
1682 let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1683 assert_eq!(state.base, &TEST_POD_MINT);
1684
1685 let mut test_mint = TEST_MINT_SLICE.to_vec();
1686 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1687 assert_eq!(state.base, &TEST_POD_MINT);
1688
1689 let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1691 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1692 let extension = state.get_extension::<TransferHookAccount>().unwrap();
1693 let transferring = Bool::from(true);
1694 assert_eq!(extension.transferring, transferring);
1695 assert_eq!(
1696 PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1697 Err(ProgramError::InvalidAccountData)
1698 );
1699
1700 let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1701 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1702
1703 let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1704 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1705 assert_eq!(state.base, &TEST_POD_ACCOUNT);
1706 }
1707
1708 #[test]
1709 fn mint_fail_unpack_opaque_buffer() {
1710 let mut buffer = vec![0, 3];
1712 assert_eq!(
1713 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1714 Err(ProgramError::InvalidAccountData)
1715 );
1716 assert_eq!(
1717 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1718 Err(ProgramError::InvalidAccountData)
1719 );
1720 assert_eq!(
1721 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1722 Err(ProgramError::InvalidAccountData)
1723 );
1724
1725 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1727 buffer[BASE_ACCOUNT_LENGTH] = 3;
1728 assert_eq!(
1729 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1730 Err(ProgramError::InvalidAccountData)
1731 );
1732
1733 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1735 buffer[45] = 0;
1736 assert_eq!(
1737 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1738 Err(ProgramError::UninitializedAccount)
1739 );
1740
1741 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1743 buffer[PodMint::SIZE_OF] = 100;
1744 assert_eq!(
1745 PodStateWithExtensions::<PodMint>::unpack(&buffer),
1746 Err(ProgramError::InvalidAccountData)
1747 );
1748
1749 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1751 buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1752 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1753 assert_eq!(
1754 state.get_extension::<TransferFeeConfig>(),
1755 Err(ProgramError::InvalidAccountData)
1756 );
1757
1758 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1760 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1761 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1762 assert_eq!(
1763 state.get_extension::<TransferFeeConfig>(),
1764 Err(ProgramError::InvalidAccountData)
1765 );
1766
1767 let mut buffer = MINT_WITH_EXTENSION.to_vec();
1769 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1770 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1771 assert_eq!(
1772 state.get_extension::<TransferFeeConfig>(),
1773 Err(ProgramError::InvalidAccountData)
1774 );
1775
1776 let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1778 let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1779 assert_eq!(
1780 state.get_extension::<MintCloseAuthority>(),
1781 Err(ProgramError::InvalidAccountData)
1782 );
1783 }
1784
1785 #[test]
1786 fn account_fail_unpack_opaque_buffer() {
1787 let mut buffer = vec![0, 3];
1789 assert_eq!(
1790 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1791 Err(ProgramError::InvalidAccountData)
1792 );
1793 assert_eq!(
1794 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1795 Err(ProgramError::InvalidAccountData)
1796 );
1797 assert_eq!(
1798 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1799 Err(ProgramError::InvalidAccountData)
1800 );
1801
1802 let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1805 assert_eq!(
1806 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1807 Err(ProgramError::UninitializedAccount)
1808 );
1809 assert_eq!(
1810 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1811 Err(ProgramError::UninitializedAccount)
1812 );
1813
1814 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1816 buffer[BASE_ACCOUNT_LENGTH] = 3;
1817 assert_eq!(
1818 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1819 Err(ProgramError::InvalidAccountData)
1820 );
1821
1822 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1824 buffer[108] = 0;
1825 assert_eq!(
1826 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1827 Err(ProgramError::UninitializedAccount)
1828 );
1829
1830 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1832 buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1833 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1834 assert_eq!(
1835 state.get_extension::<TransferHookAccount>(),
1836 Err(ProgramError::InvalidAccountData),
1837 );
1838
1839 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1841 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1842 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1843 assert_eq!(
1844 state.get_extension::<TransferHookAccount>(),
1845 Err(ProgramError::InvalidAccountData)
1846 );
1847
1848 let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1850 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1851 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1852 assert_eq!(
1853 state.get_extension::<TransferHookAccount>(),
1854 Err(ProgramError::InvalidAccountData)
1855 );
1856
1857 let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1859 let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1860 assert_eq!(
1861 state.get_extension::<TransferHookAccount>(),
1862 Err(ProgramError::InvalidAccountData)
1863 );
1864 }
1865
1866 #[test]
1867 fn get_extension_types_with_opaque_buffer() {
1868 assert_eq!(
1870 get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1871 ProgramError::InvalidAccountData,
1872 );
1873 assert_eq!(
1875 get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1876 ProgramError::InvalidAccountData,
1877 );
1878 assert_eq!(
1880 get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1881 TlvDataInfo {
1882 extension_types: vec![ExtensionType::try_from(1).unwrap()],
1883 used_len: add_type_and_length_to_len(0),
1884 }
1885 );
1886 assert_eq!(
1888 get_tlv_data_info(&[0, 0]).unwrap(),
1889 TlvDataInfo {
1890 extension_types: vec![],
1891 used_len: 0
1892 }
1893 );
1894 }
1895
1896 #[test]
1897 fn mint_with_extension_pack_unpack() {
1898 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1899 ExtensionType::MintCloseAuthority,
1900 ExtensionType::TransferFeeConfig,
1901 ])
1902 .unwrap();
1903 let mut buffer = vec![0; mint_size];
1904
1905 assert_eq!(
1907 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1908 Err(ProgramError::UninitializedAccount),
1909 );
1910
1911 let mut state =
1912 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1913 assert_eq!(
1915 state.init_extension::<TransferFeeAmount>(true),
1916 Err(ProgramError::InvalidAccountData),
1917 );
1918
1919 let close_authority: MaybeNull<Address> =
1921 Some(Address::new_from_array([1; 32])).try_into().unwrap();
1922 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1923 extension.close_authority = close_authority;
1924 assert_eq!(
1925 &state.get_extension_types().unwrap(),
1926 &[ExtensionType::MintCloseAuthority]
1927 );
1928
1929 assert_eq!(
1931 state.init_extension::<MintCloseAuthority>(false),
1932 Err(ProgramError::Custom(
1933 TokenError::ExtensionAlreadyInitialized as u32
1934 ))
1935 );
1936
1937 assert_eq!(
1939 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1940 Err(ProgramError::Custom(
1941 TokenError::ExtensionBaseMismatch as u32
1942 ))
1943 );
1944
1945 assert_eq!(
1947 PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1948 Err(ProgramError::UninitializedAccount),
1949 );
1950
1951 let mut state =
1953 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1954 *state.base = TEST_POD_MINT;
1955 state.init_account_type().unwrap();
1956
1957 let mut expect = TEST_MINT_SLICE.to_vec();
1959 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
1961 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1962 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
1963 expect.extend_from_slice(&[1; 32]); expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1965 expect.extend_from_slice(&[0; size_of::<Length>()]);
1966 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1967 assert_eq!(expect, buffer);
1968
1969 assert_eq!(
1971 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1972 Err(TokenError::AlreadyInUse.into()),
1973 );
1974
1975 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1977
1978 *state.base = TEST_POD_MINT;
1980 state.base.supply = (u64::from(state.base.supply) + 100).into();
1981
1982 let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1984 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1985
1986 let close_authority: MaybeNull<Address> = None.try_into().unwrap();
1988 unpacked_extension.close_authority = close_authority;
1989
1990 let base = *state.base;
1992 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1993 assert_eq!(state.base, &base);
1994 let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1995 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1996
1997 let mut expect = vec![];
1999 expect.extend_from_slice(bytemuck::bytes_of(&base));
2000 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2002 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2003 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2004 expect.extend_from_slice(&[0; 32]);
2005 expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2006 expect.extend_from_slice(&[0; size_of::<Length>()]);
2007 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2008 assert_eq!(expect, buffer);
2009
2010 assert_eq!(
2012 PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2013 Err(ProgramError::UninitializedAccount),
2014 );
2015
2016 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2017 let mint_transfer_fee = test_transfer_fee_config();
2019 let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2020 new_extension.transfer_fee_config_authority =
2021 mint_transfer_fee.transfer_fee_config_authority;
2022 new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2023 new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2024 new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2025 new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2026
2027 assert_eq!(
2028 &state.get_extension_types().unwrap(),
2029 &[
2030 ExtensionType::MintCloseAuthority,
2031 ExtensionType::TransferFeeConfig
2032 ]
2033 );
2034
2035 let mut expect = vec![];
2037 expect.extend_from_slice(bytemuck::bytes_of(&base));
2038 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2040 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2041 expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2042 expect.extend_from_slice(&[0; 32]); expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2044 expect.extend_from_slice(&(size_of::<TransferFeeConfig>() as u16).to_le_bytes());
2045 expect.extend_from_slice(bytemuck::bytes_of(&mint_transfer_fee));
2046 assert_eq!(expect, buffer);
2047
2048 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2050 assert_eq!(
2051 state.init_extension::<MintPaddingTest>(true),
2052 Err(ProgramError::InvalidAccountData),
2053 );
2054 }
2055
2056 #[test]
2057 fn mint_extension_any_order() {
2058 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2059 ExtensionType::MintCloseAuthority,
2060 ExtensionType::TransferFeeConfig,
2061 ])
2062 .unwrap();
2063 let mut buffer = vec![0; mint_size];
2064
2065 let mut state =
2066 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2067 let close_authority: MaybeNull<Address> =
2069 Some(Address::new_from_array([1; 32])).try_into().unwrap();
2070 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2071 extension.close_authority = close_authority;
2072
2073 let mint_transfer_fee = test_transfer_fee_config();
2074 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2075 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2076 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2077 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2078 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2079 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2080
2081 assert_eq!(
2082 &state.get_extension_types().unwrap(),
2083 &[
2084 ExtensionType::MintCloseAuthority,
2085 ExtensionType::TransferFeeConfig
2086 ]
2087 );
2088
2089 let mut state =
2091 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2092 *state.base = TEST_POD_MINT;
2093 state.init_account_type().unwrap();
2094
2095 let mut other_buffer = vec![0; mint_size];
2096 let mut state =
2097 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2098
2099 *state.base = TEST_POD_MINT;
2101 state.init_account_type().unwrap();
2102
2103 let mint_transfer_fee = test_transfer_fee_config();
2105 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2106 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2107 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2108 extension.withheld_amount = mint_transfer_fee.withheld_amount;
2109 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2110 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2111
2112 let close_authority: MaybeNull<Address> =
2113 Some(Address::new_from_array([1; 32])).try_into().unwrap();
2114 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2115 extension.close_authority = close_authority;
2116
2117 assert_eq!(
2118 &state.get_extension_types().unwrap(),
2119 &[
2120 ExtensionType::TransferFeeConfig,
2121 ExtensionType::MintCloseAuthority
2122 ]
2123 );
2124
2125 assert_ne!(buffer, other_buffer);
2127 let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2128 let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2129
2130 assert_eq!(
2132 state.get_extension::<TransferFeeConfig>().unwrap(),
2133 other_state.get_extension::<TransferFeeConfig>().unwrap()
2134 );
2135 assert_eq!(
2136 state.get_extension::<MintCloseAuthority>().unwrap(),
2137 other_state.get_extension::<MintCloseAuthority>().unwrap()
2138 );
2139 assert_eq!(state.base, other_state.base);
2140 }
2141
2142 #[test]
2143 fn mint_with_multisig_len() {
2144 let mut buffer = vec![0; Multisig::LEN];
2145 assert_eq!(
2146 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2147 Err(ProgramError::InvalidAccountData),
2148 );
2149 let mint_size =
2150 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2151 .unwrap();
2152 assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2153 let mut buffer = vec![0; mint_size];
2154
2155 let mut state =
2157 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2158 *state.base = TEST_POD_MINT;
2159 state.init_account_type().unwrap();
2160
2161 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2163 extension.padding1 = [1; 128];
2164 extension.padding2 = [1; 48];
2165 extension.padding3 = [1; 9];
2166
2167 assert_eq!(
2168 &state.get_extension_types().unwrap(),
2169 &[ExtensionType::MintPaddingTest]
2170 );
2171
2172 let mut expect = TEST_MINT_SLICE.to_vec();
2174 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
2176 expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2177 expect.extend_from_slice(&(size_of::<MintPaddingTest>() as u16).to_le_bytes());
2178 expect.extend_from_slice(&vec![1; size_of::<MintPaddingTest>()]);
2179 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2180 assert_eq!(expect, buffer);
2181 }
2182
2183 #[test]
2184 fn account_with_extension_pack_unpack() {
2185 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2186 ExtensionType::TransferFeeAmount,
2187 ])
2188 .unwrap();
2189 let mut buffer = vec![0; account_size];
2190
2191 assert_eq!(
2193 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2194 Err(ProgramError::UninitializedAccount),
2195 );
2196
2197 let mut state =
2198 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2199 assert_eq!(
2201 state.init_extension::<TransferFeeConfig>(true),
2202 Err(ProgramError::InvalidAccountData),
2203 );
2204 let withheld_amount = U64::from(u64::MAX);
2206 let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2207 extension.withheld_amount = withheld_amount;
2208
2209 assert_eq!(
2210 &state.get_extension_types().unwrap(),
2211 &[ExtensionType::TransferFeeAmount]
2212 );
2213
2214 assert_eq!(
2216 PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2217 Err(ProgramError::UninitializedAccount),
2218 );
2219
2220 let mut state =
2222 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2223 *state.base = TEST_POD_ACCOUNT;
2224 state.init_account_type().unwrap();
2225 let base = *state.base;
2226
2227 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2229 expect.push(AccountType::Account.into());
2230 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2231 expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2232 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2233 assert_eq!(expect, buffer);
2234
2235 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2237 assert_eq!(state.base, &base);
2238 assert_eq!(
2239 &state.get_extension_types().unwrap(),
2240 &[ExtensionType::TransferFeeAmount]
2241 );
2242
2243 *state.base = TEST_POD_ACCOUNT;
2245 state.base.amount = (u64::from(state.base.amount) + 100).into();
2246
2247 let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2249 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2250
2251 let withheld_amount = U64::from(u32::MAX as u64);
2253 unpacked_extension.withheld_amount = withheld_amount;
2254
2255 let base = *state.base;
2257 let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2258 assert_eq!(state.base, &base);
2259 let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2260 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2261
2262 let mut expect = vec![];
2264 expect.extend_from_slice(bytemuck::bytes_of(&base));
2265 expect.push(AccountType::Account.into());
2266 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2267 expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2268 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2269 assert_eq!(expect, buffer);
2270
2271 assert_eq!(
2273 PodStateWithExtensions::<PodMint>::unpack(&buffer),
2274 Err(ProgramError::InvalidAccountData),
2275 );
2276 }
2277
2278 #[test]
2279 fn account_with_multisig_len() {
2280 let mut buffer = vec![0; Multisig::LEN];
2281 assert_eq!(
2282 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2283 Err(ProgramError::InvalidAccountData),
2284 );
2285 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2286 ExtensionType::AccountPaddingTest,
2287 ])
2288 .unwrap();
2289 assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2290 let mut buffer = vec![0; account_size];
2291
2292 let mut state =
2294 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2295 *state.base = TEST_POD_ACCOUNT;
2296 state.init_account_type().unwrap();
2297
2298 let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2300 extension.0.padding1 = [2; 128];
2301 extension.0.padding2 = [2; 48];
2302 extension.0.padding3 = [2; 9];
2303
2304 assert_eq!(
2305 &state.get_extension_types().unwrap(),
2306 &[ExtensionType::AccountPaddingTest]
2307 );
2308
2309 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2311 expect.push(AccountType::Account.into());
2312 expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2313 expect.extend_from_slice(&(size_of::<AccountPaddingTest>() as u16).to_le_bytes());
2314 expect.extend_from_slice(&vec![2; size_of::<AccountPaddingTest>()]);
2315 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2316 assert_eq!(expect, buffer);
2317 }
2318
2319 #[test]
2320 fn test_set_account_type() {
2321 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2323 let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2324 ExtensionType::ImmutableOwner,
2325 ])
2326 .unwrap()
2327 - buffer.len();
2328 buffer.append(&mut vec![0; needed_len]);
2329 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2330 assert_eq!(err, ProgramError::InvalidAccountData);
2331 set_account_type::<PodAccount>(&mut buffer).unwrap();
2332 let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2334 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2335 assert_eq!(state.account_type[0], AccountType::Account as u8);
2336 state.init_extension::<ImmutableOwner>(true).unwrap(); let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2340 buffer.append(&mut vec![0; 2]);
2341 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2342 assert_eq!(err, ProgramError::InvalidAccountData);
2343 set_account_type::<PodAccount>(&mut buffer).unwrap();
2344 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2346 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2347 assert_eq!(state.account_type[0], AccountType::Account as u8);
2348
2349 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2351 buffer.append(&mut vec![2, 0]);
2352 let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2353 set_account_type::<PodAccount>(&mut buffer).unwrap();
2354 let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2355 assert_eq!(state.base, &TEST_POD_ACCOUNT);
2356 assert_eq!(state.account_type[0], AccountType::Account as u8);
2357
2358 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2360 buffer.append(&mut vec![1, 0]);
2361 let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2362 assert_eq!(err, ProgramError::InvalidAccountData);
2363 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2364 assert_eq!(err, ProgramError::InvalidAccountData);
2365
2366 let mut buffer = TEST_MINT_SLICE.to_vec();
2368 let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2369 ExtensionType::MintCloseAuthority,
2370 ])
2371 .unwrap()
2372 - buffer.len();
2373 buffer.append(&mut vec![0; needed_len]);
2374 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2375 assert_eq!(err, ProgramError::InvalidAccountData);
2376 set_account_type::<PodMint>(&mut buffer).unwrap();
2377 let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2379 assert_eq!(state.base, &TEST_POD_MINT);
2380 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2381 state.init_extension::<MintCloseAuthority>(true).unwrap();
2382
2383 let mut buffer = TEST_MINT_SLICE.to_vec();
2385 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2386 buffer.append(&mut vec![0; 2]);
2387 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2388 assert_eq!(err, ProgramError::InvalidAccountData);
2389 set_account_type::<PodMint>(&mut buffer).unwrap();
2390 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2392 assert_eq!(state.base, &TEST_POD_MINT);
2393 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2394
2395 let mut buffer = TEST_MINT_SLICE.to_vec();
2397 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2398 buffer.append(&mut vec![1, 0]);
2399 set_account_type::<PodMint>(&mut buffer).unwrap();
2400 let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2401 assert_eq!(state.base, &TEST_POD_MINT);
2402 assert_eq!(state.account_type[0], AccountType::Mint as u8);
2403
2404 let mut buffer = TEST_MINT_SLICE.to_vec();
2406 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2407 buffer.append(&mut vec![2, 0]);
2408 let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2409 assert_eq!(err, ProgramError::InvalidAccountData);
2410 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2411 assert_eq!(err, ProgramError::InvalidAccountData);
2412 }
2413
2414 #[test]
2415 fn test_set_account_type_wrongly() {
2416 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2418 buffer.append(&mut vec![0; 2]);
2419 let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2420 assert_eq!(err, ProgramError::InvalidAccountData);
2421
2422 let mut buffer = TEST_MINT_SLICE.to_vec();
2424 buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2425 buffer.append(&mut vec![0; 2]);
2426 let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2427 assert_eq!(err, ProgramError::InvalidAccountData);
2428 }
2429
2430 #[test]
2431 #[allow(deprecated)]
2432 fn test_get_required_init_account_extensions() {
2433 let mint_extensions = vec![
2435 ExtensionType::MintCloseAuthority,
2436 ExtensionType::Uninitialized,
2437 ];
2438 assert_eq!(
2439 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2440 vec![]
2441 );
2442
2443 let mint_extensions = vec![
2445 ExtensionType::TransferFeeConfig,
2446 ExtensionType::MintCloseAuthority,
2447 ];
2448 assert_eq!(
2449 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2450 vec![ExtensionType::TransferFeeAmount]
2451 );
2452
2453 let mint_extensions = vec![
2455 ExtensionType::TransferFeeConfig,
2456 ExtensionType::MintPaddingTest,
2457 ];
2458 assert_eq!(
2459 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2460 vec![
2461 ExtensionType::TransferFeeAmount,
2462 ExtensionType::AccountPaddingTest
2463 ]
2464 );
2465
2466 let mint_extensions = vec![
2468 ExtensionType::TransferFeeConfig,
2469 ExtensionType::TransferFeeConfig,
2470 ];
2471 assert_eq!(
2472 ExtensionType::get_required_init_account_extensions(&mint_extensions),
2473 vec![
2474 ExtensionType::TransferFeeAmount,
2475 ExtensionType::TransferFeeAmount
2476 ]
2477 );
2478 }
2479
2480 #[test]
2481 fn mint_without_extensions() {
2482 let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2483 let mut buffer = vec![0; space];
2484 assert_eq!(
2485 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2486 Err(ProgramError::InvalidAccountData),
2487 );
2488
2489 let mut state =
2491 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2492 *state.base = TEST_POD_MINT;
2493 state.init_account_type().unwrap();
2494
2495 assert_eq!(
2497 state.init_extension::<TransferFeeConfig>(true),
2498 Err(ProgramError::InvalidAccountData),
2499 );
2500
2501 assert_eq!(TEST_MINT_SLICE, buffer);
2502 }
2503
2504 #[test]
2505 fn test_init_nonzero_default() {
2506 let mint_size =
2507 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2508 .unwrap();
2509 let mut buffer = vec![0; mint_size];
2510 let mut state =
2511 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2512 *state.base = TEST_POD_MINT;
2513 state.init_account_type().unwrap();
2514 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2515 assert_eq!(extension.padding1, [1; 128]);
2516 assert_eq!(extension.padding2, [2; 48]);
2517 assert_eq!(extension.padding3, [3; 9]);
2518 }
2519
2520 #[test]
2521 fn test_init_buffer_too_small() {
2522 let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2523 ExtensionType::MintCloseAuthority,
2524 ])
2525 .unwrap();
2526 let mut buffer = vec![0; mint_size - 1];
2527 let mut state =
2528 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2529 let err = state
2530 .init_extension::<MintCloseAuthority>(true)
2531 .unwrap_err();
2532 assert_eq!(err, ProgramError::InvalidAccountData);
2533
2534 state.tlv_data[0] = 3;
2535 state.tlv_data[2] = 32;
2536 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2537 assert_eq!(err, ProgramError::InvalidAccountData);
2538
2539 let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2540 let err =
2541 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2542 assert_eq!(err, ProgramError::InvalidAccountData);
2543
2544 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2546 let mut state =
2547 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2548 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2549 assert_eq!(err, ProgramError::InvalidAccountData);
2550
2551 assert_eq!(state.get_extension_types().unwrap(), vec![]);
2552
2553 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2555 let state =
2556 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2557 assert_eq!(state.get_extension_types().unwrap(), []);
2558 }
2559
2560 #[test]
2561 fn test_extension_with_no_data() {
2562 let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2563 ExtensionType::ImmutableOwner,
2564 ])
2565 .unwrap();
2566 let mut buffer = vec![0; account_size];
2567 let mut state =
2568 PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2569 *state.base = TEST_POD_ACCOUNT;
2570 state.init_account_type().unwrap();
2571
2572 let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2573 assert_eq!(
2574 err,
2575 ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2576 );
2577
2578 state.init_extension::<ImmutableOwner>(true).unwrap();
2579 assert_eq!(
2580 get_first_extension_type(state.tlv_data).unwrap(),
2581 Some(ExtensionType::ImmutableOwner)
2582 );
2583 assert_eq!(
2584 get_tlv_data_info(state.tlv_data).unwrap(),
2585 TlvDataInfo {
2586 extension_types: vec![ExtensionType::ImmutableOwner],
2587 used_len: add_type_and_length_to_len(0)
2588 }
2589 );
2590 }
2591
2592 #[test]
2593 fn fail_account_len_with_metadata() {
2594 assert_eq!(
2595 ExtensionType::try_calculate_account_len::<PodMint>(&[
2596 ExtensionType::MintCloseAuthority,
2597 ExtensionType::VariableLenMintTest,
2598 ExtensionType::TransferFeeConfig,
2599 ])
2600 .unwrap_err(),
2601 ProgramError::InvalidArgument
2602 );
2603 }
2604
2605 #[test]
2606 fn alloc() {
2607 let variable_len = VariableLenMintTest { data: vec![1] };
2608 let alloc_size = variable_len.get_packed_len().unwrap();
2609 let account_size =
2610 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2611 let mut buffer = vec![0; account_size];
2612 let mut state =
2613 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2614 state
2615 .init_variable_len_extension(&variable_len, false)
2616 .unwrap();
2617
2618 assert_eq!(
2620 state
2621 .init_variable_len_extension(&variable_len, false)
2622 .unwrap_err(),
2623 TokenError::ExtensionAlreadyInitialized.into()
2624 );
2625
2626 state
2628 .init_variable_len_extension(&variable_len, true)
2629 .unwrap();
2630
2631 assert_eq!(
2633 state
2634 .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2635 .unwrap_err(),
2636 TokenError::InvalidLengthForAlloc.into()
2637 );
2638
2639 assert_eq!(
2641 state
2642 .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2643 .unwrap_err(),
2644 ProgramError::InvalidAccountData
2645 );
2646 }
2647
2648 #[test]
2649 fn realloc() {
2650 let small_variable_len = VariableLenMintTest {
2651 data: vec![1, 2, 3],
2652 };
2653 let base_variable_len = VariableLenMintTest {
2654 data: vec![1, 2, 3, 4],
2655 };
2656 let big_variable_len = VariableLenMintTest {
2657 data: vec![1, 2, 3, 4, 5],
2658 };
2659 let too_big_variable_len = VariableLenMintTest {
2660 data: vec![1, 2, 3, 4, 5, 6],
2661 };
2662 let account_size =
2663 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2664 .unwrap()
2665 + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2666 let mut buffer = vec![0; account_size];
2667 let mut state =
2668 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2669
2670 state
2672 .init_variable_len_extension(&base_variable_len, false)
2673 .unwrap();
2674 let max_pubkey: MaybeNull<Address> =
2675 Some(Address::new_from_array([255; 32])).try_into().unwrap();
2676 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2677 extension.authority = max_pubkey;
2678 extension.metadata_address = max_pubkey;
2679
2680 state
2682 .realloc_variable_len_extension(&big_variable_len)
2683 .unwrap();
2684 let extension = state
2685 .get_variable_len_extension::<VariableLenMintTest>()
2686 .unwrap();
2687 assert_eq!(extension, big_variable_len);
2688 let extension = state.get_extension::<MetadataPointer>().unwrap();
2689 assert_eq!(extension.authority, max_pubkey);
2690 assert_eq!(extension.metadata_address, max_pubkey);
2691
2692 state
2694 .realloc_variable_len_extension(&small_variable_len)
2695 .unwrap();
2696 let extension = state
2697 .get_variable_len_extension::<VariableLenMintTest>()
2698 .unwrap();
2699 assert_eq!(extension, small_variable_len);
2700 let extension = state.get_extension::<MetadataPointer>().unwrap();
2701 assert_eq!(extension.authority, max_pubkey);
2702 assert_eq!(extension.metadata_address, max_pubkey);
2703 let diff = big_variable_len.get_packed_len().unwrap()
2704 - small_variable_len.get_packed_len().unwrap();
2705 assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2706
2707 let mut state =
2709 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2710 assert_eq!(
2712 state
2713 .realloc_variable_len_extension(&too_big_variable_len)
2714 .unwrap_err(),
2715 ProgramError::InvalidAccountData,
2716 );
2717 }
2718
2719 #[test]
2720 fn account_len() {
2721 let small_variable_len = VariableLenMintTest {
2722 data: vec![20, 30, 40],
2723 };
2724 let variable_len = VariableLenMintTest {
2725 data: vec![20, 30, 40, 50],
2726 };
2727 let big_variable_len = VariableLenMintTest {
2728 data: vec![20, 30, 40, 50, 60],
2729 };
2730 let value_len = variable_len.get_packed_len().unwrap();
2731 let account_size =
2732 BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2733 let mut buffer = vec![0; account_size];
2734 let mut state =
2735 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2736
2737 let current_len = state.try_get_account_len().unwrap();
2740 assert_eq!(current_len, PodMint::SIZE_OF);
2741 let new_len = state
2742 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2743 &variable_len,
2744 )
2745 .unwrap();
2746 assert_eq!(
2747 new_len,
2748 BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2749 );
2750
2751 state
2752 .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2753 .unwrap();
2754 let current_len = state.try_get_account_len().unwrap();
2755 assert_eq!(current_len, new_len);
2756
2757 let new_len = state
2759 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2760 &small_variable_len,
2761 )
2762 .unwrap();
2763 assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2764
2765 let new_len = state
2767 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2768 &big_variable_len,
2769 )
2770 .unwrap();
2771 assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2772
2773 let new_len = state
2775 .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2776 &variable_len,
2777 )
2778 .unwrap();
2779 assert_eq!(new_len, current_len);
2780 }
2781
2782 struct SolanaAccountData {
2785 data: Vec<u8>,
2786 lamports: u64,
2787 owner: Address,
2788 }
2789 impl SolanaAccountData {
2790 fn new(account_data: &[u8]) -> Self {
2793 let mut data = vec![];
2794 data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2795 data.extend_from_slice(account_data);
2796 data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2797 Self {
2798 data,
2799 lamports: 10,
2800 owner: Address::new_unique(),
2801 }
2802 }
2803
2804 fn data(&self) -> &[u8] {
2807 let start = size_of::<u64>();
2808 let len = self.len();
2809 &self.data[start..start + len]
2810 }
2811
2812 fn len(&self) -> usize {
2814 self.data
2815 .get(..size_of::<u64>())
2816 .and_then(|slice| slice.try_into().ok())
2817 .map(u64::from_le_bytes)
2818 .unwrap() as usize
2819 }
2820 }
2821 impl GetAccount for SolanaAccountData {
2822 fn get(&mut self) -> (&mut u64, &mut [u8], &Address, bool) {
2823 let start = size_of::<u64>();
2825 let len = self.len();
2826 (
2827 &mut self.lamports,
2828 &mut self.data[start..start + len],
2829 &self.owner,
2830 false,
2831 )
2832 }
2833 }
2834
2835 #[test]
2836 fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2837 let fixed_len = FixedLenMintTest {
2838 data: [1, 2, 3, 4, 5, 6, 7, 8],
2839 };
2840 let value_len = size_of::<FixedLenMintTest>();
2841 let base_account_size = PodMint::SIZE_OF;
2842 let mut buffer = vec![0; base_account_size];
2843 let state =
2844 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2845 *state.base = TEST_POD_MINT;
2846
2847 let mut data = SolanaAccountData::new(&buffer);
2848 let key = Address::new_unique();
2849 let account_info = (&key, &mut data).into_account_info();
2850
2851 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2852 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2853 assert_eq!(data.len(), new_account_len);
2854 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2855 assert_eq!(
2856 state.get_extension::<FixedLenMintTest>().unwrap(),
2857 &fixed_len,
2858 );
2859
2860 let account_info = (&key, &mut data).into_account_info();
2862 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2863
2864 let account_info = (&key, &mut data).into_account_info();
2866 assert_eq!(
2867 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2868 TokenError::ExtensionAlreadyInitialized.into()
2869 );
2870 }
2871
2872 #[test]
2873 fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2874 let variable_len = VariableLenMintTest { data: vec![20, 99] };
2875 let value_len = variable_len.get_packed_len().unwrap();
2876 let base_account_size = PodMint::SIZE_OF;
2877 let mut buffer = vec![0; base_account_size];
2878 let state =
2879 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2880 *state.base = TEST_POD_MINT;
2881
2882 let mut data = SolanaAccountData::new(&buffer);
2883 let key = Address::new_unique();
2884 let account_info = (&key, &mut data).into_account_info();
2885
2886 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2887 &account_info,
2888 &variable_len,
2889 false,
2890 )
2891 .unwrap();
2892 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2893 assert_eq!(data.len(), new_account_len);
2894 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2895 assert_eq!(
2896 state
2897 .get_variable_len_extension::<VariableLenMintTest>()
2898 .unwrap(),
2899 variable_len
2900 );
2901
2902 let account_info = (&key, &mut data).into_account_info();
2904 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2905 &account_info,
2906 &variable_len,
2907 true,
2908 )
2909 .unwrap();
2910
2911 let account_info = (&key, &mut data).into_account_info();
2913 assert_eq!(
2914 alloc_and_serialize_variable_len_extension::<PodMint, _>(
2915 &account_info,
2916 &variable_len,
2917 false,
2918 )
2919 .unwrap_err(),
2920 TokenError::ExtensionAlreadyInitialized.into()
2921 );
2922 }
2923
2924 #[test]
2925 fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2926 let fixed_len = FixedLenMintTest {
2927 data: [1, 2, 3, 4, 5, 6, 7, 8],
2928 };
2929 let value_len = size_of::<FixedLenMintTest>();
2930 let account_size =
2931 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2932 .unwrap()
2933 + add_type_and_length_to_len(value_len);
2934 let mut buffer = vec![0; account_size];
2935 let mut state =
2936 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2937 *state.base = TEST_POD_MINT;
2938 state.init_account_type().unwrap();
2939
2940 let test_key: MaybeNull<Address> =
2941 Some(Address::new_from_array([20; 32])).try_into().unwrap();
2942 let extension = state.init_extension::<GroupPointer>(false).unwrap();
2943 extension.authority = test_key;
2944 extension.group_address = test_key;
2945
2946 let mut data = SolanaAccountData::new(&buffer);
2947 let key = Address::new_unique();
2948 let account_info = (&key, &mut data).into_account_info();
2949
2950 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2951 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2952 + add_type_and_length_to_len(value_len)
2953 + add_type_and_length_to_len(size_of::<GroupPointer>());
2954 assert_eq!(data.len(), new_account_len);
2955 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2956 assert_eq!(
2957 state.get_extension::<FixedLenMintTest>().unwrap(),
2958 &fixed_len,
2959 );
2960 let extension = state.get_extension::<GroupPointer>().unwrap();
2961 assert_eq!(extension.authority, test_key);
2962 assert_eq!(extension.group_address, test_key);
2963
2964 let account_info = (&key, &mut data).into_account_info();
2966 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2967
2968 let account_info = (&key, &mut data).into_account_info();
2970 assert_eq!(
2971 alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2972 TokenError::ExtensionAlreadyInitialized.into()
2973 );
2974 }
2975
2976 #[test]
2977 fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2978 let variable_len = VariableLenMintTest { data: vec![42, 6] };
2979 let value_len = variable_len.get_packed_len().unwrap();
2980 let account_size =
2981 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2982 .unwrap()
2983 + add_type_and_length_to_len(value_len);
2984 let mut buffer = vec![0; account_size];
2985 let mut state =
2986 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2987 *state.base = TEST_POD_MINT;
2988 state.init_account_type().unwrap();
2989
2990 let test_key: MaybeNull<Address> =
2991 Some(Address::new_from_array([20; 32])).try_into().unwrap();
2992 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2993 extension.authority = test_key;
2994 extension.metadata_address = test_key;
2995
2996 let mut data = SolanaAccountData::new(&buffer);
2997 let key = Address::new_unique();
2998 let account_info = (&key, &mut data).into_account_info();
2999
3000 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3001 &account_info,
3002 &variable_len,
3003 false,
3004 )
3005 .unwrap();
3006 let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3007 + add_type_and_length_to_len(value_len)
3008 + add_type_and_length_to_len(size_of::<MetadataPointer>());
3009 assert_eq!(data.len(), new_account_len);
3010 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3011 assert_eq!(
3012 state
3013 .get_variable_len_extension::<VariableLenMintTest>()
3014 .unwrap(),
3015 variable_len
3016 );
3017 let extension = state.get_extension::<MetadataPointer>().unwrap();
3018 assert_eq!(extension.authority, test_key);
3019 assert_eq!(extension.metadata_address, test_key);
3020
3021 let account_info = (&key, &mut data).into_account_info();
3023 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3024 &account_info,
3025 &variable_len,
3026 true,
3027 )
3028 .unwrap();
3029
3030 let account_info = (&key, &mut data).into_account_info();
3032 assert_eq!(
3033 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3034 &account_info,
3035 &variable_len,
3036 false,
3037 )
3038 .unwrap_err(),
3039 TokenError::ExtensionAlreadyInitialized.into()
3040 );
3041 }
3042
3043 #[test]
3044 fn realloc_variable_len_tlv_in_account_info() {
3045 let variable_len = VariableLenMintTest {
3046 data: vec![1, 2, 3, 4, 5],
3047 };
3048 let alloc_size = variable_len.get_packed_len().unwrap();
3049 let account_size =
3050 ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3051 .unwrap()
3052 + add_type_and_length_to_len(alloc_size);
3053 let mut buffer = vec![0; account_size];
3054 let mut state =
3055 PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3056 *state.base = TEST_POD_MINT;
3057 state.init_account_type().unwrap();
3058
3059 state
3061 .init_variable_len_extension(&variable_len, false)
3062 .unwrap();
3063 let max_pubkey: MaybeNull<Address> =
3064 Some(Address::new_from_array([255; 32])).try_into().unwrap();
3065 let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3066 extension.authority = max_pubkey;
3067 extension.metadata_address = max_pubkey;
3068
3069 let mut data = SolanaAccountData::new(&buffer);
3071 let key = Address::new_unique();
3072 let account_info = (&key, &mut data).into_account_info();
3073 let variable_len = VariableLenMintTest { data: vec![1, 2] };
3074 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3075 &account_info,
3076 &variable_len,
3077 true,
3078 )
3079 .unwrap();
3080
3081 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3082 let extension = state.get_extension::<MetadataPointer>().unwrap();
3083 assert_eq!(extension.authority, max_pubkey);
3084 assert_eq!(extension.metadata_address, max_pubkey);
3085 let extension = state
3086 .get_variable_len_extension::<VariableLenMintTest>()
3087 .unwrap();
3088 assert_eq!(extension, variable_len);
3089 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3090
3091 let account_info = (&key, &mut data).into_account_info();
3093 let variable_len = VariableLenMintTest {
3094 data: vec![1, 2, 3, 4, 5, 6, 7],
3095 };
3096 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3097 &account_info,
3098 &variable_len,
3099 true,
3100 )
3101 .unwrap();
3102
3103 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3104 let extension = state.get_extension::<MetadataPointer>().unwrap();
3105 assert_eq!(extension.authority, max_pubkey);
3106 assert_eq!(extension.metadata_address, max_pubkey);
3107 let extension = state
3108 .get_variable_len_extension::<VariableLenMintTest>()
3109 .unwrap();
3110 assert_eq!(extension, variable_len);
3111 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3112
3113 let account_info = (&key, &mut data).into_account_info();
3115 let variable_len = VariableLenMintTest {
3116 data: vec![7, 6, 5, 4, 3, 2, 1],
3117 };
3118 alloc_and_serialize_variable_len_extension::<PodMint, _>(
3119 &account_info,
3120 &variable_len,
3121 true,
3122 )
3123 .unwrap();
3124
3125 let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3126 let extension = state.get_extension::<MetadataPointer>().unwrap();
3127 assert_eq!(extension.authority, max_pubkey);
3128 assert_eq!(extension.metadata_address, max_pubkey);
3129 let extension = state
3130 .get_variable_len_extension::<VariableLenMintTest>()
3131 .unwrap();
3132 assert_eq!(extension, variable_len);
3133 assert_eq!(data.len(), state.try_get_account_len().unwrap());
3134 }
3135}