Skip to main content

spl_token_2022_interface/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use {
6    crate::{
7        error::TokenError,
8        extension::{
9            confidential_mint_burn::ConfidentialMintBurn,
10            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11            confidential_transfer_fee::{
12                ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13            },
14            cpi_guard::CpiGuard,
15            default_account_state::DefaultAccountState,
16            group_member_pointer::GroupMemberPointer,
17            group_pointer::GroupPointer,
18            immutable_owner::ImmutableOwner,
19            interest_bearing_mint::InterestBearingConfig,
20            memo_transfer::MemoTransfer,
21            metadata_pointer::MetadataPointer,
22            mint_close_authority::MintCloseAuthority,
23            non_transferable::{NonTransferable, NonTransferableAccount},
24            pausable::{PausableAccount, PausableConfig},
25            permanent_delegate::PermanentDelegate,
26            permissioned_burn::PermissionedBurnConfig,
27            scaled_ui_amount::ScaledUiAmountConfig,
28            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
29            transfer_hook::{TransferHook, TransferHookAccount},
30        },
31        pod::{PodAccount, PodMint},
32        state::{Account, Mint, Multisig, PackedSizeOf},
33    },
34    alloc::{vec, vec::Vec},
35    bytemuck::{Pod, Zeroable},
36    core::{
37        cmp::Ordering,
38        convert::{TryFrom, TryInto},
39        mem::size_of,
40    },
41    num_enum::{IntoPrimitive, TryFromPrimitive},
42    solana_account_info::AccountInfo,
43    solana_program_error::ProgramError,
44    solana_program_pack::{IsInitialized, Pack},
45    solana_zero_copy::unaligned::U16,
46    spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
47    spl_type_length_value::variable_len_pack::VariableLenPack,
48};
49
50/// Confidential Transfer extension
51pub mod confidential_transfer;
52/// Confidential Transfer Fee extension
53pub mod confidential_transfer_fee;
54/// CPI Guard extension
55pub mod cpi_guard;
56/// Default Account State extension
57pub mod default_account_state;
58/// Group Member Pointer extension
59pub mod group_member_pointer;
60/// Group Pointer extension
61pub mod group_pointer;
62/// Immutable Owner extension
63pub mod immutable_owner;
64/// Interest-Bearing Mint extension
65pub mod interest_bearing_mint;
66/// Memo Transfer extension
67pub mod memo_transfer;
68/// Metadata Pointer extension
69pub mod metadata_pointer;
70/// Mint Close Authority extension
71pub mod mint_close_authority;
72/// Non Transferable extension
73pub mod non_transferable;
74/// Pausable extension
75pub mod pausable;
76/// Permanent Delegate extension
77pub mod permanent_delegate;
78/// Permissioned burn extension
79pub mod permissioned_burn;
80/// Scaled UI Amount extension
81pub mod scaled_ui_amount;
82/// Token-group extension
83pub mod token_group;
84/// Token-metadata extension
85pub mod token_metadata;
86/// Transfer Fee extension
87pub mod transfer_fee;
88/// Transfer Hook extension
89pub mod transfer_hook;
90
91/// Confidential mint-burn extension
92pub mod confidential_mint_burn;
93
94/// Length in TLV structure
95#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
96#[repr(transparent)]
97pub struct Length(U16);
98impl From<Length> for usize {
99    fn from(n: Length) -> Self {
100        Self::from(u16::from(n.0))
101    }
102}
103impl TryFrom<usize> for Length {
104    type Error = ProgramError;
105    fn try_from(n: usize) -> Result<Self, Self::Error> {
106        u16::try_from(n)
107            .map(|v| Self(U16::from(v)))
108            .map_err(|_| ProgramError::AccountDataTooSmall)
109    }
110}
111
112/// Helper function to get the current `TlvIndices` from the current spot
113fn get_tlv_indices(type_start: usize) -> TlvIndices {
114    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
115    let value_start = length_start.saturating_add(size_of::<Length>());
116    TlvIndices {
117        type_start,
118        length_start,
119        value_start,
120    }
121}
122
123/// Helper function to tack on the size of an extension bytes if an account with
124/// extensions is exactly the size of a multisig
125const fn adjust_len_for_multisig(account_len: usize) -> usize {
126    if account_len == Multisig::LEN {
127        account_len.saturating_add(size_of::<ExtensionType>())
128    } else {
129        account_len
130    }
131}
132
133/// Helper function to calculate exactly how many bytes a value will take up,
134/// given the value's length
135const fn add_type_and_length_to_len(value_len: usize) -> usize {
136    value_len
137        .saturating_add(size_of::<ExtensionType>())
138        .saturating_add(size_of::<Length>())
139}
140
141/// Helper struct for returning the indices of the type, length, and value in
142/// a TLV entry
143#[derive(Debug)]
144struct TlvIndices {
145    pub type_start: usize,
146    pub length_start: usize,
147    pub value_start: usize,
148}
149fn get_extension_indices<V: Extension>(
150    tlv_data: &[u8],
151    init: bool,
152) -> Result<TlvIndices, ProgramError> {
153    let mut start_index = 0;
154    while start_index < tlv_data.len() {
155        let tlv_indices = get_tlv_indices(start_index);
156        if tlv_data.len() < tlv_indices.value_start {
157            return Err(ProgramError::InvalidAccountData);
158        }
159        let extension_type = u16::from_le_bytes(
160            tlv_data[tlv_indices.type_start..tlv_indices.length_start]
161                .try_into()
162                .map_err(|_| ProgramError::InvalidAccountData)?,
163        );
164        if extension_type == u16::from(V::TYPE) {
165            // found an instance of the extension that we're initializing, return!
166            return Ok(tlv_indices);
167        // got to an empty spot, init here, or error if we're searching, since
168        // nothing is written after an Uninitialized spot
169        } else if extension_type == u16::from(ExtensionType::Uninitialized) {
170            if init {
171                return Ok(tlv_indices);
172            } else {
173                return Err(TokenError::ExtensionNotFound.into());
174            }
175        } else {
176            let length = bytemuck::try_from_bytes::<Length>(
177                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
178            )
179            .map_err(|_| ProgramError::InvalidArgument)?;
180            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
181            start_index = value_end_index;
182        }
183    }
184    Err(ProgramError::InvalidAccountData)
185}
186
187/// Basic information about the TLV buffer, collected from iterating through all
188/// entries
189#[derive(Debug, PartialEq)]
190struct TlvDataInfo {
191    /// The extension types written in the TLV buffer
192    extension_types: Vec<ExtensionType>,
193    /// The total number bytes allocated for all TLV entries.
194    ///
195    /// Each TLV entry's allocated bytes comprises two bytes for the `type`, two
196    /// bytes for the `length`, and `length` number of bytes for the `value`.
197    used_len: usize,
198}
199
200/// Fetches basic information about the TLV buffer by iterating through all
201/// TLV entries.
202fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
203    let mut extension_types = vec![];
204    let mut start_index = 0;
205    while start_index < tlv_data.len() {
206        let tlv_indices = get_tlv_indices(start_index);
207        if tlv_data.len() < tlv_indices.length_start {
208            // There aren't enough bytes to store the next type, which means we
209            // got to the end. The last byte could be used during a realloc!
210            return Ok(TlvDataInfo {
211                extension_types,
212                used_len: tlv_indices.type_start,
213            });
214        }
215        let extension_type =
216            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
217        if extension_type == ExtensionType::Uninitialized {
218            return Ok(TlvDataInfo {
219                extension_types,
220                used_len: tlv_indices.type_start,
221            });
222        } else {
223            if tlv_data.len() < tlv_indices.value_start {
224                // not enough bytes to store the length, malformed
225                return Err(ProgramError::InvalidAccountData);
226            }
227            extension_types.push(extension_type);
228            let length = bytemuck::try_from_bytes::<Length>(
229                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
230            )
231            .map_err(|_| ProgramError::InvalidArgument)?;
232
233            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
234            if value_end_index > tlv_data.len() {
235                // value blows past the size of the slice, malformed
236                return Err(ProgramError::InvalidAccountData);
237            }
238            start_index = value_end_index;
239        }
240    }
241    Ok(TlvDataInfo {
242        extension_types,
243        used_len: start_index,
244    })
245}
246
247fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
248    if tlv_data.is_empty() {
249        Ok(None)
250    } else {
251        let tlv_indices = get_tlv_indices(0);
252        if tlv_data.len() <= tlv_indices.length_start {
253            return Ok(None);
254        }
255        let extension_type =
256            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
257        if extension_type == ExtensionType::Uninitialized {
258            Ok(None)
259        } else {
260            Ok(Some(extension_type))
261        }
262    }
263}
264
265fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
266    if input.len() == Multisig::LEN || input.len() < minimum_len {
267        Err(ProgramError::InvalidAccountData)
268    } else {
269        Ok(())
270    }
271}
272
273fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
274    if account_type != S::ACCOUNT_TYPE {
275        Err(ProgramError::InvalidAccountData)
276    } else {
277        Ok(())
278    }
279}
280
281/// Any account with extensions must be at least `Account::LEN`.  Both mints and
282/// accounts can have extensions
283/// A mint with extensions that takes it past 165 could be indiscernible from an
284/// Account with an extension, even if we add the account type. For example,
285/// let's say we have:
286///
287/// ```text
288/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
289///                          ^     ^       ^     ^
290///                     acct type  extension length data...
291///
292/// Mint: 82 bytes... + 83 bytes of other extension data
293///     + [2, 0, 3, 0, 100, ....]
294///      (data in extension just happens to look like this)
295/// ```
296///
297/// With this approach, we only start writing the TLV data after `Account::LEN`,
298/// which means we always know that the account type is going to be right after
299/// that. We do a special case checking for a Multisig length, because those
300/// aren't extensible under any circumstances.
301const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
302/// Helper that tacks on the `AccountType` length, which gives the minimum for
303/// any account with extensions
304const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
305
306fn type_and_tlv_indices<S: BaseState>(
307    rest_input: &[u8],
308) -> Result<Option<(usize, usize)>, ProgramError> {
309    if rest_input.is_empty() {
310        Ok(None)
311    } else {
312        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
313        // check padding is all zeroes
314        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
315        if rest_input.len() < tlv_start_index {
316            return Err(ProgramError::InvalidAccountData);
317        }
318        if rest_input[..account_type_index] != vec![0; account_type_index] {
319            Err(ProgramError::InvalidAccountData)
320        } else {
321            Ok(Some((account_type_index, tlv_start_index)))
322        }
323    }
324}
325
326/// Checks a base buffer to verify if it is an Account without having to
327/// completely deserialize it
328fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
329    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
330
331    if input.len() != BASE_ACCOUNT_LENGTH {
332        return Err(ProgramError::InvalidAccountData);
333    }
334    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
335}
336
337fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
338    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
339        return Err(ProgramError::InvalidAccountData);
340    }
341    let TlvIndices {
342        type_start: _,
343        length_start,
344        value_start,
345    } = get_extension_indices::<V>(tlv_data, false)?;
346    // get_extension_indices has checked that tlv_data is long enough to include
347    // these indices
348    let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
349        .map_err(|_| ProgramError::InvalidArgument)?;
350    let value_end = value_start.saturating_add(usize::from(*length));
351    if tlv_data.len() < value_end {
352        return Err(ProgramError::InvalidAccountData);
353    }
354    Ok(&tlv_data[value_start..value_end])
355}
356
357fn get_extension_bytes_mut<S: BaseState, V: Extension>(
358    tlv_data: &mut [u8],
359) -> Result<&mut [u8], ProgramError> {
360    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
361        return Err(ProgramError::InvalidAccountData);
362    }
363    let TlvIndices {
364        type_start: _,
365        length_start,
366        value_start,
367    } = get_extension_indices::<V>(tlv_data, false)?;
368    // get_extension_indices has checked that tlv_data is long enough to include
369    // these indices
370    let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
371        .map_err(|_| ProgramError::InvalidArgument)?;
372    let value_end = value_start.saturating_add(usize::from(*length));
373    if tlv_data.len() < value_end {
374        return Err(ProgramError::InvalidAccountData);
375    }
376    Ok(&mut tlv_data[value_start..value_end])
377}
378
379/// Calculate the new expected size if the state allocates the given number
380/// of bytes for the given extension type.
381///
382/// Provides the correct answer regardless if the extension is already present
383/// in the TLV data.
384fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
385    tlv_data: &[u8],
386    new_extension_len: usize,
387) -> Result<usize, ProgramError> {
388    // get the new length used by the extension
389    let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
390    let tlv_info = get_tlv_data_info(tlv_data)?;
391    // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
392    // and account type
393    let current_len = tlv_info
394        .used_len
395        .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
396    // get the current length used by the extension
397    let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
398        .map(|x| add_type_and_length_to_len(x.len()))
399        .unwrap_or(0);
400    let new_len = current_len
401        .saturating_sub(current_extension_len)
402        .saturating_add(new_extension_tlv_len);
403    Ok(adjust_len_for_multisig(new_len))
404}
405
406/// Trait for base state with extension
407pub trait BaseStateWithExtensions<S: BaseState> {
408    /// Get the buffer containing all extension data
409    fn get_tlv_data(&self) -> &[u8];
410
411    /// Fetch the bytes for a TLV entry
412    fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
413        get_extension_bytes::<S, V>(self.get_tlv_data())
414    }
415
416    /// Unpack a portion of the TLV data as the desired type
417    fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
418        bytemuck::try_from_bytes::<V>(self.get_extension_bytes::<V>()?)
419            .map_err(|_| ProgramError::InvalidArgument)
420    }
421
422    /// Unpacks a portion of the TLV data as the desired variable-length type
423    fn get_variable_len_extension<V: Extension + VariableLenPack>(
424        &self,
425    ) -> Result<V, ProgramError> {
426        let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
427        V::unpack_from_slice(data)
428    }
429
430    /// Iterates through the TLV entries, returning only the types
431    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
432        get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
433    }
434
435    /// Get just the first extension type, useful to track mixed initialization
436    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
437        get_first_extension_type(self.get_tlv_data())
438    }
439
440    /// Get the total number of bytes used by TLV entries and the base type
441    fn try_get_account_len(&self) -> Result<usize, ProgramError> {
442        let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
443        if tlv_info.extension_types.is_empty() {
444            Ok(S::SIZE_OF)
445        } else {
446            let total_len = tlv_info
447                .used_len
448                .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
449            Ok(adjust_len_for_multisig(total_len))
450        }
451    }
452    /// Calculate the new expected size if the state allocates the given
453    /// fixed-length extension instance.
454    /// If the state already has the extension, the resulting account length
455    /// will be unchanged.
456    fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
457        try_get_new_account_len_for_extension_len::<S, V>(self.get_tlv_data(), size_of::<V>())
458    }
459
460    /// Calculate the new expected size if the state allocates the given
461    /// variable-length extension instance.
462    fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
463        &self,
464        new_extension: &V,
465    ) -> Result<usize, ProgramError> {
466        try_get_new_account_len_for_extension_len::<S, V>(
467            self.get_tlv_data(),
468            new_extension.get_packed_len()?,
469        )
470    }
471}
472
473/// Encapsulates owned immutable base state data (mint or account) with possible
474/// extensions
475#[derive(Clone, Debug, PartialEq)]
476pub struct StateWithExtensionsOwned<S: BaseState> {
477    /// Unpacked base data
478    pub base: S,
479    /// Raw TLV data, deserialized on demand
480    tlv_data: Vec<u8>,
481}
482impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
483    /// Unpack base state, leaving the extension data as a slice
484    ///
485    /// Fails if the base state is not initialized.
486    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
487        check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
488        let mut rest = input.split_off(S::SIZE_OF);
489        let base = S::unpack(&input)?;
490        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
491            // type_and_tlv_indices() checks that returned indexes are within range
492            let account_type = AccountType::try_from(rest[account_type_index])
493                .map_err(|_| ProgramError::InvalidAccountData)?;
494            check_account_type::<S>(account_type)?;
495            let tlv_data = rest.split_off(tlv_start_index);
496            Ok(Self { base, tlv_data })
497        } else {
498            Ok(Self {
499                base,
500                tlv_data: vec![],
501            })
502        }
503    }
504}
505
506impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
507    fn get_tlv_data(&self) -> &[u8] {
508        &self.tlv_data
509    }
510}
511
512/// Encapsulates immutable base state data (mint or account) with possible
513/// extensions
514#[derive(Debug, PartialEq)]
515pub struct StateWithExtensions<'data, S: BaseState + Pack> {
516    /// Unpacked base data
517    pub base: S,
518    /// Slice of data containing all TLV data, deserialized on demand
519    tlv_data: &'data [u8],
520}
521impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
522    /// Unpack base state, leaving the extension data as a slice
523    ///
524    /// Fails if the base state is not initialized.
525    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
526        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
527        let (base_data, rest) = input.split_at(S::SIZE_OF);
528        let base = S::unpack(base_data)?;
529        let tlv_data = unpack_tlv_data::<S>(rest)?;
530        Ok(Self { base, tlv_data })
531    }
532}
533impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
534    fn get_tlv_data(&self) -> &[u8] {
535        self.tlv_data
536    }
537}
538
539/// Encapsulates immutable base state data (mint or account) with possible
540/// extensions, where the base state is Pod for zero-copy serde.
541#[derive(Debug, PartialEq)]
542pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
543    /// Unpacked base data
544    pub base: &'data S,
545    /// Slice of data containing all TLV data, deserialized on demand
546    tlv_data: &'data [u8],
547}
548impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
549    /// Unpack base state, leaving the extension data as a slice
550    ///
551    /// Fails if the base state is not initialized.
552    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
553        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
554        let (base_data, rest) = input.split_at(S::SIZE_OF);
555        let base =
556            bytemuck::try_from_bytes::<S>(base_data).map_err(|_| ProgramError::InvalidArgument)?;
557        if !base.is_initialized() {
558            Err(ProgramError::UninitializedAccount)
559        } else {
560            let tlv_data = unpack_tlv_data::<S>(rest)?;
561            Ok(Self { base, tlv_data })
562        }
563    }
564}
565impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
566    fn get_tlv_data(&self) -> &[u8] {
567        self.tlv_data
568    }
569}
570
571/// Trait for mutable base state with extension
572pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
573    /// Get the underlying TLV data as mutable
574    fn get_tlv_data_mut(&mut self) -> &mut [u8];
575
576    /// Get the underlying account type as mutable
577    fn get_account_type_mut(&mut self) -> &mut [u8];
578
579    /// Unpack a portion of the TLV data as the base mutable bytes
580    fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
581        get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
582    }
583
584    /// Unpack a portion of the TLV data as the desired type that allows
585    /// modifying the type
586    fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
587        bytemuck::try_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
588            .map_err(|_| ProgramError::InvalidArgument)
589    }
590
591    /// Packs a variable-length extension into its appropriate data segment.
592    /// Fails if space hasn't already been allocated for the given extension
593    fn pack_variable_len_extension<V: Extension + VariableLenPack>(
594        &mut self,
595        extension: &V,
596    ) -> Result<(), ProgramError> {
597        let data = self.get_extension_bytes_mut::<V>()?;
598        // NOTE: Do *not* use `pack`, since the length check will cause
599        // reallocations to smaller sizes to fail
600        extension.pack_into_slice(data)
601    }
602
603    /// Packs the default extension data into an open slot if not already found
604    /// in the data buffer. If extension is already found in the buffer, it
605    /// overwrites the existing extension with the default state if
606    /// `overwrite` is set. If extension found, but `overwrite` is not set,
607    /// it returns error.
608    fn init_extension<V: Extension + Pod + Default>(
609        &mut self,
610        overwrite: bool,
611    ) -> Result<&mut V, ProgramError> {
612        let length = size_of::<V>();
613        let buffer = self.alloc::<V>(length, overwrite)?;
614        let extension_ref =
615            bytemuck::try_from_bytes_mut::<V>(buffer).map_err(|_| ProgramError::InvalidArgument)?;
616        *extension_ref = V::default();
617        Ok(extension_ref)
618    }
619
620    /// Reallocate and overwrite the TLV entry for the given variable-length
621    /// extension.
622    ///
623    /// Returns an error if the extension is not present, or if there is not
624    /// enough space in the buffer.
625    fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
626        &mut self,
627        new_extension: &V,
628    ) -> Result<(), ProgramError> {
629        let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
630        new_extension.pack_into_slice(data)
631    }
632
633    /// Reallocate the TLV entry for the given extension to the given number of
634    /// bytes.
635    ///
636    /// If the new length is smaller, it will compact the rest of the buffer and
637    /// zero out the difference at the end. If it's larger, it will move the
638    /// rest of the buffer data and zero out the new data.
639    ///
640    /// Returns an error if the extension is not present, or if this is not
641    /// enough space in the buffer.
642    fn realloc<V: Extension + VariableLenPack>(
643        &mut self,
644        length: usize,
645    ) -> Result<&mut [u8], ProgramError> {
646        let tlv_data = self.get_tlv_data_mut();
647        let TlvIndices {
648            type_start: _,
649            length_start,
650            value_start,
651        } = get_extension_indices::<V>(tlv_data, false)?;
652        let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
653        let data_len = tlv_data.len();
654
655        let length_ref =
656            bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
657                .map_err(|_| ProgramError::InvalidArgument)?;
658        let old_length = usize::from(*length_ref);
659
660        // Length check to avoid a panic later in `copy_within`
661        if old_length < length {
662            let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
663            if new_tlv_len > data_len {
664                return Err(ProgramError::InvalidAccountData);
665            }
666        }
667
668        // write new length after the check, to avoid getting into a bad situation
669        // if trying to recover from an error
670        *length_ref = Length::try_from(length)?;
671
672        let old_value_end = value_start.saturating_add(old_length);
673        let new_value_end = value_start.saturating_add(length);
674        tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
675        match old_length.cmp(&length) {
676            Ordering::Greater => {
677                // realloc to smaller, zero out the end
678                let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
679                tlv_data[new_tlv_len..tlv_len].fill(0);
680            }
681            Ordering::Less => {
682                // realloc to bigger, zero out the new bytes
683                tlv_data[old_value_end..new_value_end].fill(0);
684            }
685            Ordering::Equal => {} // nothing needed!
686        }
687
688        Ok(&mut tlv_data[value_start..new_value_end])
689    }
690
691    /// Allocate the given number of bytes for the given variable-length
692    /// extension and write its contents into the TLV buffer.
693    ///
694    /// This can only be used for variable-sized types, such as `String` or
695    /// `Vec`. `Pod` types must use `init_extension`
696    fn init_variable_len_extension<V: Extension + VariableLenPack>(
697        &mut self,
698        extension: &V,
699        overwrite: bool,
700    ) -> Result<(), ProgramError> {
701        let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
702        extension.pack_into_slice(data)
703    }
704
705    /// Allocate some space for the extension in the TLV data
706    fn alloc<V: Extension>(
707        &mut self,
708        length: usize,
709        overwrite: bool,
710    ) -> Result<&mut [u8], ProgramError> {
711        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
712            return Err(ProgramError::InvalidAccountData);
713        }
714        let tlv_data = self.get_tlv_data_mut();
715        let TlvIndices {
716            type_start,
717            length_start,
718            value_start,
719        } = get_extension_indices::<V>(tlv_data, true)?;
720
721        if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
722            return Err(ProgramError::InvalidAccountData);
723        }
724        let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
725
726        if extension_type == ExtensionType::Uninitialized || overwrite {
727            // write extension type
728            let extension_type_array: [u8; 2] = V::TYPE.into();
729            let extension_type_ref = &mut tlv_data[type_start..length_start];
730            extension_type_ref.copy_from_slice(&extension_type_array);
731            // write length
732            let length_ref =
733                bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
734                    .map_err(|_| ProgramError::InvalidArgument)?;
735
736            // check that the length is the same if we're doing an alloc
737            // with overwrite, otherwise a realloc should be done
738            if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
739                return Err(TokenError::InvalidLengthForAlloc.into());
740            }
741
742            *length_ref = Length::try_from(length)?;
743
744            let value_end = value_start.saturating_add(length);
745            Ok(&mut tlv_data[value_start..value_end])
746        } else {
747            // extension is already initialized, but no overwrite permission
748            Err(TokenError::ExtensionAlreadyInitialized.into())
749        }
750    }
751
752    /// If `extension_type` is an Account-associated `ExtensionType` that
753    /// requires initialization on `InitializeAccount`, this method packs
754    /// the default relevant `Extension` of an `ExtensionType` into an open
755    /// slot if not already found in the data buffer, otherwise overwrites
756    /// the existing extension with the default state. For all other
757    /// `ExtensionType`s, this is a no-op.
758    fn init_account_extension_from_type(
759        &mut self,
760        extension_type: ExtensionType,
761    ) -> Result<(), ProgramError> {
762        if extension_type.get_account_type() != AccountType::Account {
763            return Ok(());
764        }
765        match extension_type {
766            ExtensionType::TransferFeeAmount => {
767                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
768            }
769            ExtensionType::ImmutableOwner => {
770                self.init_extension::<ImmutableOwner>(true).map(|_| ())
771            }
772            ExtensionType::NonTransferableAccount => self
773                .init_extension::<NonTransferableAccount>(true)
774                .map(|_| ()),
775            ExtensionType::TransferHookAccount => {
776                self.init_extension::<TransferHookAccount>(true).map(|_| ())
777            }
778            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
779            // on InitializeAccount
780            ExtensionType::ConfidentialTransferAccount => Ok(()),
781            ExtensionType::PausableAccount => {
782                self.init_extension::<PausableAccount>(true).map(|_| ())
783            }
784            #[cfg(test)]
785            ExtensionType::AccountPaddingTest => {
786                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
787            }
788            _ => unreachable!(),
789        }
790    }
791
792    /// Write the account type into the buffer, done during the base
793    /// state initialization
794    /// Noops if there is no room for an extension in the account, needed for
795    /// pure base mints / accounts.
796    fn init_account_type(&mut self) -> Result<(), ProgramError> {
797        let first_extension_type = self.get_first_extension_type()?;
798        let account_type = self.get_account_type_mut();
799        if !account_type.is_empty() {
800            if let Some(extension_type) = first_extension_type {
801                let account_type = extension_type.get_account_type();
802                if account_type != S::ACCOUNT_TYPE {
803                    return Err(TokenError::ExtensionBaseMismatch.into());
804                }
805            }
806            account_type[0] = S::ACCOUNT_TYPE.into();
807        }
808        Ok(())
809    }
810
811    /// Check that the account type on the account (if initialized) matches the
812    /// account type for any extensions initialized on the TLV data
813    fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
814        if let Some(extension_type) = self.get_first_extension_type()? {
815            let account_type = extension_type.get_account_type();
816            if account_type != S::ACCOUNT_TYPE {
817                return Err(TokenError::ExtensionBaseMismatch.into());
818            }
819        }
820        Ok(())
821    }
822}
823
824/// Encapsulates mutable base state data (mint or account) with possible
825/// extensions
826#[derive(Debug, PartialEq)]
827pub struct StateWithExtensionsMut<'data, S: BaseState> {
828    /// Unpacked base data
829    pub base: S,
830    /// Raw base data
831    base_data: &'data mut [u8],
832    /// Writable account type
833    account_type: &'data mut [u8],
834    /// Slice of data containing all TLV data, deserialized on demand
835    tlv_data: &'data mut [u8],
836}
837impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
838    /// Unpack base state, leaving the extension data as a mutable slice
839    ///
840    /// Fails if the base state is not initialized.
841    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
842        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
843        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
844        let base = S::unpack(base_data)?;
845        let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
846        Ok(Self {
847            base,
848            base_data,
849            account_type,
850            tlv_data,
851        })
852    }
853
854    /// Unpack an uninitialized base state, leaving the extension data as a
855    /// mutable slice
856    ///
857    /// Fails if the base state has already been initialized.
858    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
859        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
860        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
861        let base = S::unpack_unchecked(base_data)?;
862        if base.is_initialized() {
863            return Err(TokenError::AlreadyInUse.into());
864        }
865        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
866        let state = Self {
867            base,
868            base_data,
869            account_type,
870            tlv_data,
871        };
872        state.check_account_type_matches_extension_type()?;
873        Ok(state)
874    }
875
876    /// Packs base state data into the base data portion
877    pub fn pack_base(&mut self) {
878        S::pack_into_slice(&self.base, self.base_data);
879    }
880}
881impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
882    fn get_tlv_data(&self) -> &[u8] {
883        self.tlv_data
884    }
885}
886impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
887    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
888        self.tlv_data
889    }
890    fn get_account_type_mut(&mut self) -> &mut [u8] {
891        self.account_type
892    }
893}
894
895/// Encapsulates mutable base state data (mint or account) with possible
896/// extensions, where the base state is Pod for zero-copy serde.
897#[derive(Debug, PartialEq)]
898pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
899    /// Unpacked base data
900    pub base: &'data mut S,
901    /// Writable account type
902    account_type: &'data mut [u8],
903    /// Slice of data containing all TLV data, deserialized on demand
904    tlv_data: &'data mut [u8],
905}
906impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
907    /// Unpack base state, leaving the extension data as a mutable slice
908    ///
909    /// Fails if the base state is not initialized.
910    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
911        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
912        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
913        let base = bytemuck::try_from_bytes_mut::<S>(base_data)
914            .map_err(|_| ProgramError::InvalidArgument)?;
915        if !base.is_initialized() {
916            Err(ProgramError::UninitializedAccount)
917        } else {
918            let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
919            Ok(Self {
920                base,
921                account_type,
922                tlv_data,
923            })
924        }
925    }
926
927    /// Unpack an uninitialized base state, leaving the extension data as a
928    /// mutable slice
929    ///
930    /// Fails if the base state has already been initialized.
931    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
932        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
933        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
934        let base = bytemuck::try_from_bytes_mut::<S>(base_data)
935            .map_err(|_| ProgramError::InvalidArgument)?;
936        if base.is_initialized() {
937            return Err(TokenError::AlreadyInUse.into());
938        }
939        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
940        let state = Self {
941            base,
942            account_type,
943            tlv_data,
944        };
945        state.check_account_type_matches_extension_type()?;
946        Ok(state)
947    }
948}
949
950impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
951    fn get_tlv_data(&self) -> &[u8] {
952        self.tlv_data
953    }
954}
955impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
956    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
957        self.tlv_data
958    }
959    fn get_account_type_mut(&mut self) -> &mut [u8] {
960        self.account_type
961    }
962}
963
964fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
965    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
966        // type_and_tlv_indices() checks that returned indexes are within range
967        let account_type = AccountType::try_from(rest[account_type_index])
968            .map_err(|_| ProgramError::InvalidAccountData)?;
969        check_account_type::<S>(account_type)?;
970        Ok(&rest[tlv_start_index..])
971    } else {
972        Ok(&[])
973    }
974}
975
976fn unpack_type_and_tlv_data_with_check_mut<
977    S: BaseState,
978    F: Fn(AccountType) -> Result<(), ProgramError>,
979>(
980    rest: &mut [u8],
981    check_fn: F,
982) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
983    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
984        // type_and_tlv_indices() checks that returned indexes are within range
985        let account_type = AccountType::try_from(rest[account_type_index])
986            .map_err(|_| ProgramError::InvalidAccountData)?;
987        check_fn(account_type)?;
988        let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
989        Ok((
990            &mut account_type[account_type_index..tlv_start_index],
991            tlv_data,
992        ))
993    } else {
994        Ok((&mut [], &mut []))
995    }
996}
997
998fn unpack_type_and_tlv_data_mut<S: BaseState>(
999    rest: &mut [u8],
1000) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1001    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
1002}
1003
1004fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
1005    rest: &mut [u8],
1006) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1007    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1008        if account_type != AccountType::Uninitialized {
1009            Err(ProgramError::InvalidAccountData)
1010        } else {
1011            Ok(())
1012        }
1013    })
1014}
1015
1016/// If `AccountType` is uninitialized, set it to the `BaseState`'s
1017/// `ACCOUNT_TYPE`; if `AccountType` is already set, check is set correctly for
1018/// `BaseState`. This method assumes that the `base_data` has already been
1019/// packed with data of the desired type.
1020pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1021    check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1022    let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1023    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1024        return Err(ProgramError::InvalidAccountData);
1025    }
1026    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1027        let mut account_type = AccountType::try_from(rest[account_type_index])
1028            .map_err(|_| ProgramError::InvalidAccountData)?;
1029        if account_type == AccountType::Uninitialized {
1030            rest[account_type_index] = S::ACCOUNT_TYPE.into();
1031            account_type = S::ACCOUNT_TYPE;
1032        }
1033        check_account_type::<S>(account_type)?;
1034        Ok(())
1035    } else {
1036        Err(ProgramError::InvalidAccountData)
1037    }
1038}
1039
1040/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig`
1041/// types are determined exclusively by the size of the account, and are not
1042/// included in the account data. `AccountType` is only included if extensions
1043/// have been initialized.
1044#[repr(u8)]
1045#[derive(Clone, Copy, Debug, Default, PartialEq, TryFromPrimitive, IntoPrimitive)]
1046pub enum AccountType {
1047    /// Marker for 0 data
1048    #[default]
1049    Uninitialized,
1050    /// Mint account with additional extensions
1051    Mint,
1052    /// Token holding account with additional extensions
1053    Account,
1054}
1055
1056/// Extensions that can be applied to mints or accounts.  Mint extensions must
1057/// only be applied to mint accounts, and account extensions must only be
1058/// applied to token holding accounts.
1059#[repr(u16)]
1060#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1061#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
1062#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1063pub enum ExtensionType {
1064    /// Used as padding if the account size would otherwise be 355, same as a
1065    /// multisig
1066    Uninitialized,
1067    /// Includes transfer fee rate info and accompanying authorities to withdraw
1068    /// and set the fee
1069    TransferFeeConfig,
1070    /// Includes withheld transfer fees
1071    TransferFeeAmount,
1072    /// Includes an optional mint close authority
1073    MintCloseAuthority,
1074    /// Auditor configuration for confidential transfers
1075    ConfidentialTransferMint,
1076    /// State for confidential transfers
1077    ConfidentialTransferAccount,
1078    /// Specifies the default Account::state for new Accounts
1079    DefaultAccountState,
1080    /// Indicates that the Account owner authority cannot be changed
1081    ImmutableOwner,
1082    /// Require inbound transfers to have memo
1083    MemoTransfer,
1084    /// Indicates that the tokens from this mint can't be transferred
1085    NonTransferable,
1086    /// Tokens accrue interest over time,
1087    InterestBearingConfig,
1088    /// Locks privileged token operations from happening via CPI
1089    CpiGuard,
1090    /// Includes an optional permanent delegate
1091    PermanentDelegate,
1092    /// Indicates that the tokens in this account belong to a non-transferable
1093    /// mint
1094    NonTransferableAccount,
1095    /// Mint requires a CPI to a program implementing the "transfer hook"
1096    /// interface
1097    TransferHook,
1098    /// Indicates that the tokens in this account belong to a mint with a
1099    /// transfer hook
1100    TransferHookAccount,
1101    /// Includes encrypted withheld fees and the encryption public that they are
1102    /// encrypted under
1103    ConfidentialTransferFeeConfig,
1104    /// Includes confidential withheld transfer fees
1105    ConfidentialTransferFeeAmount,
1106    /// Mint contains a pointer to another account (or the same account) that
1107    /// holds metadata
1108    MetadataPointer,
1109    /// Mint contains token-metadata
1110    TokenMetadata,
1111    /// Mint contains a pointer to another account (or the same account) that
1112    /// holds group configurations
1113    GroupPointer,
1114    /// Mint contains token group configurations
1115    TokenGroup,
1116    /// Mint contains a pointer to another account (or the same account) that
1117    /// holds group member configurations
1118    GroupMemberPointer,
1119    /// Mint contains token group member configurations
1120    TokenGroupMember,
1121    /// Mint allowing the minting and burning of confidential tokens
1122    ConfidentialMintBurn,
1123    /// Tokens whose UI amount is scaled by a given amount
1124    ScaledUiAmount,
1125    /// Tokens where minting / burning / transferring can be paused
1126    Pausable,
1127    /// Indicates that the account belongs to a pausable mint
1128    PausableAccount,
1129    /// Tokens burning requires approval from authority.
1130    PermissionedBurn,
1131
1132    /// Test variable-length mint extension
1133    #[cfg(test)]
1134    VariableLenMintTest = u16::MAX - 2,
1135    /// Padding extension used to make an account exactly Multisig::LEN, used
1136    /// for testing
1137    #[cfg(test)]
1138    AccountPaddingTest,
1139    /// Padding extension used to make a mint exactly Multisig::LEN, used for
1140    /// testing
1141    #[cfg(test)]
1142    MintPaddingTest,
1143}
1144impl TryFrom<&[u8]> for ExtensionType {
1145    type Error = ProgramError;
1146    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1147        Self::try_from(u16::from_le_bytes(
1148            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1149        ))
1150        .map_err(|_| ProgramError::InvalidAccountData)
1151    }
1152}
1153impl From<ExtensionType> for [u8; 2] {
1154    fn from(a: ExtensionType) -> Self {
1155        u16::from(a).to_le_bytes()
1156    }
1157}
1158impl ExtensionType {
1159    /// Returns true if the given extension type is sized
1160    ///
1161    /// Most extension types should be sized, so any variable-length extension
1162    /// types should be added here by hand
1163    const fn sized(&self) -> bool {
1164        match self {
1165            ExtensionType::TokenMetadata => false,
1166            #[cfg(test)]
1167            ExtensionType::VariableLenMintTest => false,
1168            _ => true,
1169        }
1170    }
1171
1172    /// Get the data length of the type associated with the enum
1173    ///
1174    /// Fails if the extension type has a variable length
1175    fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1176        if !self.sized() {
1177            return Err(ProgramError::InvalidArgument);
1178        }
1179        Ok(match self {
1180            ExtensionType::Uninitialized => 0,
1181            ExtensionType::TransferFeeConfig => size_of::<TransferFeeConfig>(),
1182            ExtensionType::TransferFeeAmount => size_of::<TransferFeeAmount>(),
1183            ExtensionType::MintCloseAuthority => size_of::<MintCloseAuthority>(),
1184            ExtensionType::ImmutableOwner => size_of::<ImmutableOwner>(),
1185            ExtensionType::ConfidentialTransferMint => size_of::<ConfidentialTransferMint>(),
1186            ExtensionType::ConfidentialTransferAccount => size_of::<ConfidentialTransferAccount>(),
1187            ExtensionType::DefaultAccountState => size_of::<DefaultAccountState>(),
1188            ExtensionType::MemoTransfer => size_of::<MemoTransfer>(),
1189            ExtensionType::NonTransferable => size_of::<NonTransferable>(),
1190            ExtensionType::InterestBearingConfig => size_of::<InterestBearingConfig>(),
1191            ExtensionType::CpiGuard => size_of::<CpiGuard>(),
1192            ExtensionType::PermanentDelegate => size_of::<PermanentDelegate>(),
1193            ExtensionType::NonTransferableAccount => size_of::<NonTransferableAccount>(),
1194            ExtensionType::TransferHook => size_of::<TransferHook>(),
1195            ExtensionType::TransferHookAccount => size_of::<TransferHookAccount>(),
1196            ExtensionType::ConfidentialTransferFeeConfig => {
1197                size_of::<ConfidentialTransferFeeConfig>()
1198            }
1199            ExtensionType::ConfidentialTransferFeeAmount => {
1200                size_of::<ConfidentialTransferFeeAmount>()
1201            }
1202            ExtensionType::MetadataPointer => size_of::<MetadataPointer>(),
1203            ExtensionType::TokenMetadata => unreachable!(),
1204            ExtensionType::GroupPointer => size_of::<GroupPointer>(),
1205            ExtensionType::TokenGroup => size_of::<TokenGroup>(),
1206            ExtensionType::GroupMemberPointer => size_of::<GroupMemberPointer>(),
1207            ExtensionType::TokenGroupMember => size_of::<TokenGroupMember>(),
1208            ExtensionType::ConfidentialMintBurn => size_of::<ConfidentialMintBurn>(),
1209            ExtensionType::ScaledUiAmount => size_of::<ScaledUiAmountConfig>(),
1210            ExtensionType::Pausable => size_of::<PausableConfig>(),
1211            ExtensionType::PausableAccount => size_of::<PausableAccount>(),
1212            ExtensionType::PermissionedBurn => size_of::<PermissionedBurnConfig>(),
1213            #[cfg(test)]
1214            ExtensionType::AccountPaddingTest => size_of::<AccountPaddingTest>(),
1215            #[cfg(test)]
1216            ExtensionType::MintPaddingTest => size_of::<MintPaddingTest>(),
1217            #[cfg(test)]
1218            ExtensionType::VariableLenMintTest => unreachable!(),
1219        })
1220    }
1221
1222    /// Get the TLV length for an `ExtensionType`
1223    ///
1224    /// Fails if the extension type has a variable length
1225    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1226        Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1227    }
1228
1229    /// Get the TLV length for a set of `ExtensionType`s
1230    ///
1231    /// Fails if any of the extension types has a variable length
1232    fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1233        // dedupe extensions
1234        let mut extensions = vec![];
1235        for extension_type in extension_types {
1236            if !extensions.contains(&extension_type) {
1237                extensions.push(extension_type);
1238            }
1239        }
1240        extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1241    }
1242
1243    /// Get the required account data length for the given `ExtensionType`s
1244    ///
1245    /// Fails if any of the extension types has a variable length
1246    pub fn try_calculate_account_len<S: BaseState>(
1247        extension_types: &[Self],
1248    ) -> Result<usize, ProgramError> {
1249        if extension_types.is_empty() {
1250            Ok(S::SIZE_OF)
1251        } else {
1252            let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1253            let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1254            Ok(adjust_len_for_multisig(total_len))
1255        }
1256    }
1257
1258    /// Get the associated account type
1259    pub fn get_account_type(&self) -> AccountType {
1260        match self {
1261            ExtensionType::Uninitialized => AccountType::Uninitialized,
1262            ExtensionType::TransferFeeConfig
1263            | ExtensionType::MintCloseAuthority
1264            | ExtensionType::ConfidentialTransferMint
1265            | ExtensionType::DefaultAccountState
1266            | ExtensionType::NonTransferable
1267            | ExtensionType::InterestBearingConfig
1268            | ExtensionType::PermanentDelegate
1269            | ExtensionType::TransferHook
1270            | ExtensionType::ConfidentialTransferFeeConfig
1271            | ExtensionType::MetadataPointer
1272            | ExtensionType::TokenMetadata
1273            | ExtensionType::GroupPointer
1274            | ExtensionType::TokenGroup
1275            | ExtensionType::GroupMemberPointer
1276            | ExtensionType::ConfidentialMintBurn
1277            | ExtensionType::TokenGroupMember
1278            | ExtensionType::ScaledUiAmount
1279            | ExtensionType::Pausable
1280            | ExtensionType::PermissionedBurn => AccountType::Mint,
1281            ExtensionType::ImmutableOwner
1282            | ExtensionType::TransferFeeAmount
1283            | ExtensionType::ConfidentialTransferAccount
1284            | ExtensionType::MemoTransfer
1285            | ExtensionType::NonTransferableAccount
1286            | ExtensionType::TransferHookAccount
1287            | ExtensionType::CpiGuard
1288            | ExtensionType::ConfidentialTransferFeeAmount
1289            | ExtensionType::PausableAccount => AccountType::Account,
1290            #[cfg(test)]
1291            ExtensionType::VariableLenMintTest => AccountType::Mint,
1292            #[cfg(test)]
1293            ExtensionType::AccountPaddingTest => AccountType::Account,
1294            #[cfg(test)]
1295            ExtensionType::MintPaddingTest => AccountType::Mint,
1296        }
1297    }
1298
1299    /// Based on a set of `AccountType::Mint` `ExtensionType`s, get the list of
1300    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1301    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1302        let mut account_extension_types = vec![];
1303        for extension_type in mint_extension_types {
1304            match extension_type {
1305                ExtensionType::TransferFeeConfig => {
1306                    account_extension_types.push(ExtensionType::TransferFeeAmount);
1307                }
1308                ExtensionType::NonTransferable => {
1309                    account_extension_types.push(ExtensionType::NonTransferableAccount);
1310                    account_extension_types.push(ExtensionType::ImmutableOwner);
1311                }
1312                ExtensionType::TransferHook => {
1313                    account_extension_types.push(ExtensionType::TransferHookAccount);
1314                }
1315                ExtensionType::Pausable => {
1316                    account_extension_types.push(ExtensionType::PausableAccount);
1317                }
1318                #[cfg(test)]
1319                ExtensionType::MintPaddingTest => {
1320                    account_extension_types.push(ExtensionType::AccountPaddingTest);
1321                }
1322                _ => {}
1323            }
1324        }
1325        account_extension_types
1326    }
1327
1328    /// Check for invalid combination of mint extensions
1329    pub fn check_for_invalid_mint_extension_combinations(
1330        mint_extension_types: &[Self],
1331    ) -> Result<(), TokenError> {
1332        let mut transfer_fee_config = false;
1333        let mut confidential_transfer_mint = false;
1334        let mut confidential_transfer_fee_config = false;
1335        let mut confidential_mint_burn = false;
1336        let mut interest_bearing = false;
1337        let mut scaled_ui_amount = false;
1338        let mut non_transferable = false;
1339
1340        for extension_type in mint_extension_types {
1341            match extension_type {
1342                ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1343                ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1344                ExtensionType::ConfidentialTransferFeeConfig => {
1345                    confidential_transfer_fee_config = true
1346                }
1347                ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1348                ExtensionType::InterestBearingConfig => interest_bearing = true,
1349                ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1350                ExtensionType::NonTransferable => non_transferable = true,
1351                _ => (),
1352            }
1353        }
1354
1355        if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1356        {
1357            return Err(TokenError::InvalidExtensionCombination);
1358        }
1359
1360        if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1361            return Err(TokenError::InvalidExtensionCombination);
1362        }
1363
1364        if confidential_mint_burn && !confidential_transfer_mint {
1365            return Err(TokenError::InvalidExtensionCombination);
1366        }
1367
1368        if scaled_ui_amount && interest_bearing {
1369            return Err(TokenError::InvalidExtensionCombination);
1370        }
1371
1372        if non_transferable && confidential_transfer_mint && !confidential_mint_burn {
1373            return Err(TokenError::InvalidExtensionCombination);
1374        }
1375
1376        Ok(())
1377    }
1378}
1379
1380/// Trait for base states, specifying the associated enum
1381pub trait BaseState: PackedSizeOf + IsInitialized {
1382    /// Associated extension type enum, checked at the start of TLV entries
1383    const ACCOUNT_TYPE: AccountType;
1384}
1385impl BaseState for Account {
1386    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1387}
1388impl BaseState for Mint {
1389    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1390}
1391impl BaseState for PodAccount {
1392    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1393}
1394impl BaseState for PodMint {
1395    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1396}
1397
1398/// Trait to be implemented by all extension states, specifying which extension
1399/// and account type they are associated with
1400pub trait Extension {
1401    /// Associated extension type enum, checked at the start of TLV entries
1402    const TYPE: ExtensionType;
1403}
1404
1405/// Padding a mint account to be exactly `Multisig::LEN`.
1406/// We need to pad 185 bytes, since `Multisig::LEN = 355`, `Account::LEN = 165`,
1407/// `size_of::<AccountType>() = 1`, `size_of::<ExtensionType>() = 2`,
1408/// `size_of::<Length>() = 2`.
1409///
1410/// ```
1411/// assert_eq!(355 - 165 - 1 - 2 - 2, 185);
1412/// ```
1413#[cfg(test)]
1414#[repr(C)]
1415#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1416pub struct MintPaddingTest {
1417    /// Largest value under 185 that implements Pod
1418    pub padding1: [u8; 128],
1419    /// Largest value under 57 that implements Pod
1420    pub padding2: [u8; 48],
1421    /// Exact value needed to finish the padding
1422    pub padding3: [u8; 9],
1423}
1424#[cfg(test)]
1425impl Extension for MintPaddingTest {
1426    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1427}
1428#[cfg(test)]
1429impl Default for MintPaddingTest {
1430    fn default() -> Self {
1431        Self {
1432            padding1: [1; 128],
1433            padding2: [2; 48],
1434            padding3: [3; 9],
1435        }
1436    }
1437}
1438/// Account version of the `MintPadding`
1439#[cfg(test)]
1440#[repr(C)]
1441#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1442pub struct AccountPaddingTest(MintPaddingTest);
1443#[cfg(test)]
1444impl Extension for AccountPaddingTest {
1445    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1446}
1447
1448/// Packs a fixed-length extension into a TLV space
1449///
1450/// This function reallocates the account as needed to accommodate for the
1451/// change in space.
1452///
1453/// If the extension already exists, it will overwrite the existing extension
1454/// if `overwrite` is `true`, otherwise it will return an error.
1455///
1456/// If the extension does not exist, it will reallocate the account and write
1457/// the extension into the TLV buffer.
1458///
1459/// NOTE: Since this function deals with fixed-size extensions, it does not
1460/// handle _decreasing_ the size of an account's data buffer, like the function
1461/// `alloc_and_serialize_variable_len_extension` does.
1462pub fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1463    account_info: &AccountInfo,
1464    new_extension: &V,
1465    overwrite: bool,
1466) -> Result<(), ProgramError> {
1467    let previous_account_len = account_info.try_data_len()?;
1468    let new_account_len = {
1469        let data = account_info.try_borrow_data()?;
1470        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1471        state.try_get_new_account_len::<V>()?
1472    };
1473
1474    // Realloc the account first, if needed
1475    if new_account_len > previous_account_len {
1476        account_info.resize(new_account_len)?;
1477    }
1478    let mut buffer = account_info.try_borrow_mut_data()?;
1479    if previous_account_len <= BASE_ACCOUNT_LENGTH {
1480        set_account_type::<S>(*buffer)?;
1481    }
1482    let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1483
1484    // Write the extension
1485    let extension = state.init_extension::<V>(overwrite)?;
1486    *extension = *new_extension;
1487
1488    Ok(())
1489}
1490
1491/// Packs a variable-length extension into a TLV space
1492///
1493/// This function reallocates the account as needed to accommodate for the
1494/// change in space, then reallocates in the TLV buffer, and finally writes the
1495/// bytes.
1496///
1497/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1498/// size of an account if it has too many bytes allocated for the given value.
1499pub fn alloc_and_serialize_variable_len_extension<
1500    S: BaseState + Pod,
1501    V: Extension + VariableLenPack,
1502>(
1503    account_info: &AccountInfo,
1504    new_extension: &V,
1505    overwrite: bool,
1506) -> Result<(), ProgramError> {
1507    let previous_account_len = account_info.try_data_len()?;
1508    let (new_account_len, extension_already_exists) = {
1509        let data = account_info.try_borrow_data()?;
1510        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1511        let new_account_len =
1512            state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1513        let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1514        (new_account_len, extension_already_exists)
1515    };
1516
1517    if extension_already_exists && !overwrite {
1518        return Err(TokenError::ExtensionAlreadyInitialized.into());
1519    }
1520
1521    if previous_account_len < new_account_len {
1522        // account size increased, so realloc the account, then the TLV entry, then
1523        // write data
1524        account_info.resize(new_account_len)?;
1525        let mut buffer = account_info.try_borrow_mut_data()?;
1526        if extension_already_exists {
1527            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1528            state.realloc_variable_len_extension(new_extension)?;
1529        } else {
1530            if previous_account_len <= BASE_ACCOUNT_LENGTH {
1531                set_account_type::<S>(*buffer)?;
1532            }
1533            // now alloc in the TLV buffer and write the data
1534            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1535            state.init_variable_len_extension(new_extension, false)?;
1536        }
1537    } else {
1538        // do it backwards otherwise, write the state, realloc TLV, then the account
1539        let mut buffer = account_info.try_borrow_mut_data()?;
1540        let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1541        if extension_already_exists {
1542            state.realloc_variable_len_extension(new_extension)?;
1543        } else {
1544            // this situation can happen if we have an overallocated buffer
1545            state.init_variable_len_extension(new_extension, false)?;
1546        }
1547
1548        let removed_bytes = previous_account_len
1549            .checked_sub(new_account_len)
1550            .ok_or(ProgramError::AccountDataTooSmall)?;
1551        if removed_bytes > 0 {
1552            // this is probably fine, but be safe and avoid invalidating references
1553            drop(buffer);
1554            account_info.resize(new_account_len)?;
1555        }
1556    }
1557    Ok(())
1558}
1559
1560#[cfg(test)]
1561mod test {
1562    use {
1563        super::*,
1564        crate::{
1565            pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1566            state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1567        },
1568        bytemuck::Pod,
1569        solana_account_info::{
1570            Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1571        },
1572        solana_address::Address,
1573        solana_nullable::MaybeNull,
1574        solana_zero_copy::unaligned::{Bool, U64},
1575        transfer_fee::test::test_transfer_fee_config,
1576    };
1577
1578    /// Test fixed-length struct
1579    #[repr(C)]
1580    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1581    struct FixedLenMintTest {
1582        data: [u8; 8],
1583    }
1584    impl Extension for FixedLenMintTest {
1585        const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1586    }
1587
1588    /// Test variable-length struct
1589    #[derive(Clone, Debug, PartialEq)]
1590    struct VariableLenMintTest {
1591        data: Vec<u8>,
1592    }
1593    impl Extension for VariableLenMintTest {
1594        const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1595    }
1596    impl VariableLenPack for VariableLenMintTest {
1597        fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1598            let data_start = size_of::<u64>();
1599            let end = data_start + self.data.len();
1600            if dst.len() < end {
1601                Err(ProgramError::InvalidAccountData)
1602            } else {
1603                dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1604                dst[data_start..end].copy_from_slice(&self.data);
1605                Ok(())
1606            }
1607        }
1608        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1609            let data_start = size_of::<u64>();
1610            let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1611            if src[data_start..data_start + length].len() != length {
1612                return Err(ProgramError::InvalidAccountData);
1613            }
1614            let data = Vec::from(&src[data_start..data_start + length]);
1615            Ok(Self { data })
1616        }
1617        fn get_packed_len(&self) -> Result<usize, ProgramError> {
1618            Ok(size_of::<u64>().saturating_add(self.data.len()))
1619        }
1620    }
1621
1622    const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1623        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1624        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1625        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1626        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1627        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1628        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1629        1, // account type
1630    ];
1631
1632    const MINT_WITH_EXTENSION: &[u8] = &[
1633        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1634        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1635        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1636        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1637        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1638        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1639        1, // account type
1640        3, 0, // extension type
1641        32, 0, // length
1642        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1643        1, 1, // data
1644    ];
1645
1646    const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1647        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1648        1, 1, // mint
1649        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1650        2, 2, // owner
1651        3, 0, 0, 0, 0, 0, 0, 0, // amount
1652        1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1653        4, 4, 4, 4, 4, 4, // delegate
1654        2, // account state
1655        1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, // is native
1656        6, 0, 0, 0, 0, 0, 0, 0, // delegated amount
1657        1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1658        7, 7, 7, 7, 7, 7, // close authority
1659        2, // account type
1660        15, 0, // extension type
1661        1, 0, // length
1662        1, // data
1663    ];
1664
1665    #[test]
1666    fn unpack_opaque_buffer() {
1667        // Mint
1668        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1669        assert_eq!(state.base, &TEST_POD_MINT);
1670        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1671        assert_eq!(state.base, &TEST_POD_MINT);
1672        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1673        let close_authority: MaybeNull<Address> =
1674            Some(Address::new_from_array([1; 32])).try_into().unwrap();
1675        assert_eq!(extension.close_authority, close_authority);
1676        assert_eq!(
1677            state.get_extension::<TransferFeeConfig>(),
1678            Err(ProgramError::InvalidAccountData)
1679        );
1680        assert_eq!(
1681            PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1682            Err(ProgramError::UninitializedAccount)
1683        );
1684
1685        let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1686        assert_eq!(state.base, &TEST_POD_MINT);
1687
1688        let mut test_mint = TEST_MINT_SLICE.to_vec();
1689        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1690        assert_eq!(state.base, &TEST_POD_MINT);
1691
1692        // Account
1693        let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1694        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1695        let extension = state.get_extension::<TransferHookAccount>().unwrap();
1696        let transferring = Bool::from(true);
1697        assert_eq!(extension.transferring, transferring);
1698        assert_eq!(
1699            PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1700            Err(ProgramError::InvalidAccountData)
1701        );
1702
1703        let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1704        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1705
1706        let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1707        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1708        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1709    }
1710
1711    #[test]
1712    fn mint_fail_unpack_opaque_buffer() {
1713        // input buffer too small
1714        let mut buffer = vec![0, 3];
1715        assert_eq!(
1716            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1717            Err(ProgramError::InvalidAccountData)
1718        );
1719        assert_eq!(
1720            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1721            Err(ProgramError::InvalidAccountData)
1722        );
1723        assert_eq!(
1724            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1725            Err(ProgramError::InvalidAccountData)
1726        );
1727
1728        // tweak the account type
1729        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1730        buffer[BASE_ACCOUNT_LENGTH] = 3;
1731        assert_eq!(
1732            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1733            Err(ProgramError::InvalidAccountData)
1734        );
1735
1736        // clear the mint initialized byte
1737        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1738        buffer[45] = 0;
1739        assert_eq!(
1740            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1741            Err(ProgramError::UninitializedAccount)
1742        );
1743
1744        // tweak the padding
1745        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1746        buffer[PodMint::SIZE_OF] = 100;
1747        assert_eq!(
1748            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1749            Err(ProgramError::InvalidAccountData)
1750        );
1751
1752        // tweak the extension type
1753        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1754        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1755        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1756        assert_eq!(
1757            state.get_extension::<TransferFeeConfig>(),
1758            Err(ProgramError::InvalidAccountData)
1759        );
1760
1761        // tweak the length, too big
1762        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1763        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1764        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1765        assert_eq!(
1766            state.get_extension::<TransferFeeConfig>(),
1767            Err(ProgramError::InvalidAccountData)
1768        );
1769
1770        // tweak the length, too small
1771        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1772        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1773        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1774        assert_eq!(
1775            state.get_extension::<TransferFeeConfig>(),
1776            Err(ProgramError::InvalidAccountData)
1777        );
1778
1779        // data buffer is too small
1780        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1781        let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1782        assert_eq!(
1783            state.get_extension::<MintCloseAuthority>(),
1784            Err(ProgramError::InvalidAccountData)
1785        );
1786    }
1787
1788    #[test]
1789    fn account_fail_unpack_opaque_buffer() {
1790        // input buffer too small
1791        let mut buffer = vec![0, 3];
1792        assert_eq!(
1793            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1794            Err(ProgramError::InvalidAccountData)
1795        );
1796        assert_eq!(
1797            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1798            Err(ProgramError::InvalidAccountData)
1799        );
1800        assert_eq!(
1801            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1802            Err(ProgramError::InvalidAccountData)
1803        );
1804
1805        // input buffer invalid
1806        // all 5's - not a valid `AccountState`
1807        let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1808        assert_eq!(
1809            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1810            Err(ProgramError::UninitializedAccount)
1811        );
1812        assert_eq!(
1813            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1814            Err(ProgramError::UninitializedAccount)
1815        );
1816
1817        // tweak the account type
1818        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1819        buffer[BASE_ACCOUNT_LENGTH] = 3;
1820        assert_eq!(
1821            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1822            Err(ProgramError::InvalidAccountData)
1823        );
1824
1825        // clear the state byte
1826        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1827        buffer[108] = 0;
1828        assert_eq!(
1829            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1830            Err(ProgramError::UninitializedAccount)
1831        );
1832
1833        // tweak the extension type
1834        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1835        buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1836        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1837        assert_eq!(
1838            state.get_extension::<TransferHookAccount>(),
1839            Err(ProgramError::InvalidAccountData),
1840        );
1841
1842        // tweak the length, too big
1843        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1844        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1845        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1846        assert_eq!(
1847            state.get_extension::<TransferHookAccount>(),
1848            Err(ProgramError::InvalidAccountData)
1849        );
1850
1851        // tweak the length, too small
1852        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1853        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1854        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1855        assert_eq!(
1856            state.get_extension::<TransferHookAccount>(),
1857            Err(ProgramError::InvalidAccountData)
1858        );
1859
1860        // data buffer is too small
1861        let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1862        let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1863        assert_eq!(
1864            state.get_extension::<TransferHookAccount>(),
1865            Err(ProgramError::InvalidAccountData)
1866        );
1867    }
1868
1869    #[test]
1870    fn get_extension_types_with_opaque_buffer() {
1871        // incorrect due to the length
1872        assert_eq!(
1873            get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1874            ProgramError::InvalidAccountData,
1875        );
1876        // incorrect due to the huge enum number
1877        assert_eq!(
1878            get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1879            ProgramError::InvalidAccountData,
1880        );
1881        // correct due to the good enum number and zero length
1882        assert_eq!(
1883            get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1884            TlvDataInfo {
1885                extension_types: vec![ExtensionType::try_from(1).unwrap()],
1886                used_len: add_type_and_length_to_len(0),
1887            }
1888        );
1889        // correct since it's just uninitialized data at the end
1890        assert_eq!(
1891            get_tlv_data_info(&[0, 0]).unwrap(),
1892            TlvDataInfo {
1893                extension_types: vec![],
1894                used_len: 0
1895            }
1896        );
1897    }
1898
1899    #[test]
1900    fn mint_with_extension_pack_unpack() {
1901        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1902            ExtensionType::MintCloseAuthority,
1903            ExtensionType::TransferFeeConfig,
1904        ])
1905        .unwrap();
1906        let mut buffer = vec![0; mint_size];
1907
1908        // fail unpack
1909        assert_eq!(
1910            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1911            Err(ProgramError::UninitializedAccount),
1912        );
1913
1914        let mut state =
1915            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1916        // fail init account extension
1917        assert_eq!(
1918            state.init_extension::<TransferFeeAmount>(true),
1919            Err(ProgramError::InvalidAccountData),
1920        );
1921
1922        // success write extension
1923        let close_authority: MaybeNull<Address> =
1924            Some(Address::new_from_array([1; 32])).try_into().unwrap();
1925        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1926        extension.close_authority = close_authority;
1927        assert_eq!(
1928            &state.get_extension_types().unwrap(),
1929            &[ExtensionType::MintCloseAuthority]
1930        );
1931
1932        // fail init extension when already initialized
1933        assert_eq!(
1934            state.init_extension::<MintCloseAuthority>(false),
1935            Err(ProgramError::Custom(
1936                TokenError::ExtensionAlreadyInitialized as u32
1937            ))
1938        );
1939
1940        // fail unpack as account, a mint extension was written
1941        assert_eq!(
1942            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1943            Err(ProgramError::Custom(
1944                TokenError::ExtensionBaseMismatch as u32
1945            ))
1946        );
1947
1948        // fail unpack again, still no base data
1949        assert_eq!(
1950            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1951            Err(ProgramError::UninitializedAccount),
1952        );
1953
1954        // write base mint
1955        let mut state =
1956            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1957        *state.base = TEST_POD_MINT;
1958        state.init_account_type().unwrap();
1959
1960        // check raw buffer
1961        let mut expect = TEST_MINT_SLICE.to_vec();
1962        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1963        expect.push(AccountType::Mint.into());
1964        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1965        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
1966        expect.extend_from_slice(&[1; 32]); // data
1967        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1968        expect.extend_from_slice(&[0; size_of::<Length>()]);
1969        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1970        assert_eq!(expect, buffer);
1971
1972        // unpack uninitialized will now fail because the PodMint is now initialized
1973        assert_eq!(
1974            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1975            Err(TokenError::AlreadyInUse.into()),
1976        );
1977
1978        // check unpacking
1979        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1980
1981        // update base
1982        *state.base = TEST_POD_MINT;
1983        state.base.supply = (u64::from(state.base.supply) + 100).into();
1984
1985        // check unpacking
1986        let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1987        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1988
1989        // update extension
1990        let close_authority: MaybeNull<Address> = None.try_into().unwrap();
1991        unpacked_extension.close_authority = close_authority;
1992
1993        // check updates are propagated
1994        let base = *state.base;
1995        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1996        assert_eq!(state.base, &base);
1997        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1998        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1999
2000        // check raw buffer
2001        let mut expect = vec![];
2002        expect.extend_from_slice(bytemuck::bytes_of(&base));
2003        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2004        expect.push(AccountType::Mint.into());
2005        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2006        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2007        expect.extend_from_slice(&[0; 32]);
2008        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2009        expect.extend_from_slice(&[0; size_of::<Length>()]);
2010        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2011        assert_eq!(expect, buffer);
2012
2013        // fail unpack as an account
2014        assert_eq!(
2015            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2016            Err(ProgramError::UninitializedAccount),
2017        );
2018
2019        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2020        // init one more extension
2021        let mint_transfer_fee = test_transfer_fee_config();
2022        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2023        new_extension.transfer_fee_config_authority =
2024            mint_transfer_fee.transfer_fee_config_authority;
2025        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2026        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2027        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2028        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2029
2030        assert_eq!(
2031            &state.get_extension_types().unwrap(),
2032            &[
2033                ExtensionType::MintCloseAuthority,
2034                ExtensionType::TransferFeeConfig
2035            ]
2036        );
2037
2038        // check raw buffer
2039        let mut expect = vec![];
2040        expect.extend_from_slice(bytemuck::bytes_of(&base));
2041        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2042        expect.push(AccountType::Mint.into());
2043        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2044        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2045        expect.extend_from_slice(&[0; 32]); // data
2046        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2047        expect.extend_from_slice(&(size_of::<TransferFeeConfig>() as u16).to_le_bytes());
2048        expect.extend_from_slice(bytemuck::bytes_of(&mint_transfer_fee));
2049        assert_eq!(expect, buffer);
2050
2051        // fail to init one more extension that does not fit
2052        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2053        assert_eq!(
2054            state.init_extension::<MintPaddingTest>(true),
2055            Err(ProgramError::InvalidAccountData),
2056        );
2057    }
2058
2059    #[test]
2060    fn mint_extension_any_order() {
2061        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2062            ExtensionType::MintCloseAuthority,
2063            ExtensionType::TransferFeeConfig,
2064        ])
2065        .unwrap();
2066        let mut buffer = vec![0; mint_size];
2067
2068        let mut state =
2069            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2070        // write extensions
2071        let close_authority: MaybeNull<Address> =
2072            Some(Address::new_from_array([1; 32])).try_into().unwrap();
2073        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2074        extension.close_authority = close_authority;
2075
2076        let mint_transfer_fee = test_transfer_fee_config();
2077        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2078        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2079        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2080        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2081        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2082        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2083
2084        assert_eq!(
2085            &state.get_extension_types().unwrap(),
2086            &[
2087                ExtensionType::MintCloseAuthority,
2088                ExtensionType::TransferFeeConfig
2089            ]
2090        );
2091
2092        // write base mint
2093        let mut state =
2094            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2095        *state.base = TEST_POD_MINT;
2096        state.init_account_type().unwrap();
2097
2098        let mut other_buffer = vec![0; mint_size];
2099        let mut state =
2100            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2101
2102        // write base mint
2103        *state.base = TEST_POD_MINT;
2104        state.init_account_type().unwrap();
2105
2106        // write extensions in a different order
2107        let mint_transfer_fee = test_transfer_fee_config();
2108        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2109        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2110        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2111        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2112        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2113        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2114
2115        let close_authority: MaybeNull<Address> =
2116            Some(Address::new_from_array([1; 32])).try_into().unwrap();
2117        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2118        extension.close_authority = close_authority;
2119
2120        assert_eq!(
2121            &state.get_extension_types().unwrap(),
2122            &[
2123                ExtensionType::TransferFeeConfig,
2124                ExtensionType::MintCloseAuthority
2125            ]
2126        );
2127
2128        // buffers are NOT the same because written in a different order
2129        assert_ne!(buffer, other_buffer);
2130        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2131        let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2132
2133        // BUT mint and extensions are the same
2134        assert_eq!(
2135            state.get_extension::<TransferFeeConfig>().unwrap(),
2136            other_state.get_extension::<TransferFeeConfig>().unwrap()
2137        );
2138        assert_eq!(
2139            state.get_extension::<MintCloseAuthority>().unwrap(),
2140            other_state.get_extension::<MintCloseAuthority>().unwrap()
2141        );
2142        assert_eq!(state.base, other_state.base);
2143    }
2144
2145    #[test]
2146    fn mint_with_multisig_len() {
2147        let mut buffer = vec![0; Multisig::LEN];
2148        assert_eq!(
2149            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2150            Err(ProgramError::InvalidAccountData),
2151        );
2152        let mint_size =
2153            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2154                .unwrap();
2155        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2156        let mut buffer = vec![0; mint_size];
2157
2158        // write base mint
2159        let mut state =
2160            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2161        *state.base = TEST_POD_MINT;
2162        state.init_account_type().unwrap();
2163
2164        // write padding
2165        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2166        extension.padding1 = [1; 128];
2167        extension.padding2 = [1; 48];
2168        extension.padding3 = [1; 9];
2169
2170        assert_eq!(
2171            &state.get_extension_types().unwrap(),
2172            &[ExtensionType::MintPaddingTest]
2173        );
2174
2175        // check raw buffer
2176        let mut expect = TEST_MINT_SLICE.to_vec();
2177        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2178        expect.push(AccountType::Mint.into());
2179        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2180        expect.extend_from_slice(&(size_of::<MintPaddingTest>() as u16).to_le_bytes());
2181        expect.extend_from_slice(&vec![1; size_of::<MintPaddingTest>()]);
2182        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2183        assert_eq!(expect, buffer);
2184    }
2185
2186    #[test]
2187    fn account_with_extension_pack_unpack() {
2188        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2189            ExtensionType::TransferFeeAmount,
2190        ])
2191        .unwrap();
2192        let mut buffer = vec![0; account_size];
2193
2194        // fail unpack
2195        assert_eq!(
2196            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2197            Err(ProgramError::UninitializedAccount),
2198        );
2199
2200        let mut state =
2201            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2202        // fail init mint extension
2203        assert_eq!(
2204            state.init_extension::<TransferFeeConfig>(true),
2205            Err(ProgramError::InvalidAccountData),
2206        );
2207        // success write extension
2208        let withheld_amount = U64::from(u64::MAX);
2209        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2210        extension.withheld_amount = withheld_amount;
2211
2212        assert_eq!(
2213            &state.get_extension_types().unwrap(),
2214            &[ExtensionType::TransferFeeAmount]
2215        );
2216
2217        // fail unpack again, still no base data
2218        assert_eq!(
2219            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2220            Err(ProgramError::UninitializedAccount),
2221        );
2222
2223        // write base account
2224        let mut state =
2225            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2226        *state.base = TEST_POD_ACCOUNT;
2227        state.init_account_type().unwrap();
2228        let base = *state.base;
2229
2230        // check raw buffer
2231        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2232        expect.push(AccountType::Account.into());
2233        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2234        expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2235        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2236        assert_eq!(expect, buffer);
2237
2238        // check unpacking
2239        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2240        assert_eq!(state.base, &base);
2241        assert_eq!(
2242            &state.get_extension_types().unwrap(),
2243            &[ExtensionType::TransferFeeAmount]
2244        );
2245
2246        // update base
2247        *state.base = TEST_POD_ACCOUNT;
2248        state.base.amount = (u64::from(state.base.amount) + 100).into();
2249
2250        // check unpacking
2251        let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2252        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2253
2254        // update extension
2255        let withheld_amount = U64::from(u32::MAX as u64);
2256        unpacked_extension.withheld_amount = withheld_amount;
2257
2258        // check updates are propagated
2259        let base = *state.base;
2260        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2261        assert_eq!(state.base, &base);
2262        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2263        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2264
2265        // check raw buffer
2266        let mut expect = vec![];
2267        expect.extend_from_slice(bytemuck::bytes_of(&base));
2268        expect.push(AccountType::Account.into());
2269        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2270        expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2271        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2272        assert_eq!(expect, buffer);
2273
2274        // fail unpack as a mint
2275        assert_eq!(
2276            PodStateWithExtensions::<PodMint>::unpack(&buffer),
2277            Err(ProgramError::InvalidAccountData),
2278        );
2279    }
2280
2281    #[test]
2282    fn account_with_multisig_len() {
2283        let mut buffer = vec![0; Multisig::LEN];
2284        assert_eq!(
2285            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2286            Err(ProgramError::InvalidAccountData),
2287        );
2288        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2289            ExtensionType::AccountPaddingTest,
2290        ])
2291        .unwrap();
2292        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2293        let mut buffer = vec![0; account_size];
2294
2295        // write base account
2296        let mut state =
2297            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2298        *state.base = TEST_POD_ACCOUNT;
2299        state.init_account_type().unwrap();
2300
2301        // write padding
2302        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2303        extension.0.padding1 = [2; 128];
2304        extension.0.padding2 = [2; 48];
2305        extension.0.padding3 = [2; 9];
2306
2307        assert_eq!(
2308            &state.get_extension_types().unwrap(),
2309            &[ExtensionType::AccountPaddingTest]
2310        );
2311
2312        // check raw buffer
2313        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2314        expect.push(AccountType::Account.into());
2315        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2316        expect.extend_from_slice(&(size_of::<AccountPaddingTest>() as u16).to_le_bytes());
2317        expect.extend_from_slice(&vec![2; size_of::<AccountPaddingTest>()]);
2318        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2319        assert_eq!(expect, buffer);
2320    }
2321
2322    #[test]
2323    fn test_set_account_type() {
2324        // account with buffer big enough for AccountType and Extension
2325        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2326        let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2327            ExtensionType::ImmutableOwner,
2328        ])
2329        .unwrap()
2330            - buffer.len();
2331        buffer.append(&mut vec![0; needed_len]);
2332        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2333        assert_eq!(err, ProgramError::InvalidAccountData);
2334        set_account_type::<PodAccount>(&mut buffer).unwrap();
2335        // unpack is viable after manual set_account_type
2336        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2337        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2338        assert_eq!(state.account_type[0], AccountType::Account as u8);
2339        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
2340
2341        // account with buffer big enough for AccountType only
2342        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2343        buffer.append(&mut vec![0; 2]);
2344        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2345        assert_eq!(err, ProgramError::InvalidAccountData);
2346        set_account_type::<PodAccount>(&mut buffer).unwrap();
2347        // unpack is viable after manual set_account_type
2348        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2349        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2350        assert_eq!(state.account_type[0], AccountType::Account as u8);
2351
2352        // account with AccountType already set => noop
2353        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2354        buffer.append(&mut vec![2, 0]);
2355        let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2356        set_account_type::<PodAccount>(&mut buffer).unwrap();
2357        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2358        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2359        assert_eq!(state.account_type[0], AccountType::Account as u8);
2360
2361        // account with wrong AccountType fails
2362        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2363        buffer.append(&mut vec![1, 0]);
2364        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2365        assert_eq!(err, ProgramError::InvalidAccountData);
2366        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2367        assert_eq!(err, ProgramError::InvalidAccountData);
2368
2369        // mint with buffer big enough for AccountType and Extension
2370        let mut buffer = TEST_MINT_SLICE.to_vec();
2371        let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2372            ExtensionType::MintCloseAuthority,
2373        ])
2374        .unwrap()
2375            - buffer.len();
2376        buffer.append(&mut vec![0; needed_len]);
2377        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2378        assert_eq!(err, ProgramError::InvalidAccountData);
2379        set_account_type::<PodMint>(&mut buffer).unwrap();
2380        // unpack is viable after manual set_account_type
2381        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2382        assert_eq!(state.base, &TEST_POD_MINT);
2383        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2384        state.init_extension::<MintCloseAuthority>(true).unwrap();
2385
2386        // mint with buffer big enough for AccountType only
2387        let mut buffer = TEST_MINT_SLICE.to_vec();
2388        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2389        buffer.append(&mut vec![0; 2]);
2390        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2391        assert_eq!(err, ProgramError::InvalidAccountData);
2392        set_account_type::<PodMint>(&mut buffer).unwrap();
2393        // unpack is viable after manual set_account_type
2394        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2395        assert_eq!(state.base, &TEST_POD_MINT);
2396        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2397
2398        // mint with AccountType already set => noop
2399        let mut buffer = TEST_MINT_SLICE.to_vec();
2400        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2401        buffer.append(&mut vec![1, 0]);
2402        set_account_type::<PodMint>(&mut buffer).unwrap();
2403        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2404        assert_eq!(state.base, &TEST_POD_MINT);
2405        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2406
2407        // mint with wrong AccountType fails
2408        let mut buffer = TEST_MINT_SLICE.to_vec();
2409        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2410        buffer.append(&mut vec![2, 0]);
2411        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2412        assert_eq!(err, ProgramError::InvalidAccountData);
2413        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2414        assert_eq!(err, ProgramError::InvalidAccountData);
2415    }
2416
2417    #[test]
2418    fn test_set_account_type_wrongly() {
2419        // try to set PodAccount account_type to PodMint
2420        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2421        buffer.append(&mut vec![0; 2]);
2422        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2423        assert_eq!(err, ProgramError::InvalidAccountData);
2424
2425        // try to set PodMint account_type to PodAccount
2426        let mut buffer = TEST_MINT_SLICE.to_vec();
2427        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2428        buffer.append(&mut vec![0; 2]);
2429        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2430        assert_eq!(err, ProgramError::InvalidAccountData);
2431    }
2432
2433    #[test]
2434    fn test_get_required_init_account_extensions() {
2435        // Some mint extensions with no required account extensions
2436        let mint_extensions = vec![
2437            ExtensionType::MintCloseAuthority,
2438            ExtensionType::Uninitialized,
2439        ];
2440        assert_eq!(
2441            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2442            vec![]
2443        );
2444
2445        // One mint extension with required account extension, one without
2446        let mint_extensions = vec![
2447            ExtensionType::TransferFeeConfig,
2448            ExtensionType::MintCloseAuthority,
2449        ];
2450        assert_eq!(
2451            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2452            vec![ExtensionType::TransferFeeAmount]
2453        );
2454
2455        // Some mint extensions both with required account extensions
2456        let mint_extensions = vec![
2457            ExtensionType::TransferFeeConfig,
2458            ExtensionType::MintPaddingTest,
2459        ];
2460        assert_eq!(
2461            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2462            vec![
2463                ExtensionType::TransferFeeAmount,
2464                ExtensionType::AccountPaddingTest
2465            ]
2466        );
2467
2468        // Demonstrate that method does not dedupe inputs or outputs
2469        let mint_extensions = vec![
2470            ExtensionType::TransferFeeConfig,
2471            ExtensionType::TransferFeeConfig,
2472        ];
2473        assert_eq!(
2474            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2475            vec![
2476                ExtensionType::TransferFeeAmount,
2477                ExtensionType::TransferFeeAmount
2478            ]
2479        );
2480    }
2481
2482    #[test]
2483    fn mint_without_extensions() {
2484        let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2485        let mut buffer = vec![0; space];
2486        assert_eq!(
2487            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2488            Err(ProgramError::InvalidAccountData),
2489        );
2490
2491        // write base account
2492        let mut state =
2493            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2494        *state.base = TEST_POD_MINT;
2495        state.init_account_type().unwrap();
2496
2497        // fail init extension
2498        assert_eq!(
2499            state.init_extension::<TransferFeeConfig>(true),
2500            Err(ProgramError::InvalidAccountData),
2501        );
2502
2503        assert_eq!(TEST_MINT_SLICE, buffer);
2504    }
2505
2506    #[test]
2507    fn test_init_nonzero_default() {
2508        let mint_size =
2509            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2510                .unwrap();
2511        let mut buffer = vec![0; mint_size];
2512        let mut state =
2513            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2514        *state.base = TEST_POD_MINT;
2515        state.init_account_type().unwrap();
2516        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2517        assert_eq!(extension.padding1, [1; 128]);
2518        assert_eq!(extension.padding2, [2; 48]);
2519        assert_eq!(extension.padding3, [3; 9]);
2520    }
2521
2522    #[test]
2523    fn test_init_buffer_too_small() {
2524        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2525            ExtensionType::MintCloseAuthority,
2526        ])
2527        .unwrap();
2528        let mut buffer = vec![0; mint_size - 1];
2529        let mut state =
2530            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2531        let err = state
2532            .init_extension::<MintCloseAuthority>(true)
2533            .unwrap_err();
2534        assert_eq!(err, ProgramError::InvalidAccountData);
2535
2536        state.tlv_data[0] = 3;
2537        state.tlv_data[2] = 32;
2538        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2539        assert_eq!(err, ProgramError::InvalidAccountData);
2540
2541        let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2542        let err =
2543            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2544        assert_eq!(err, ProgramError::InvalidAccountData);
2545
2546        // OK since there are two bytes for the type, which is `Uninitialized`
2547        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2548        let mut state =
2549            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2550        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2551        assert_eq!(err, ProgramError::InvalidAccountData);
2552
2553        assert_eq!(state.get_extension_types().unwrap(), vec![]);
2554
2555        // OK, there aren't two bytes for the type, but that's fine
2556        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2557        let state =
2558            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2559        assert_eq!(state.get_extension_types().unwrap(), []);
2560    }
2561
2562    #[test]
2563    fn test_extension_with_no_data() {
2564        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2565            ExtensionType::ImmutableOwner,
2566        ])
2567        .unwrap();
2568        let mut buffer = vec![0; account_size];
2569        let mut state =
2570            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2571        *state.base = TEST_POD_ACCOUNT;
2572        state.init_account_type().unwrap();
2573
2574        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2575        assert_eq!(
2576            err,
2577            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2578        );
2579
2580        state.init_extension::<ImmutableOwner>(true).unwrap();
2581        assert_eq!(
2582            get_first_extension_type(state.tlv_data).unwrap(),
2583            Some(ExtensionType::ImmutableOwner)
2584        );
2585        assert_eq!(
2586            get_tlv_data_info(state.tlv_data).unwrap(),
2587            TlvDataInfo {
2588                extension_types: vec![ExtensionType::ImmutableOwner],
2589                used_len: add_type_and_length_to_len(0)
2590            }
2591        );
2592    }
2593
2594    #[test]
2595    fn fail_account_len_with_metadata() {
2596        assert_eq!(
2597            ExtensionType::try_calculate_account_len::<PodMint>(&[
2598                ExtensionType::MintCloseAuthority,
2599                ExtensionType::VariableLenMintTest,
2600                ExtensionType::TransferFeeConfig,
2601            ])
2602            .unwrap_err(),
2603            ProgramError::InvalidArgument
2604        );
2605    }
2606
2607    #[test]
2608    fn alloc() {
2609        let variable_len = VariableLenMintTest { data: vec![1] };
2610        let alloc_size = variable_len.get_packed_len().unwrap();
2611        let account_size =
2612            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2613        let mut buffer = vec![0; account_size];
2614        let mut state =
2615            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2616        state
2617            .init_variable_len_extension(&variable_len, false)
2618            .unwrap();
2619
2620        // can't double alloc
2621        assert_eq!(
2622            state
2623                .init_variable_len_extension(&variable_len, false)
2624                .unwrap_err(),
2625            TokenError::ExtensionAlreadyInitialized.into()
2626        );
2627
2628        // unless overwrite is set
2629        state
2630            .init_variable_len_extension(&variable_len, true)
2631            .unwrap();
2632
2633        // can't change the size during overwrite though
2634        assert_eq!(
2635            state
2636                .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2637                .unwrap_err(),
2638            TokenError::InvalidLengthForAlloc.into()
2639        );
2640
2641        // try to write too far, fail earlier
2642        assert_eq!(
2643            state
2644                .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2645                .unwrap_err(),
2646            ProgramError::InvalidAccountData
2647        );
2648    }
2649
2650    #[test]
2651    fn realloc() {
2652        let small_variable_len = VariableLenMintTest {
2653            data: vec![1, 2, 3],
2654        };
2655        let base_variable_len = VariableLenMintTest {
2656            data: vec![1, 2, 3, 4],
2657        };
2658        let big_variable_len = VariableLenMintTest {
2659            data: vec![1, 2, 3, 4, 5],
2660        };
2661        let too_big_variable_len = VariableLenMintTest {
2662            data: vec![1, 2, 3, 4, 5, 6],
2663        };
2664        let account_size =
2665            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2666                .unwrap()
2667                + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2668        let mut buffer = vec![0; account_size];
2669        let mut state =
2670            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2671
2672        // alloc both types
2673        state
2674            .init_variable_len_extension(&base_variable_len, false)
2675            .unwrap();
2676        let max_pubkey: MaybeNull<Address> =
2677            Some(Address::new_from_array([255; 32])).try_into().unwrap();
2678        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2679        extension.authority = max_pubkey;
2680        extension.metadata_address = max_pubkey;
2681
2682        // realloc first entry to larger
2683        state
2684            .realloc_variable_len_extension(&big_variable_len)
2685            .unwrap();
2686        let extension = state
2687            .get_variable_len_extension::<VariableLenMintTest>()
2688            .unwrap();
2689        assert_eq!(extension, big_variable_len);
2690        let extension = state.get_extension::<MetadataPointer>().unwrap();
2691        assert_eq!(extension.authority, max_pubkey);
2692        assert_eq!(extension.metadata_address, max_pubkey);
2693
2694        // realloc to smaller
2695        state
2696            .realloc_variable_len_extension(&small_variable_len)
2697            .unwrap();
2698        let extension = state
2699            .get_variable_len_extension::<VariableLenMintTest>()
2700            .unwrap();
2701        assert_eq!(extension, small_variable_len);
2702        let extension = state.get_extension::<MetadataPointer>().unwrap();
2703        assert_eq!(extension.authority, max_pubkey);
2704        assert_eq!(extension.metadata_address, max_pubkey);
2705        let diff = big_variable_len.get_packed_len().unwrap()
2706            - small_variable_len.get_packed_len().unwrap();
2707        assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2708
2709        // unpack again since we dropped the last `state`
2710        let mut state =
2711            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2712        // realloc too much, fails
2713        assert_eq!(
2714            state
2715                .realloc_variable_len_extension(&too_big_variable_len)
2716                .unwrap_err(),
2717            ProgramError::InvalidAccountData,
2718        );
2719    }
2720
2721    #[test]
2722    fn account_len() {
2723        let small_variable_len = VariableLenMintTest {
2724            data: vec![20, 30, 40],
2725        };
2726        let variable_len = VariableLenMintTest {
2727            data: vec![20, 30, 40, 50],
2728        };
2729        let big_variable_len = VariableLenMintTest {
2730            data: vec![20, 30, 40, 50, 60],
2731        };
2732        let value_len = variable_len.get_packed_len().unwrap();
2733        let account_size =
2734            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2735        let mut buffer = vec![0; account_size];
2736        let mut state =
2737            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2738
2739        // With a new extension, new length must include padding, 1 byte for
2740        // account type, 2 bytes for type, 2 for length
2741        let current_len = state.try_get_account_len().unwrap();
2742        assert_eq!(current_len, PodMint::SIZE_OF);
2743        let new_len = state
2744            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2745                &variable_len,
2746            )
2747            .unwrap();
2748        assert_eq!(
2749            new_len,
2750            BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2751        );
2752
2753        state
2754            .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2755            .unwrap();
2756        let current_len = state.try_get_account_len().unwrap();
2757        assert_eq!(current_len, new_len);
2758
2759        // Reduce the extension size
2760        let new_len = state
2761            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2762                &small_variable_len,
2763            )
2764            .unwrap();
2765        assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2766
2767        // Increase the extension size
2768        let new_len = state
2769            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2770                &big_variable_len,
2771            )
2772            .unwrap();
2773        assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2774
2775        // Maintain the extension size
2776        let new_len = state
2777            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2778                &variable_len,
2779            )
2780            .unwrap();
2781        assert_eq!(new_len, current_len);
2782    }
2783
2784    /// Test helper for mimicking the data layout an on-chain `AccountInfo`,
2785    /// which permits "reallocs" as the Solana runtime does it
2786    struct SolanaAccountData {
2787        data: Vec<u8>,
2788        lamports: u64,
2789        owner: Address,
2790    }
2791    impl SolanaAccountData {
2792        /// Create a new fake solana account data. The underlying vector is
2793        /// overallocated to mimic the runtime
2794        fn new(account_data: &[u8]) -> Self {
2795            let mut data = vec![];
2796            data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2797            data.extend_from_slice(account_data);
2798            data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2799            Self {
2800                data,
2801                lamports: 10,
2802                owner: Address::new_unique(),
2803            }
2804        }
2805
2806        /// Data lops off the first 8 bytes, since those store the size of the
2807        /// account for the Solana runtime
2808        fn data(&self) -> &[u8] {
2809            let start = size_of::<u64>();
2810            let len = self.len();
2811            &self.data[start..start + len]
2812        }
2813
2814        /// Gets the runtime length of the account data
2815        fn len(&self) -> usize {
2816            self.data
2817                .get(..size_of::<u64>())
2818                .and_then(|slice| slice.try_into().ok())
2819                .map(u64::from_le_bytes)
2820                .unwrap() as usize
2821        }
2822    }
2823    impl GetAccount for SolanaAccountData {
2824        fn get(&mut self) -> (&mut u64, &mut [u8], &Address, bool) {
2825            // need to pull out the data here to avoid a double-mutable borrow
2826            let start = size_of::<u64>();
2827            let len = self.len();
2828            (
2829                &mut self.lamports,
2830                &mut self.data[start..start + len],
2831                &self.owner,
2832                false,
2833            )
2834        }
2835    }
2836
2837    #[test]
2838    fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2839        let fixed_len = FixedLenMintTest {
2840            data: [1, 2, 3, 4, 5, 6, 7, 8],
2841        };
2842        let value_len = size_of::<FixedLenMintTest>();
2843        let base_account_size = PodMint::SIZE_OF;
2844        let mut buffer = vec![0; base_account_size];
2845        let state =
2846            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2847        *state.base = TEST_POD_MINT;
2848
2849        let mut data = SolanaAccountData::new(&buffer);
2850        let key = Address::new_unique();
2851        let account_info = (&key, &mut data).into_account_info();
2852
2853        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2854        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2855        assert_eq!(data.len(), new_account_len);
2856        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2857        assert_eq!(
2858            state.get_extension::<FixedLenMintTest>().unwrap(),
2859            &fixed_len,
2860        );
2861
2862        // alloc again succeeds with "overwrite"
2863        let account_info = (&key, &mut data).into_account_info();
2864        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2865
2866        // alloc again fails without "overwrite"
2867        let account_info = (&key, &mut data).into_account_info();
2868        assert_eq!(
2869            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2870            TokenError::ExtensionAlreadyInitialized.into()
2871        );
2872    }
2873
2874    #[test]
2875    fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2876        let variable_len = VariableLenMintTest { data: vec![20, 99] };
2877        let value_len = variable_len.get_packed_len().unwrap();
2878        let base_account_size = PodMint::SIZE_OF;
2879        let mut buffer = vec![0; base_account_size];
2880        let state =
2881            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2882        *state.base = TEST_POD_MINT;
2883
2884        let mut data = SolanaAccountData::new(&buffer);
2885        let key = Address::new_unique();
2886        let account_info = (&key, &mut data).into_account_info();
2887
2888        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2889            &account_info,
2890            &variable_len,
2891            false,
2892        )
2893        .unwrap();
2894        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2895        assert_eq!(data.len(), new_account_len);
2896        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2897        assert_eq!(
2898            state
2899                .get_variable_len_extension::<VariableLenMintTest>()
2900                .unwrap(),
2901            variable_len
2902        );
2903
2904        // alloc again succeeds with "overwrite"
2905        let account_info = (&key, &mut data).into_account_info();
2906        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2907            &account_info,
2908            &variable_len,
2909            true,
2910        )
2911        .unwrap();
2912
2913        // alloc again fails without "overwrite"
2914        let account_info = (&key, &mut data).into_account_info();
2915        assert_eq!(
2916            alloc_and_serialize_variable_len_extension::<PodMint, _>(
2917                &account_info,
2918                &variable_len,
2919                false,
2920            )
2921            .unwrap_err(),
2922            TokenError::ExtensionAlreadyInitialized.into()
2923        );
2924    }
2925
2926    #[test]
2927    fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2928        let fixed_len = FixedLenMintTest {
2929            data: [1, 2, 3, 4, 5, 6, 7, 8],
2930        };
2931        let value_len = size_of::<FixedLenMintTest>();
2932        let account_size =
2933            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2934                .unwrap()
2935                + add_type_and_length_to_len(value_len);
2936        let mut buffer = vec![0; account_size];
2937        let mut state =
2938            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2939        *state.base = TEST_POD_MINT;
2940        state.init_account_type().unwrap();
2941
2942        let test_key: MaybeNull<Address> =
2943            Some(Address::new_from_array([20; 32])).try_into().unwrap();
2944        let extension = state.init_extension::<GroupPointer>(false).unwrap();
2945        extension.authority = test_key;
2946        extension.group_address = test_key;
2947
2948        let mut data = SolanaAccountData::new(&buffer);
2949        let key = Address::new_unique();
2950        let account_info = (&key, &mut data).into_account_info();
2951
2952        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2953        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2954            + add_type_and_length_to_len(value_len)
2955            + add_type_and_length_to_len(size_of::<GroupPointer>());
2956        assert_eq!(data.len(), new_account_len);
2957        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2958        assert_eq!(
2959            state.get_extension::<FixedLenMintTest>().unwrap(),
2960            &fixed_len,
2961        );
2962        let extension = state.get_extension::<GroupPointer>().unwrap();
2963        assert_eq!(extension.authority, test_key);
2964        assert_eq!(extension.group_address, test_key);
2965
2966        // alloc again succeeds with "overwrite"
2967        let account_info = (&key, &mut data).into_account_info();
2968        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2969
2970        // alloc again fails without "overwrite"
2971        let account_info = (&key, &mut data).into_account_info();
2972        assert_eq!(
2973            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2974            TokenError::ExtensionAlreadyInitialized.into()
2975        );
2976    }
2977
2978    #[test]
2979    fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2980        let variable_len = VariableLenMintTest { data: vec![42, 6] };
2981        let value_len = variable_len.get_packed_len().unwrap();
2982        let account_size =
2983            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2984                .unwrap()
2985                + add_type_and_length_to_len(value_len);
2986        let mut buffer = vec![0; account_size];
2987        let mut state =
2988            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2989        *state.base = TEST_POD_MINT;
2990        state.init_account_type().unwrap();
2991
2992        let test_key: MaybeNull<Address> =
2993            Some(Address::new_from_array([20; 32])).try_into().unwrap();
2994        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2995        extension.authority = test_key;
2996        extension.metadata_address = test_key;
2997
2998        let mut data = SolanaAccountData::new(&buffer);
2999        let key = Address::new_unique();
3000        let account_info = (&key, &mut data).into_account_info();
3001
3002        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3003            &account_info,
3004            &variable_len,
3005            false,
3006        )
3007        .unwrap();
3008        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3009            + add_type_and_length_to_len(value_len)
3010            + add_type_and_length_to_len(size_of::<MetadataPointer>());
3011        assert_eq!(data.len(), new_account_len);
3012        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3013        assert_eq!(
3014            state
3015                .get_variable_len_extension::<VariableLenMintTest>()
3016                .unwrap(),
3017            variable_len
3018        );
3019        let extension = state.get_extension::<MetadataPointer>().unwrap();
3020        assert_eq!(extension.authority, test_key);
3021        assert_eq!(extension.metadata_address, test_key);
3022
3023        // alloc again succeeds with "overwrite"
3024        let account_info = (&key, &mut data).into_account_info();
3025        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3026            &account_info,
3027            &variable_len,
3028            true,
3029        )
3030        .unwrap();
3031
3032        // alloc again fails without "overwrite"
3033        let account_info = (&key, &mut data).into_account_info();
3034        assert_eq!(
3035            alloc_and_serialize_variable_len_extension::<PodMint, _>(
3036                &account_info,
3037                &variable_len,
3038                false,
3039            )
3040            .unwrap_err(),
3041            TokenError::ExtensionAlreadyInitialized.into()
3042        );
3043    }
3044
3045    #[test]
3046    fn realloc_variable_len_tlv_in_account_info() {
3047        let variable_len = VariableLenMintTest {
3048            data: vec![1, 2, 3, 4, 5],
3049        };
3050        let alloc_size = variable_len.get_packed_len().unwrap();
3051        let account_size =
3052            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3053                .unwrap()
3054                + add_type_and_length_to_len(alloc_size);
3055        let mut buffer = vec![0; account_size];
3056        let mut state =
3057            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3058        *state.base = TEST_POD_MINT;
3059        state.init_account_type().unwrap();
3060
3061        // alloc both types
3062        state
3063            .init_variable_len_extension(&variable_len, false)
3064            .unwrap();
3065        let max_pubkey: MaybeNull<Address> =
3066            Some(Address::new_from_array([255; 32])).try_into().unwrap();
3067        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3068        extension.authority = max_pubkey;
3069        extension.metadata_address = max_pubkey;
3070
3071        // reallocate to smaller, make sure existing extension is fine
3072        let mut data = SolanaAccountData::new(&buffer);
3073        let key = Address::new_unique();
3074        let account_info = (&key, &mut data).into_account_info();
3075        let variable_len = VariableLenMintTest { data: vec![1, 2] };
3076        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3077            &account_info,
3078            &variable_len,
3079            true,
3080        )
3081        .unwrap();
3082
3083        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3084        let extension = state.get_extension::<MetadataPointer>().unwrap();
3085        assert_eq!(extension.authority, max_pubkey);
3086        assert_eq!(extension.metadata_address, max_pubkey);
3087        let extension = state
3088            .get_variable_len_extension::<VariableLenMintTest>()
3089            .unwrap();
3090        assert_eq!(extension, variable_len);
3091        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3092
3093        // reallocate to larger
3094        let account_info = (&key, &mut data).into_account_info();
3095        let variable_len = VariableLenMintTest {
3096            data: vec![1, 2, 3, 4, 5, 6, 7],
3097        };
3098        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3099            &account_info,
3100            &variable_len,
3101            true,
3102        )
3103        .unwrap();
3104
3105        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3106        let extension = state.get_extension::<MetadataPointer>().unwrap();
3107        assert_eq!(extension.authority, max_pubkey);
3108        assert_eq!(extension.metadata_address, max_pubkey);
3109        let extension = state
3110            .get_variable_len_extension::<VariableLenMintTest>()
3111            .unwrap();
3112        assert_eq!(extension, variable_len);
3113        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3114
3115        // reallocate to same
3116        let account_info = (&key, &mut data).into_account_info();
3117        let variable_len = VariableLenMintTest {
3118            data: vec![7, 6, 5, 4, 3, 2, 1],
3119        };
3120        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3121            &account_info,
3122            &variable_len,
3123            true,
3124        )
3125        .unwrap();
3126
3127        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3128        let extension = state.get_extension::<MetadataPointer>().unwrap();
3129        assert_eq!(extension.authority, max_pubkey);
3130        assert_eq!(extension.metadata_address, max_pubkey);
3131        let extension = state
3132            .get_variable_len_extension::<VariableLenMintTest>()
3133            .unwrap();
3134        assert_eq!(extension, variable_len);
3135        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3136    }
3137}