spl_token_2022/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3#[cfg(feature = "serde-traits")]
4use serde::{Deserialize, Serialize};
5use {
6    crate::{
7        error::TokenError,
8        extension::{
9            confidential_mint_burn::ConfidentialMintBurn,
10            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11            confidential_transfer_fee::{
12                ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13            },
14            cpi_guard::CpiGuard,
15            default_account_state::DefaultAccountState,
16            group_member_pointer::GroupMemberPointer,
17            group_pointer::GroupPointer,
18            immutable_owner::ImmutableOwner,
19            interest_bearing_mint::InterestBearingConfig,
20            memo_transfer::MemoTransfer,
21            metadata_pointer::MetadataPointer,
22            mint_close_authority::MintCloseAuthority,
23            non_transferable::{NonTransferable, NonTransferableAccount},
24            pausable::{PausableAccount, PausableConfig},
25            permanent_delegate::PermanentDelegate,
26            scaled_ui_amount::ScaledUiAmountConfig,
27            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
28            transfer_hook::{TransferHook, TransferHookAccount},
29        },
30        pod::{PodAccount, PodMint},
31        state::{Account, Mint, Multisig, PackedSizeOf},
32    },
33    bytemuck::{Pod, Zeroable},
34    num_enum::{IntoPrimitive, TryFromPrimitive},
35    solana_account_info::AccountInfo,
36    solana_program_error::ProgramError,
37    solana_program_pack::{IsInitialized, Pack},
38    spl_pod::{
39        bytemuck::{pod_from_bytes, pod_from_bytes_mut, pod_get_packed_len},
40        primitives::PodU16,
41    },
42    spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
43    spl_type_length_value::variable_len_pack::VariableLenPack,
44    std::{
45        cmp::Ordering,
46        convert::{TryFrom, TryInto},
47        mem::size_of,
48    },
49};
50
51/// Confidential Transfer extension
52pub mod confidential_transfer;
53/// Confidential Transfer Fee extension
54pub mod confidential_transfer_fee;
55/// CPI Guard extension
56pub mod cpi_guard;
57/// Default Account State extension
58pub mod default_account_state;
59/// Group Member Pointer extension
60pub mod group_member_pointer;
61/// Group Pointer extension
62pub mod group_pointer;
63/// Immutable Owner extension
64pub mod immutable_owner;
65/// Interest-Bearing Mint extension
66pub mod interest_bearing_mint;
67/// Memo Transfer extension
68pub mod memo_transfer;
69/// Metadata Pointer extension
70pub mod metadata_pointer;
71/// Mint Close Authority extension
72pub mod mint_close_authority;
73/// Non Transferable extension
74pub mod non_transferable;
75/// Pausable extension
76pub mod pausable;
77/// Permanent Delegate extension
78pub mod permanent_delegate;
79/// Utility to reallocate token accounts
80pub mod reallocate;
81/// Scaled UI Amount extension
82pub mod scaled_ui_amount;
83/// Token-group extension
84pub mod token_group;
85/// Token-metadata extension
86pub mod token_metadata;
87/// Transfer Fee extension
88pub mod transfer_fee;
89/// Transfer Hook extension
90pub mod transfer_hook;
91
92/// Confidential mint-burn extension
93pub mod confidential_mint_burn;
94
95/// Length in TLV structure
96#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
97#[repr(transparent)]
98pub struct Length(PodU16);
99impl From<Length> for usize {
100    fn from(n: Length) -> Self {
101        Self::from(u16::from(n.0))
102    }
103}
104impl TryFrom<usize> for Length {
105    type Error = ProgramError;
106    fn try_from(n: usize) -> Result<Self, Self::Error> {
107        u16::try_from(n)
108            .map(|v| Self(PodU16::from(v)))
109            .map_err(|_| ProgramError::AccountDataTooSmall)
110    }
111}
112
113/// Helper function to get the current `TlvIndices` from the current spot
114fn get_tlv_indices(type_start: usize) -> TlvIndices {
115    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
116    let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
117    TlvIndices {
118        type_start,
119        length_start,
120        value_start,
121    }
122}
123
124/// Helper function to tack on the size of an extension bytes if an account with
125/// extensions is exactly the size of a multisig
126const fn adjust_len_for_multisig(account_len: usize) -> usize {
127    if account_len == Multisig::LEN {
128        account_len.saturating_add(size_of::<ExtensionType>())
129    } else {
130        account_len
131    }
132}
133
134/// Helper function to calculate exactly how many bytes a value will take up,
135/// given the value's length
136const fn add_type_and_length_to_len(value_len: usize) -> usize {
137    value_len
138        .saturating_add(size_of::<ExtensionType>())
139        .saturating_add(pod_get_packed_len::<Length>())
140}
141
142/// Helper struct for returning the indices of the type, length, and value in
143/// a TLV entry
144#[derive(Debug)]
145struct TlvIndices {
146    pub type_start: usize,
147    pub length_start: usize,
148    pub value_start: usize,
149}
150fn get_extension_indices<V: Extension>(
151    tlv_data: &[u8],
152    init: bool,
153) -> Result<TlvIndices, ProgramError> {
154    let mut start_index = 0;
155    let v_account_type = V::TYPE.get_account_type();
156    while start_index < tlv_data.len() {
157        let tlv_indices = get_tlv_indices(start_index);
158        if tlv_data.len() < tlv_indices.value_start {
159            return Err(ProgramError::InvalidAccountData);
160        }
161        let extension_type =
162            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
163        let account_type = extension_type.get_account_type();
164        if extension_type == V::TYPE {
165            // found an instance of the extension that we're initializing, return!
166            return Ok(tlv_indices);
167        // got to an empty spot, init here, or error if we're searching, since
168        // nothing is written after an Uninitialized spot
169        } else if extension_type == ExtensionType::Uninitialized {
170            if init {
171                return Ok(tlv_indices);
172            } else {
173                return Err(TokenError::ExtensionNotFound.into());
174            }
175        } else if v_account_type != account_type {
176            return Err(TokenError::ExtensionTypeMismatch.into());
177        } else {
178            let length = pod_from_bytes::<Length>(
179                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
180            )?;
181            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
182            start_index = value_end_index;
183        }
184    }
185    Err(ProgramError::InvalidAccountData)
186}
187
188/// Basic information about the TLV buffer, collected from iterating through all
189/// entries
190#[derive(Debug, PartialEq)]
191struct TlvDataInfo {
192    /// The extension types written in the TLV buffer
193    extension_types: Vec<ExtensionType>,
194    /// The total number bytes allocated for all TLV entries.
195    ///
196    /// Each TLV entry's allocated bytes comprises two bytes for the `type`, two
197    /// bytes for the `length`, and `length` number of bytes for the `value`.
198    used_len: usize,
199}
200
201/// Fetches basic information about the TLV buffer by iterating through all
202/// TLV entries.
203fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
204    let mut extension_types = vec![];
205    let mut start_index = 0;
206    while start_index < tlv_data.len() {
207        let tlv_indices = get_tlv_indices(start_index);
208        if tlv_data.len() < tlv_indices.length_start {
209            // There aren't enough bytes to store the next type, which means we
210            // got to the end. The last byte could be used during a realloc!
211            return Ok(TlvDataInfo {
212                extension_types,
213                used_len: tlv_indices.type_start,
214            });
215        }
216        let extension_type =
217            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
218        if extension_type == ExtensionType::Uninitialized {
219            return Ok(TlvDataInfo {
220                extension_types,
221                used_len: tlv_indices.type_start,
222            });
223        } else {
224            if tlv_data.len() < tlv_indices.value_start {
225                // not enough bytes to store the length, malformed
226                return Err(ProgramError::InvalidAccountData);
227            }
228            extension_types.push(extension_type);
229            let length = pod_from_bytes::<Length>(
230                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
231            )?;
232
233            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
234            if value_end_index > tlv_data.len() {
235                // value blows past the size of the slice, malformed
236                return Err(ProgramError::InvalidAccountData);
237            }
238            start_index = value_end_index;
239        }
240    }
241    Ok(TlvDataInfo {
242        extension_types,
243        used_len: start_index,
244    })
245}
246
247fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
248    if tlv_data.is_empty() {
249        Ok(None)
250    } else {
251        let tlv_indices = get_tlv_indices(0);
252        if tlv_data.len() <= tlv_indices.length_start {
253            return Ok(None);
254        }
255        let extension_type =
256            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
257        if extension_type == ExtensionType::Uninitialized {
258            Ok(None)
259        } else {
260            Ok(Some(extension_type))
261        }
262    }
263}
264
265fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
266    if input.len() == Multisig::LEN || input.len() < minimum_len {
267        Err(ProgramError::InvalidAccountData)
268    } else {
269        Ok(())
270    }
271}
272
273fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
274    if account_type != S::ACCOUNT_TYPE {
275        Err(ProgramError::InvalidAccountData)
276    } else {
277        Ok(())
278    }
279}
280
281/// Any account with extensions must be at least `Account::LEN`.  Both mints and
282/// accounts can have extensions
283/// A mint with extensions that takes it past 165 could be indiscernible from an
284/// Account with an extension, even if we add the account type. For example,
285/// let's say we have:
286///
287/// ```text
288/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
289///                          ^     ^       ^     ^
290///                     acct type  extension length data...
291///
292/// Mint: 82 bytes... + 83 bytes of other extension data
293///     + [2, 0, 3, 0, 100, ....]
294///      (data in extension just happens to look like this)
295/// ```
296///
297/// With this approach, we only start writing the TLV data after `Account::LEN`,
298/// which means we always know that the account type is going to be right after
299/// that. We do a special case checking for a Multisig length, because those
300/// aren't extensible under any circumstances.
301const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
302/// Helper that tacks on the `AccountType` length, which gives the minimum for
303/// any account with extensions
304const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
305
306fn type_and_tlv_indices<S: BaseState>(
307    rest_input: &[u8],
308) -> Result<Option<(usize, usize)>, ProgramError> {
309    if rest_input.is_empty() {
310        Ok(None)
311    } else {
312        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
313        // check padding is all zeroes
314        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
315        if rest_input.len() < tlv_start_index {
316            return Err(ProgramError::InvalidAccountData);
317        }
318        if rest_input[..account_type_index] != vec![0; account_type_index] {
319            Err(ProgramError::InvalidAccountData)
320        } else {
321            Ok(Some((account_type_index, tlv_start_index)))
322        }
323    }
324}
325
326/// Checks a base buffer to verify if it is an Account without having to
327/// completely deserialize it
328fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
329    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
330
331    if input.len() != BASE_ACCOUNT_LENGTH {
332        return Err(ProgramError::InvalidAccountData);
333    }
334    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
335}
336
337fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
338    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
339        return Err(ProgramError::InvalidAccountData);
340    }
341    let TlvIndices {
342        type_start: _,
343        length_start,
344        value_start,
345    } = get_extension_indices::<V>(tlv_data, false)?;
346    // get_extension_indices has checked that tlv_data is long enough to include
347    // these indices
348    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
349    let value_end = value_start.saturating_add(usize::from(*length));
350    if tlv_data.len() < value_end {
351        return Err(ProgramError::InvalidAccountData);
352    }
353    Ok(&tlv_data[value_start..value_end])
354}
355
356fn get_extension_bytes_mut<S: BaseState, V: Extension>(
357    tlv_data: &mut [u8],
358) -> Result<&mut [u8], ProgramError> {
359    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
360        return Err(ProgramError::InvalidAccountData);
361    }
362    let TlvIndices {
363        type_start: _,
364        length_start,
365        value_start,
366    } = get_extension_indices::<V>(tlv_data, false)?;
367    // get_extension_indices has checked that tlv_data is long enough to include
368    // these indices
369    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
370    let value_end = value_start.saturating_add(usize::from(*length));
371    if tlv_data.len() < value_end {
372        return Err(ProgramError::InvalidAccountData);
373    }
374    Ok(&mut tlv_data[value_start..value_end])
375}
376
377/// Calculate the new expected size if the state allocates the given number
378/// of bytes for the given extension type.
379///
380/// Provides the correct answer regardless if the extension is already present
381/// in the TLV data.
382fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
383    tlv_data: &[u8],
384    new_extension_len: usize,
385) -> Result<usize, ProgramError> {
386    // get the new length used by the extension
387    let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
388    let tlv_info = get_tlv_data_info(tlv_data)?;
389    // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
390    // and account type
391    let current_len = tlv_info
392        .used_len
393        .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
394    // get the current length used by the extension
395    let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
396        .map(|x| add_type_and_length_to_len(x.len()))
397        .unwrap_or(0);
398    let new_len = current_len
399        .saturating_sub(current_extension_len)
400        .saturating_add(new_extension_tlv_len);
401    Ok(adjust_len_for_multisig(new_len))
402}
403
404/// Trait for base state with extension
405pub trait BaseStateWithExtensions<S: BaseState> {
406    /// Get the buffer containing all extension data
407    fn get_tlv_data(&self) -> &[u8];
408
409    /// Fetch the bytes for a TLV entry
410    fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
411        get_extension_bytes::<S, V>(self.get_tlv_data())
412    }
413
414    /// Unpack a portion of the TLV data as the desired type
415    fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
416        pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
417    }
418
419    /// Unpacks a portion of the TLV data as the desired variable-length type
420    fn get_variable_len_extension<V: Extension + VariableLenPack>(
421        &self,
422    ) -> Result<V, ProgramError> {
423        let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
424        V::unpack_from_slice(data)
425    }
426
427    /// Iterates through the TLV entries, returning only the types
428    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
429        get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
430    }
431
432    /// Get just the first extension type, useful to track mixed initialization
433    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
434        get_first_extension_type(self.get_tlv_data())
435    }
436
437    /// Get the total number of bytes used by TLV entries and the base type
438    fn try_get_account_len(&self) -> Result<usize, ProgramError> {
439        let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
440        if tlv_info.extension_types.is_empty() {
441            Ok(S::SIZE_OF)
442        } else {
443            let total_len = tlv_info
444                .used_len
445                .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
446            Ok(adjust_len_for_multisig(total_len))
447        }
448    }
449    /// Calculate the new expected size if the state allocates the given
450    /// fixed-length extension instance.
451    /// If the state already has the extension, the resulting account length
452    /// will be unchanged.
453    fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
454        try_get_new_account_len_for_extension_len::<S, V>(
455            self.get_tlv_data(),
456            pod_get_packed_len::<V>(),
457        )
458    }
459
460    /// Calculate the new expected size if the state allocates the given
461    /// variable-length extension instance.
462    fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
463        &self,
464        new_extension: &V,
465    ) -> Result<usize, ProgramError> {
466        try_get_new_account_len_for_extension_len::<S, V>(
467            self.get_tlv_data(),
468            new_extension.get_packed_len()?,
469        )
470    }
471}
472
473/// Encapsulates owned immutable base state data (mint or account) with possible
474/// extensions
475#[derive(Clone, Debug, PartialEq)]
476pub struct StateWithExtensionsOwned<S: BaseState> {
477    /// Unpacked base data
478    pub base: S,
479    /// Raw TLV data, deserialized on demand
480    tlv_data: Vec<u8>,
481}
482impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
483    /// Unpack base state, leaving the extension data as a slice
484    ///
485    /// Fails if the base state is not initialized.
486    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
487        check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
488        let mut rest = input.split_off(S::SIZE_OF);
489        let base = S::unpack(&input)?;
490        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
491            // type_and_tlv_indices() checks that returned indexes are within range
492            let account_type = AccountType::try_from(rest[account_type_index])
493                .map_err(|_| ProgramError::InvalidAccountData)?;
494            check_account_type::<S>(account_type)?;
495            let tlv_data = rest.split_off(tlv_start_index);
496            Ok(Self { base, tlv_data })
497        } else {
498            Ok(Self {
499                base,
500                tlv_data: vec![],
501            })
502        }
503    }
504}
505
506impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
507    fn get_tlv_data(&self) -> &[u8] {
508        &self.tlv_data
509    }
510}
511
512/// Encapsulates immutable base state data (mint or account) with possible
513/// extensions
514#[derive(Debug, PartialEq)]
515pub struct StateWithExtensions<'data, S: BaseState + Pack> {
516    /// Unpacked base data
517    pub base: S,
518    /// Slice of data containing all TLV data, deserialized on demand
519    tlv_data: &'data [u8],
520}
521impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
522    /// Unpack base state, leaving the extension data as a slice
523    ///
524    /// Fails if the base state is not initialized.
525    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
526        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
527        let (base_data, rest) = input.split_at(S::SIZE_OF);
528        let base = S::unpack(base_data)?;
529        let tlv_data = unpack_tlv_data::<S>(rest)?;
530        Ok(Self { base, tlv_data })
531    }
532}
533impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
534    fn get_tlv_data(&self) -> &[u8] {
535        self.tlv_data
536    }
537}
538
539/// Encapsulates immutable base state data (mint or account) with possible
540/// extensions, where the base state is Pod for zero-copy serde.
541#[derive(Debug, PartialEq)]
542pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
543    /// Unpacked base data
544    pub base: &'data S,
545    /// Slice of data containing all TLV data, deserialized on demand
546    tlv_data: &'data [u8],
547}
548impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
549    /// Unpack base state, leaving the extension data as a slice
550    ///
551    /// Fails if the base state is not initialized.
552    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
553        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
554        let (base_data, rest) = input.split_at(S::SIZE_OF);
555        let base = pod_from_bytes::<S>(base_data)?;
556        if !base.is_initialized() {
557            Err(ProgramError::UninitializedAccount)
558        } else {
559            let tlv_data = unpack_tlv_data::<S>(rest)?;
560            Ok(Self { base, tlv_data })
561        }
562    }
563}
564impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
565    fn get_tlv_data(&self) -> &[u8] {
566        self.tlv_data
567    }
568}
569
570/// Trait for mutable base state with extension
571pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
572    /// Get the underlying TLV data as mutable
573    fn get_tlv_data_mut(&mut self) -> &mut [u8];
574
575    /// Get the underlying account type as mutable
576    fn get_account_type_mut(&mut self) -> &mut [u8];
577
578    /// Unpack a portion of the TLV data as the base mutable bytes
579    fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
580        get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
581    }
582
583    /// Unpack a portion of the TLV data as the desired type that allows
584    /// modifying the type
585    fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
586        pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
587    }
588
589    /// Packs a variable-length extension into its appropriate data segment.
590    /// Fails if space hasn't already been allocated for the given extension
591    fn pack_variable_len_extension<V: Extension + VariableLenPack>(
592        &mut self,
593        extension: &V,
594    ) -> Result<(), ProgramError> {
595        let data = self.get_extension_bytes_mut::<V>()?;
596        // NOTE: Do *not* use `pack`, since the length check will cause
597        // reallocations to smaller sizes to fail
598        extension.pack_into_slice(data)
599    }
600
601    /// Packs the default extension data into an open slot if not already found
602    /// in the data buffer. If extension is already found in the buffer, it
603    /// overwrites the existing extension with the default state if
604    /// `overwrite` is set. If extension found, but `overwrite` is not set,
605    /// it returns error.
606    fn init_extension<V: Extension + Pod + Default>(
607        &mut self,
608        overwrite: bool,
609    ) -> Result<&mut V, ProgramError> {
610        let length = pod_get_packed_len::<V>();
611        let buffer = self.alloc::<V>(length, overwrite)?;
612        let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
613        *extension_ref = V::default();
614        Ok(extension_ref)
615    }
616
617    /// Reallocate and overwrite the TLV entry for the given variable-length
618    /// extension.
619    ///
620    /// Returns an error if the extension is not present, or if there is not
621    /// enough space in the buffer.
622    fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
623        &mut self,
624        new_extension: &V,
625    ) -> Result<(), ProgramError> {
626        let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
627        new_extension.pack_into_slice(data)
628    }
629
630    /// Reallocate the TLV entry for the given extension to the given number of
631    /// bytes.
632    ///
633    /// If the new length is smaller, it will compact the rest of the buffer and
634    /// zero out the difference at the end. If it's larger, it will move the
635    /// rest of the buffer data and zero out the new data.
636    ///
637    /// Returns an error if the extension is not present, or if this is not
638    /// enough space in the buffer.
639    fn realloc<V: Extension + VariableLenPack>(
640        &mut self,
641        length: usize,
642    ) -> Result<&mut [u8], ProgramError> {
643        let tlv_data = self.get_tlv_data_mut();
644        let TlvIndices {
645            type_start: _,
646            length_start,
647            value_start,
648        } = get_extension_indices::<V>(tlv_data, false)?;
649        let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
650        let data_len = tlv_data.len();
651
652        let length_ref = pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
653        let old_length = usize::from(*length_ref);
654
655        // Length check to avoid a panic later in `copy_within`
656        if old_length < length {
657            let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
658            if new_tlv_len > data_len {
659                return Err(ProgramError::InvalidAccountData);
660            }
661        }
662
663        // write new length after the check, to avoid getting into a bad situation
664        // if trying to recover from an error
665        *length_ref = Length::try_from(length)?;
666
667        let old_value_end = value_start.saturating_add(old_length);
668        let new_value_end = value_start.saturating_add(length);
669        tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
670        match old_length.cmp(&length) {
671            Ordering::Greater => {
672                // realloc to smaller, zero out the end
673                let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
674                tlv_data[new_tlv_len..tlv_len].fill(0);
675            }
676            Ordering::Less => {
677                // realloc to bigger, zero out the new bytes
678                tlv_data[old_value_end..new_value_end].fill(0);
679            }
680            Ordering::Equal => {} // nothing needed!
681        }
682
683        Ok(&mut tlv_data[value_start..new_value_end])
684    }
685
686    /// Allocate the given number of bytes for the given variable-length
687    /// extension and write its contents into the TLV buffer.
688    ///
689    /// This can only be used for variable-sized types, such as `String` or
690    /// `Vec`. `Pod` types must use `init_extension`
691    fn init_variable_len_extension<V: Extension + VariableLenPack>(
692        &mut self,
693        extension: &V,
694        overwrite: bool,
695    ) -> Result<(), ProgramError> {
696        let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
697        extension.pack_into_slice(data)
698    }
699
700    /// Allocate some space for the extension in the TLV data
701    fn alloc<V: Extension>(
702        &mut self,
703        length: usize,
704        overwrite: bool,
705    ) -> Result<&mut [u8], ProgramError> {
706        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
707            return Err(ProgramError::InvalidAccountData);
708        }
709        let tlv_data = self.get_tlv_data_mut();
710        let TlvIndices {
711            type_start,
712            length_start,
713            value_start,
714        } = get_extension_indices::<V>(tlv_data, true)?;
715
716        if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
717            return Err(ProgramError::InvalidAccountData);
718        }
719        let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
720
721        if extension_type == ExtensionType::Uninitialized || overwrite {
722            // write extension type
723            let extension_type_array: [u8; 2] = V::TYPE.into();
724            let extension_type_ref = &mut tlv_data[type_start..length_start];
725            extension_type_ref.copy_from_slice(&extension_type_array);
726            // write length
727            let length_ref =
728                pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
729
730            // check that the length is the same if we're doing an alloc
731            // with overwrite, otherwise a realloc should be done
732            if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
733                return Err(TokenError::InvalidLengthForAlloc.into());
734            }
735
736            *length_ref = Length::try_from(length)?;
737
738            let value_end = value_start.saturating_add(length);
739            Ok(&mut tlv_data[value_start..value_end])
740        } else {
741            // extension is already initialized, but no overwrite permission
742            Err(TokenError::ExtensionAlreadyInitialized.into())
743        }
744    }
745
746    /// If `extension_type` is an Account-associated `ExtensionType` that
747    /// requires initialization on `InitializeAccount`, this method packs
748    /// the default relevant `Extension` of an `ExtensionType` into an open
749    /// slot if not already found in the data buffer, otherwise overwrites
750    /// the existing extension with the default state. For all other
751    /// `ExtensionType`s, this is a no-op.
752    fn init_account_extension_from_type(
753        &mut self,
754        extension_type: ExtensionType,
755    ) -> Result<(), ProgramError> {
756        if extension_type.get_account_type() != AccountType::Account {
757            return Ok(());
758        }
759        match extension_type {
760            ExtensionType::TransferFeeAmount => {
761                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
762            }
763            ExtensionType::ImmutableOwner => {
764                self.init_extension::<ImmutableOwner>(true).map(|_| ())
765            }
766            ExtensionType::NonTransferableAccount => self
767                .init_extension::<NonTransferableAccount>(true)
768                .map(|_| ()),
769            ExtensionType::TransferHookAccount => {
770                self.init_extension::<TransferHookAccount>(true).map(|_| ())
771            }
772            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
773            // on InitializeAccount
774            ExtensionType::ConfidentialTransferAccount => Ok(()),
775            ExtensionType::PausableAccount => {
776                self.init_extension::<PausableAccount>(true).map(|_| ())
777            }
778            #[cfg(test)]
779            ExtensionType::AccountPaddingTest => {
780                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
781            }
782            _ => unreachable!(),
783        }
784    }
785
786    /// Write the account type into the buffer, done during the base
787    /// state initialization
788    /// Noops if there is no room for an extension in the account, needed for
789    /// pure base mints / accounts.
790    fn init_account_type(&mut self) -> Result<(), ProgramError> {
791        let first_extension_type = self.get_first_extension_type()?;
792        let account_type = self.get_account_type_mut();
793        if !account_type.is_empty() {
794            if let Some(extension_type) = first_extension_type {
795                let account_type = extension_type.get_account_type();
796                if account_type != S::ACCOUNT_TYPE {
797                    return Err(TokenError::ExtensionBaseMismatch.into());
798                }
799            }
800            account_type[0] = S::ACCOUNT_TYPE.into();
801        }
802        Ok(())
803    }
804
805    /// Check that the account type on the account (if initialized) matches the
806    /// account type for any extensions initialized on the TLV data
807    fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
808        if let Some(extension_type) = self.get_first_extension_type()? {
809            let account_type = extension_type.get_account_type();
810            if account_type != S::ACCOUNT_TYPE {
811                return Err(TokenError::ExtensionBaseMismatch.into());
812            }
813        }
814        Ok(())
815    }
816}
817
818/// Encapsulates mutable base state data (mint or account) with possible
819/// extensions
820#[derive(Debug, PartialEq)]
821pub struct StateWithExtensionsMut<'data, S: BaseState> {
822    /// Unpacked base data
823    pub base: S,
824    /// Raw base data
825    base_data: &'data mut [u8],
826    /// Writable account type
827    account_type: &'data mut [u8],
828    /// Slice of data containing all TLV data, deserialized on demand
829    tlv_data: &'data mut [u8],
830}
831impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
832    /// Unpack base state, leaving the extension data as a mutable slice
833    ///
834    /// Fails if the base state is not initialized.
835    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
836        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
837        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
838        let base = S::unpack(base_data)?;
839        let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
840        Ok(Self {
841            base,
842            base_data,
843            account_type,
844            tlv_data,
845        })
846    }
847
848    /// Unpack an uninitialized base state, leaving the extension data as a
849    /// mutable slice
850    ///
851    /// Fails if the base state has already been initialized.
852    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
853        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
854        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
855        let base = S::unpack_unchecked(base_data)?;
856        if base.is_initialized() {
857            return Err(TokenError::AlreadyInUse.into());
858        }
859        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
860        let state = Self {
861            base,
862            base_data,
863            account_type,
864            tlv_data,
865        };
866        state.check_account_type_matches_extension_type()?;
867        Ok(state)
868    }
869
870    /// Packs base state data into the base data portion
871    pub fn pack_base(&mut self) {
872        S::pack_into_slice(&self.base, self.base_data);
873    }
874}
875impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
876    fn get_tlv_data(&self) -> &[u8] {
877        self.tlv_data
878    }
879}
880impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
881    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
882        self.tlv_data
883    }
884    fn get_account_type_mut(&mut self) -> &mut [u8] {
885        self.account_type
886    }
887}
888
889/// Encapsulates mutable base state data (mint or account) with possible
890/// extensions, where the base state is Pod for zero-copy serde.
891#[derive(Debug, PartialEq)]
892pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
893    /// Unpacked base data
894    pub base: &'data mut S,
895    /// Writable account type
896    account_type: &'data mut [u8],
897    /// Slice of data containing all TLV data, deserialized on demand
898    tlv_data: &'data mut [u8],
899}
900impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
901    /// Unpack base state, leaving the extension data as a mutable slice
902    ///
903    /// Fails if the base state is not initialized.
904    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
905        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
906        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
907        let base = pod_from_bytes_mut::<S>(base_data)?;
908        if !base.is_initialized() {
909            Err(ProgramError::UninitializedAccount)
910        } else {
911            let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
912            Ok(Self {
913                base,
914                account_type,
915                tlv_data,
916            })
917        }
918    }
919
920    /// Unpack an uninitialized base state, leaving the extension data as a
921    /// mutable slice
922    ///
923    /// Fails if the base state has already been initialized.
924    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
925        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
926        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
927        let base = pod_from_bytes_mut::<S>(base_data)?;
928        if base.is_initialized() {
929            return Err(TokenError::AlreadyInUse.into());
930        }
931        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
932        let state = Self {
933            base,
934            account_type,
935            tlv_data,
936        };
937        state.check_account_type_matches_extension_type()?;
938        Ok(state)
939    }
940}
941
942impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
943    fn get_tlv_data(&self) -> &[u8] {
944        self.tlv_data
945    }
946}
947impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
948    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
949        self.tlv_data
950    }
951    fn get_account_type_mut(&mut self) -> &mut [u8] {
952        self.account_type
953    }
954}
955
956fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
957    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
958        // type_and_tlv_indices() checks that returned indexes are within range
959        let account_type = AccountType::try_from(rest[account_type_index])
960            .map_err(|_| ProgramError::InvalidAccountData)?;
961        check_account_type::<S>(account_type)?;
962        Ok(&rest[tlv_start_index..])
963    } else {
964        Ok(&[])
965    }
966}
967
968fn unpack_type_and_tlv_data_with_check_mut<
969    S: BaseState,
970    F: Fn(AccountType) -> Result<(), ProgramError>,
971>(
972    rest: &mut [u8],
973    check_fn: F,
974) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
975    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
976        // type_and_tlv_indices() checks that returned indexes are within range
977        let account_type = AccountType::try_from(rest[account_type_index])
978            .map_err(|_| ProgramError::InvalidAccountData)?;
979        check_fn(account_type)?;
980        let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
981        Ok((
982            &mut account_type[account_type_index..tlv_start_index],
983            tlv_data,
984        ))
985    } else {
986        Ok((&mut [], &mut []))
987    }
988}
989
990fn unpack_type_and_tlv_data_mut<S: BaseState>(
991    rest: &mut [u8],
992) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
993    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
994}
995
996fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
997    rest: &mut [u8],
998) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
999    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1000        if account_type != AccountType::Uninitialized {
1001            Err(ProgramError::InvalidAccountData)
1002        } else {
1003            Ok(())
1004        }
1005    })
1006}
1007
1008/// If `AccountType` is uninitialized, set it to the `BaseState`'s
1009/// `ACCOUNT_TYPE`; if `AccountType` is already set, check is set correctly for
1010/// `BaseState`. This method assumes that the `base_data` has already been
1011/// packed with data of the desired type.
1012pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1013    check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1014    let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1015    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1016        return Err(ProgramError::InvalidAccountData);
1017    }
1018    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1019        let mut account_type = AccountType::try_from(rest[account_type_index])
1020            .map_err(|_| ProgramError::InvalidAccountData)?;
1021        if account_type == AccountType::Uninitialized {
1022            rest[account_type_index] = S::ACCOUNT_TYPE.into();
1023            account_type = S::ACCOUNT_TYPE;
1024        }
1025        check_account_type::<S>(account_type)?;
1026        Ok(())
1027    } else {
1028        Err(ProgramError::InvalidAccountData)
1029    }
1030}
1031
1032/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig`
1033/// types are determined exclusively by the size of the account, and are not
1034/// included in the account data. `AccountType` is only included if extensions
1035/// have been initialized.
1036#[repr(u8)]
1037#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1038pub enum AccountType {
1039    /// Marker for 0 data
1040    Uninitialized,
1041    /// Mint account with additional extensions
1042    Mint,
1043    /// Token holding account with additional extensions
1044    Account,
1045}
1046impl Default for AccountType {
1047    fn default() -> Self {
1048        Self::Uninitialized
1049    }
1050}
1051
1052/// Extensions that can be applied to mints or accounts.  Mint extensions must
1053/// only be applied to mint accounts, and account extensions must only be
1054/// applied to token holding accounts.
1055#[repr(u16)]
1056#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
1057#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
1058#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1059pub enum ExtensionType {
1060    /// Used as padding if the account size would otherwise be 355, same as a
1061    /// multisig
1062    Uninitialized,
1063    /// Includes transfer fee rate info and accompanying authorities to withdraw
1064    /// and set the fee
1065    TransferFeeConfig,
1066    /// Includes withheld transfer fees
1067    TransferFeeAmount,
1068    /// Includes an optional mint close authority
1069    MintCloseAuthority,
1070    /// Auditor configuration for confidential transfers
1071    ConfidentialTransferMint,
1072    /// State for confidential transfers
1073    ConfidentialTransferAccount,
1074    /// Specifies the default Account::state for new Accounts
1075    DefaultAccountState,
1076    /// Indicates that the Account owner authority cannot be changed
1077    ImmutableOwner,
1078    /// Require inbound transfers to have memo
1079    MemoTransfer,
1080    /// Indicates that the tokens from this mint can't be transferred
1081    NonTransferable,
1082    /// Tokens accrue interest over time,
1083    InterestBearingConfig,
1084    /// Locks privileged token operations from happening via CPI
1085    CpiGuard,
1086    /// Includes an optional permanent delegate
1087    PermanentDelegate,
1088    /// Indicates that the tokens in this account belong to a non-transferable
1089    /// mint
1090    NonTransferableAccount,
1091    /// Mint requires a CPI to a program implementing the "transfer hook"
1092    /// interface
1093    TransferHook,
1094    /// Indicates that the tokens in this account belong to a mint with a
1095    /// transfer hook
1096    TransferHookAccount,
1097    /// Includes encrypted withheld fees and the encryption public that they are
1098    /// encrypted under
1099    ConfidentialTransferFeeConfig,
1100    /// Includes confidential withheld transfer fees
1101    ConfidentialTransferFeeAmount,
1102    /// Mint contains a pointer to another account (or the same account) that
1103    /// holds metadata
1104    MetadataPointer,
1105    /// Mint contains token-metadata
1106    TokenMetadata,
1107    /// Mint contains a pointer to another account (or the same account) that
1108    /// holds group configurations
1109    GroupPointer,
1110    /// Mint contains token group configurations
1111    TokenGroup,
1112    /// Mint contains a pointer to another account (or the same account) that
1113    /// holds group member configurations
1114    GroupMemberPointer,
1115    /// Mint contains token group member configurations
1116    TokenGroupMember,
1117    /// Mint allowing the minting and burning of confidential tokens
1118    ConfidentialMintBurn,
1119    /// Tokens whose UI amount is scaled by a given amount
1120    ScaledUiAmount,
1121    /// Tokens where minting / burning / transferring can be paused
1122    Pausable,
1123    /// Indicates that the account belongs to a pausable mint
1124    PausableAccount,
1125
1126    /// Test variable-length mint extension
1127    #[cfg(test)]
1128    VariableLenMintTest = u16::MAX - 2,
1129    /// Padding extension used to make an account exactly Multisig::LEN, used
1130    /// for testing
1131    #[cfg(test)]
1132    AccountPaddingTest,
1133    /// Padding extension used to make a mint exactly Multisig::LEN, used for
1134    /// testing
1135    #[cfg(test)]
1136    MintPaddingTest,
1137}
1138impl TryFrom<&[u8]> for ExtensionType {
1139    type Error = ProgramError;
1140    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1141        Self::try_from(u16::from_le_bytes(
1142            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1143        ))
1144        .map_err(|_| ProgramError::InvalidAccountData)
1145    }
1146}
1147impl From<ExtensionType> for [u8; 2] {
1148    fn from(a: ExtensionType) -> Self {
1149        u16::from(a).to_le_bytes()
1150    }
1151}
1152impl ExtensionType {
1153    /// Returns true if the given extension type is sized
1154    ///
1155    /// Most extension types should be sized, so any variable-length extension
1156    /// types should be added here by hand
1157    const fn sized(&self) -> bool {
1158        match self {
1159            ExtensionType::TokenMetadata => false,
1160            #[cfg(test)]
1161            ExtensionType::VariableLenMintTest => false,
1162            _ => true,
1163        }
1164    }
1165
1166    /// Get the data length of the type associated with the enum
1167    ///
1168    /// Fails if the extension type has a variable length
1169    fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1170        if !self.sized() {
1171            return Err(ProgramError::InvalidArgument);
1172        }
1173        Ok(match self {
1174            ExtensionType::Uninitialized => 0,
1175            ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
1176            ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
1177            ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
1178            ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
1179            ExtensionType::ConfidentialTransferMint => {
1180                pod_get_packed_len::<ConfidentialTransferMint>()
1181            }
1182            ExtensionType::ConfidentialTransferAccount => {
1183                pod_get_packed_len::<ConfidentialTransferAccount>()
1184            }
1185            ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
1186            ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
1187            ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
1188            ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
1189            ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
1190            ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
1191            ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
1192            ExtensionType::TransferHook => pod_get_packed_len::<TransferHook>(),
1193            ExtensionType::TransferHookAccount => pod_get_packed_len::<TransferHookAccount>(),
1194            ExtensionType::ConfidentialTransferFeeConfig => {
1195                pod_get_packed_len::<ConfidentialTransferFeeConfig>()
1196            }
1197            ExtensionType::ConfidentialTransferFeeAmount => {
1198                pod_get_packed_len::<ConfidentialTransferFeeAmount>()
1199            }
1200            ExtensionType::MetadataPointer => pod_get_packed_len::<MetadataPointer>(),
1201            ExtensionType::TokenMetadata => unreachable!(),
1202            ExtensionType::GroupPointer => pod_get_packed_len::<GroupPointer>(),
1203            ExtensionType::TokenGroup => pod_get_packed_len::<TokenGroup>(),
1204            ExtensionType::GroupMemberPointer => pod_get_packed_len::<GroupMemberPointer>(),
1205            ExtensionType::TokenGroupMember => pod_get_packed_len::<TokenGroupMember>(),
1206            ExtensionType::ConfidentialMintBurn => pod_get_packed_len::<ConfidentialMintBurn>(),
1207            ExtensionType::ScaledUiAmount => pod_get_packed_len::<ScaledUiAmountConfig>(),
1208            ExtensionType::Pausable => pod_get_packed_len::<PausableConfig>(),
1209            ExtensionType::PausableAccount => pod_get_packed_len::<PausableAccount>(),
1210            #[cfg(test)]
1211            ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
1212            #[cfg(test)]
1213            ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
1214            #[cfg(test)]
1215            ExtensionType::VariableLenMintTest => unreachable!(),
1216        })
1217    }
1218
1219    /// Get the TLV length for an `ExtensionType`
1220    ///
1221    /// Fails if the extension type has a variable length
1222    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1223        Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1224    }
1225
1226    /// Get the TLV length for a set of `ExtensionType`s
1227    ///
1228    /// Fails if any of the extension types has a variable length
1229    fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1230        // dedupe extensions
1231        let mut extensions = vec![];
1232        for extension_type in extension_types {
1233            if !extensions.contains(&extension_type) {
1234                extensions.push(extension_type);
1235            }
1236        }
1237        extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1238    }
1239
1240    /// Get the required account data length for the given `ExtensionType`s
1241    ///
1242    /// Fails if any of the extension types has a variable length
1243    pub fn try_calculate_account_len<S: BaseState>(
1244        extension_types: &[Self],
1245    ) -> Result<usize, ProgramError> {
1246        if extension_types.is_empty() {
1247            Ok(S::SIZE_OF)
1248        } else {
1249            let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1250            let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1251            Ok(adjust_len_for_multisig(total_len))
1252        }
1253    }
1254
1255    /// Get the associated account type
1256    pub fn get_account_type(&self) -> AccountType {
1257        match self {
1258            ExtensionType::Uninitialized => AccountType::Uninitialized,
1259            ExtensionType::TransferFeeConfig
1260            | ExtensionType::MintCloseAuthority
1261            | ExtensionType::ConfidentialTransferMint
1262            | ExtensionType::DefaultAccountState
1263            | ExtensionType::NonTransferable
1264            | ExtensionType::InterestBearingConfig
1265            | ExtensionType::PermanentDelegate
1266            | ExtensionType::TransferHook
1267            | ExtensionType::ConfidentialTransferFeeConfig
1268            | ExtensionType::MetadataPointer
1269            | ExtensionType::TokenMetadata
1270            | ExtensionType::GroupPointer
1271            | ExtensionType::TokenGroup
1272            | ExtensionType::GroupMemberPointer
1273            | ExtensionType::ConfidentialMintBurn
1274            | ExtensionType::TokenGroupMember
1275            | ExtensionType::ScaledUiAmount
1276            | ExtensionType::Pausable => AccountType::Mint,
1277            ExtensionType::ImmutableOwner
1278            | ExtensionType::TransferFeeAmount
1279            | ExtensionType::ConfidentialTransferAccount
1280            | ExtensionType::MemoTransfer
1281            | ExtensionType::NonTransferableAccount
1282            | ExtensionType::TransferHookAccount
1283            | ExtensionType::CpiGuard
1284            | ExtensionType::ConfidentialTransferFeeAmount
1285            | ExtensionType::PausableAccount => AccountType::Account,
1286            #[cfg(test)]
1287            ExtensionType::VariableLenMintTest => AccountType::Mint,
1288            #[cfg(test)]
1289            ExtensionType::AccountPaddingTest => AccountType::Account,
1290            #[cfg(test)]
1291            ExtensionType::MintPaddingTest => AccountType::Mint,
1292        }
1293    }
1294
1295    /// Based on a set of `AccountType::Mint` `ExtensionType`s, get the list of
1296    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1297    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1298        let mut account_extension_types = vec![];
1299        for extension_type in mint_extension_types {
1300            match extension_type {
1301                ExtensionType::TransferFeeConfig => {
1302                    account_extension_types.push(ExtensionType::TransferFeeAmount);
1303                }
1304                ExtensionType::NonTransferable => {
1305                    account_extension_types.push(ExtensionType::NonTransferableAccount);
1306                    account_extension_types.push(ExtensionType::ImmutableOwner);
1307                }
1308                ExtensionType::TransferHook => {
1309                    account_extension_types.push(ExtensionType::TransferHookAccount);
1310                }
1311                ExtensionType::Pausable => {
1312                    account_extension_types.push(ExtensionType::PausableAccount);
1313                }
1314                #[cfg(test)]
1315                ExtensionType::MintPaddingTest => {
1316                    account_extension_types.push(ExtensionType::AccountPaddingTest);
1317                }
1318                _ => {}
1319            }
1320        }
1321        account_extension_types
1322    }
1323
1324    /// Check for invalid combination of mint extensions
1325    pub fn check_for_invalid_mint_extension_combinations(
1326        mint_extension_types: &[Self],
1327    ) -> Result<(), TokenError> {
1328        let mut transfer_fee_config = false;
1329        let mut confidential_transfer_mint = false;
1330        let mut confidential_transfer_fee_config = false;
1331        let mut confidential_mint_burn = false;
1332        let mut interest_bearing = false;
1333        let mut scaled_ui_amount = false;
1334
1335        for extension_type in mint_extension_types {
1336            match extension_type {
1337                ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1338                ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1339                ExtensionType::ConfidentialTransferFeeConfig => {
1340                    confidential_transfer_fee_config = true
1341                }
1342                ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1343                ExtensionType::InterestBearingConfig => interest_bearing = true,
1344                ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1345                _ => (),
1346            }
1347        }
1348
1349        if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1350        {
1351            return Err(TokenError::InvalidExtensionCombination);
1352        }
1353
1354        if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1355            return Err(TokenError::InvalidExtensionCombination);
1356        }
1357
1358        if confidential_mint_burn && !confidential_transfer_mint {
1359            return Err(TokenError::InvalidExtensionCombination);
1360        }
1361
1362        if scaled_ui_amount && interest_bearing {
1363            return Err(TokenError::InvalidExtensionCombination);
1364        }
1365
1366        Ok(())
1367    }
1368}
1369
1370/// Trait for base states, specifying the associated enum
1371pub trait BaseState: PackedSizeOf + IsInitialized {
1372    /// Associated extension type enum, checked at the start of TLV entries
1373    const ACCOUNT_TYPE: AccountType;
1374}
1375impl BaseState for Account {
1376    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1377}
1378impl BaseState for Mint {
1379    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1380}
1381impl BaseState for PodAccount {
1382    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1383}
1384impl BaseState for PodMint {
1385    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1386}
1387
1388/// Trait to be implemented by all extension states, specifying which extension
1389/// and account type they are associated with
1390pub trait Extension {
1391    /// Associated extension type enum, checked at the start of TLV entries
1392    const TYPE: ExtensionType;
1393}
1394
1395/// Padding a mint account to be exactly `Multisig::LEN`.
1396/// We need to pad 185 bytes, since `Multisig::LEN = 355`, `Account::LEN = 165`,
1397/// `size_of::<AccountType>() = 1`, `size_of::<ExtensionType>() = 2`,
1398/// `size_of::<Length>() = 2`.
1399///
1400/// ```
1401/// assert_eq!(355 - 165 - 1 - 2 - 2, 185);
1402/// ```
1403#[cfg(test)]
1404#[repr(C)]
1405#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1406pub struct MintPaddingTest {
1407    /// Largest value under 185 that implements Pod
1408    pub padding1: [u8; 128],
1409    /// Largest value under 57 that implements Pod
1410    pub padding2: [u8; 48],
1411    /// Exact value needed to finish the padding
1412    pub padding3: [u8; 9],
1413}
1414#[cfg(test)]
1415impl Extension for MintPaddingTest {
1416    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1417}
1418#[cfg(test)]
1419impl Default for MintPaddingTest {
1420    fn default() -> Self {
1421        Self {
1422            padding1: [1; 128],
1423            padding2: [2; 48],
1424            padding3: [3; 9],
1425        }
1426    }
1427}
1428/// Account version of the `MintPadding`
1429#[cfg(test)]
1430#[repr(C)]
1431#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1432pub struct AccountPaddingTest(MintPaddingTest);
1433#[cfg(test)]
1434impl Extension for AccountPaddingTest {
1435    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1436}
1437
1438/// Packs a fixed-length extension into a TLV space
1439///
1440/// This function reallocates the account as needed to accommodate for the
1441/// change in space.
1442///
1443/// If the extension already exists, it will overwrite the existing extension
1444/// if `overwrite` is `true`, otherwise it will return an error.
1445///
1446/// If the extension does not exist, it will reallocate the account and write
1447/// the extension into the TLV buffer.
1448///
1449/// NOTE: Since this function deals with fixed-size extensions, it does not
1450/// handle _decreasing_ the size of an account's data buffer, like the function
1451/// `alloc_and_serialize_variable_len_extension` does.
1452pub(crate) fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1453    account_info: &AccountInfo,
1454    new_extension: &V,
1455    overwrite: bool,
1456) -> Result<(), ProgramError> {
1457    let previous_account_len = account_info.try_data_len()?;
1458    let new_account_len = {
1459        let data = account_info.try_borrow_data()?;
1460        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1461        state.try_get_new_account_len::<V>()?
1462    };
1463
1464    // Realloc the account first, if needed
1465    if new_account_len > previous_account_len {
1466        account_info.realloc(new_account_len, false)?;
1467    }
1468    let mut buffer = account_info.try_borrow_mut_data()?;
1469    if previous_account_len <= BASE_ACCOUNT_LENGTH {
1470        set_account_type::<S>(*buffer)?;
1471    }
1472    let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1473
1474    // Write the extension
1475    let extension = state.init_extension::<V>(overwrite)?;
1476    *extension = *new_extension;
1477
1478    Ok(())
1479}
1480
1481/// Packs a variable-length extension into a TLV space
1482///
1483/// This function reallocates the account as needed to accommodate for the
1484/// change in space, then reallocates in the TLV buffer, and finally writes the
1485/// bytes.
1486///
1487/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1488/// size of an account if it has too many bytes allocated for the given value.
1489pub(crate) fn alloc_and_serialize_variable_len_extension<
1490    S: BaseState + Pod,
1491    V: Extension + VariableLenPack,
1492>(
1493    account_info: &AccountInfo,
1494    new_extension: &V,
1495    overwrite: bool,
1496) -> Result<(), ProgramError> {
1497    let previous_account_len = account_info.try_data_len()?;
1498    let (new_account_len, extension_already_exists) = {
1499        let data = account_info.try_borrow_data()?;
1500        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1501        let new_account_len =
1502            state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1503        let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1504        (new_account_len, extension_already_exists)
1505    };
1506
1507    if extension_already_exists && !overwrite {
1508        return Err(TokenError::ExtensionAlreadyInitialized.into());
1509    }
1510
1511    if previous_account_len < new_account_len {
1512        // account size increased, so realloc the account, then the TLV entry, then
1513        // write data
1514        account_info.realloc(new_account_len, false)?;
1515        let mut buffer = account_info.try_borrow_mut_data()?;
1516        if extension_already_exists {
1517            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1518            state.realloc_variable_len_extension(new_extension)?;
1519        } else {
1520            if previous_account_len <= BASE_ACCOUNT_LENGTH {
1521                set_account_type::<S>(*buffer)?;
1522            }
1523            // now alloc in the TLV buffer and write the data
1524            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1525            state.init_variable_len_extension(new_extension, false)?;
1526        }
1527    } else {
1528        // do it backwards otherwise, write the state, realloc TLV, then the account
1529        let mut buffer = account_info.try_borrow_mut_data()?;
1530        let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1531        if extension_already_exists {
1532            state.realloc_variable_len_extension(new_extension)?;
1533        } else {
1534            // this situation can happen if we have an overallocated buffer
1535            state.init_variable_len_extension(new_extension, false)?;
1536        }
1537
1538        let removed_bytes = previous_account_len
1539            .checked_sub(new_account_len)
1540            .ok_or(ProgramError::AccountDataTooSmall)?;
1541        if removed_bytes > 0 {
1542            // this is probably fine, but be safe and avoid invalidating references
1543            drop(buffer);
1544            account_info.realloc(new_account_len, false)?;
1545        }
1546    }
1547    Ok(())
1548}
1549
1550#[cfg(test)]
1551mod test {
1552    use {
1553        super::*,
1554        crate::{
1555            pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1556            state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1557        },
1558        bytemuck::Pod,
1559        solana_account_info::{
1560            Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1561        },
1562        solana_clock::Epoch,
1563        solana_pubkey::Pubkey,
1564        spl_pod::{
1565            bytemuck::pod_bytes_of,
1566            optional_keys::OptionalNonZeroPubkey,
1567            primitives::{PodBool, PodU64},
1568        },
1569        transfer_fee::test::test_transfer_fee_config,
1570    };
1571
1572    /// Test fixed-length struct
1573    #[repr(C)]
1574    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1575    struct FixedLenMintTest {
1576        data: [u8; 8],
1577    }
1578    impl Extension for FixedLenMintTest {
1579        const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1580    }
1581
1582    /// Test variable-length struct
1583    #[derive(Clone, Debug, PartialEq)]
1584    struct VariableLenMintTest {
1585        data: Vec<u8>,
1586    }
1587    impl Extension for VariableLenMintTest {
1588        const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1589    }
1590    impl VariableLenPack for VariableLenMintTest {
1591        fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1592            let data_start = size_of::<u64>();
1593            let end = data_start + self.data.len();
1594            if dst.len() < end {
1595                Err(ProgramError::InvalidAccountData)
1596            } else {
1597                dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1598                dst[data_start..end].copy_from_slice(&self.data);
1599                Ok(())
1600            }
1601        }
1602        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1603            let data_start = size_of::<u64>();
1604            let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1605            if src[data_start..data_start + length].len() != length {
1606                return Err(ProgramError::InvalidAccountData);
1607            }
1608            let data = Vec::from(&src[data_start..data_start + length]);
1609            Ok(Self { data })
1610        }
1611        fn get_packed_len(&self) -> Result<usize, ProgramError> {
1612            Ok(size_of::<u64>().saturating_add(self.data.len()))
1613        }
1614    }
1615
1616    const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1617        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1618        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1619        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1620        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1621        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1622        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1623        1, // account type
1624    ];
1625
1626    const MINT_WITH_EXTENSION: &[u8] = &[
1627        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1628        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1629        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1630        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1631        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1632        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1633        1, // account type
1634        3, 0, // extension type
1635        32, 0, // length
1636        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1637        1, 1, // data
1638    ];
1639
1640    const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1641        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1642        1, 1, // mint
1643        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1644        2, 2, // owner
1645        3, 0, 0, 0, 0, 0, 0, 0, // amount
1646        1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1647        4, 4, 4, 4, 4, 4, // delegate
1648        2, // account state
1649        1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, // is native
1650        6, 0, 0, 0, 0, 0, 0, 0, // delegated amount
1651        1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1652        7, 7, 7, 7, 7, 7, // close authority
1653        2, // account type
1654        15, 0, // extension type
1655        1, 0, // length
1656        1, // data
1657    ];
1658
1659    #[test]
1660    fn unpack_opaque_buffer() {
1661        // Mint
1662        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1663        assert_eq!(state.base, &TEST_POD_MINT);
1664        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1665        assert_eq!(state.base, &TEST_POD_MINT);
1666        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1667        let close_authority =
1668            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1669        assert_eq!(extension.close_authority, close_authority);
1670        assert_eq!(
1671            state.get_extension::<TransferFeeConfig>(),
1672            Err(ProgramError::InvalidAccountData)
1673        );
1674        assert_eq!(
1675            PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1676            Err(ProgramError::UninitializedAccount)
1677        );
1678
1679        let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1680        assert_eq!(state.base, &TEST_POD_MINT);
1681
1682        let mut test_mint = TEST_MINT_SLICE.to_vec();
1683        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1684        assert_eq!(state.base, &TEST_POD_MINT);
1685
1686        // Account
1687        let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1688        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1689        let extension = state.get_extension::<TransferHookAccount>().unwrap();
1690        let transferring = PodBool::from(true);
1691        assert_eq!(extension.transferring, transferring);
1692        assert_eq!(
1693            PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1694            Err(ProgramError::InvalidAccountData)
1695        );
1696
1697        let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1698        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1699
1700        let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1701        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1702        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1703    }
1704
1705    #[test]
1706    fn mint_fail_unpack_opaque_buffer() {
1707        // input buffer too small
1708        let mut buffer = vec![0, 3];
1709        assert_eq!(
1710            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1711            Err(ProgramError::InvalidAccountData)
1712        );
1713        assert_eq!(
1714            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1715            Err(ProgramError::InvalidAccountData)
1716        );
1717        assert_eq!(
1718            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1719            Err(ProgramError::InvalidAccountData)
1720        );
1721
1722        // tweak the account type
1723        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1724        buffer[BASE_ACCOUNT_LENGTH] = 3;
1725        assert_eq!(
1726            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1727            Err(ProgramError::InvalidAccountData)
1728        );
1729
1730        // clear the mint initialized byte
1731        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1732        buffer[45] = 0;
1733        assert_eq!(
1734            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1735            Err(ProgramError::UninitializedAccount)
1736        );
1737
1738        // tweak the padding
1739        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1740        buffer[PodMint::SIZE_OF] = 100;
1741        assert_eq!(
1742            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1743            Err(ProgramError::InvalidAccountData)
1744        );
1745
1746        // tweak the extension type
1747        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1748        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1749        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1750        assert_eq!(
1751            state.get_extension::<TransferFeeConfig>(),
1752            Err(ProgramError::Custom(
1753                TokenError::ExtensionTypeMismatch as u32
1754            ))
1755        );
1756
1757        // tweak the length, too big
1758        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1759        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1760        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1761        assert_eq!(
1762            state.get_extension::<TransferFeeConfig>(),
1763            Err(ProgramError::InvalidAccountData)
1764        );
1765
1766        // tweak the length, too small
1767        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1768        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1769        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1770        assert_eq!(
1771            state.get_extension::<TransferFeeConfig>(),
1772            Err(ProgramError::InvalidAccountData)
1773        );
1774
1775        // data buffer is too small
1776        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1777        let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1778        assert_eq!(
1779            state.get_extension::<MintCloseAuthority>(),
1780            Err(ProgramError::InvalidAccountData)
1781        );
1782    }
1783
1784    #[test]
1785    fn account_fail_unpack_opaque_buffer() {
1786        // input buffer too small
1787        let mut buffer = vec![0, 3];
1788        assert_eq!(
1789            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1790            Err(ProgramError::InvalidAccountData)
1791        );
1792        assert_eq!(
1793            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1794            Err(ProgramError::InvalidAccountData)
1795        );
1796        assert_eq!(
1797            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1798            Err(ProgramError::InvalidAccountData)
1799        );
1800
1801        // input buffer invalid
1802        // all 5's - not a valid `AccountState`
1803        let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1804        assert_eq!(
1805            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1806            Err(ProgramError::UninitializedAccount)
1807        );
1808        assert_eq!(
1809            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1810            Err(ProgramError::UninitializedAccount)
1811        );
1812
1813        // tweak the account type
1814        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1815        buffer[BASE_ACCOUNT_LENGTH] = 3;
1816        assert_eq!(
1817            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1818            Err(ProgramError::InvalidAccountData)
1819        );
1820
1821        // clear the state byte
1822        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1823        buffer[108] = 0;
1824        assert_eq!(
1825            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1826            Err(ProgramError::UninitializedAccount)
1827        );
1828
1829        // tweak the extension type
1830        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1831        buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1832        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1833        assert_eq!(
1834            state.get_extension::<TransferHookAccount>(),
1835            Err(ProgramError::Custom(
1836                TokenError::ExtensionTypeMismatch as u32
1837            ))
1838        );
1839
1840        // tweak the length, too big
1841        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1842        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1843        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1844        assert_eq!(
1845            state.get_extension::<TransferHookAccount>(),
1846            Err(ProgramError::InvalidAccountData)
1847        );
1848
1849        // tweak the length, too small
1850        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1851        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1852        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1853        assert_eq!(
1854            state.get_extension::<TransferHookAccount>(),
1855            Err(ProgramError::InvalidAccountData)
1856        );
1857
1858        // data buffer is too small
1859        let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1860        let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1861        assert_eq!(
1862            state.get_extension::<TransferHookAccount>(),
1863            Err(ProgramError::InvalidAccountData)
1864        );
1865    }
1866
1867    #[test]
1868    fn get_extension_types_with_opaque_buffer() {
1869        // incorrect due to the length
1870        assert_eq!(
1871            get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1872            ProgramError::InvalidAccountData,
1873        );
1874        // incorrect due to the huge enum number
1875        assert_eq!(
1876            get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1877            ProgramError::InvalidAccountData,
1878        );
1879        // correct due to the good enum number and zero length
1880        assert_eq!(
1881            get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1882            TlvDataInfo {
1883                extension_types: vec![ExtensionType::try_from(1).unwrap()],
1884                used_len: add_type_and_length_to_len(0),
1885            }
1886        );
1887        // correct since it's just uninitialized data at the end
1888        assert_eq!(
1889            get_tlv_data_info(&[0, 0]).unwrap(),
1890            TlvDataInfo {
1891                extension_types: vec![],
1892                used_len: 0
1893            }
1894        );
1895    }
1896
1897    #[test]
1898    fn mint_with_extension_pack_unpack() {
1899        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1900            ExtensionType::MintCloseAuthority,
1901            ExtensionType::TransferFeeConfig,
1902        ])
1903        .unwrap();
1904        let mut buffer = vec![0; mint_size];
1905
1906        // fail unpack
1907        assert_eq!(
1908            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1909            Err(ProgramError::UninitializedAccount),
1910        );
1911
1912        let mut state =
1913            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1914        // fail init account extension
1915        assert_eq!(
1916            state.init_extension::<TransferFeeAmount>(true),
1917            Err(ProgramError::InvalidAccountData),
1918        );
1919
1920        // success write extension
1921        let close_authority =
1922            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1923        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1924        extension.close_authority = close_authority;
1925        assert_eq!(
1926            &state.get_extension_types().unwrap(),
1927            &[ExtensionType::MintCloseAuthority]
1928        );
1929
1930        // fail init extension when already initialized
1931        assert_eq!(
1932            state.init_extension::<MintCloseAuthority>(false),
1933            Err(ProgramError::Custom(
1934                TokenError::ExtensionAlreadyInitialized as u32
1935            ))
1936        );
1937
1938        // fail unpack as account, a mint extension was written
1939        assert_eq!(
1940            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1941            Err(ProgramError::Custom(
1942                TokenError::ExtensionBaseMismatch as u32
1943            ))
1944        );
1945
1946        // fail unpack again, still no base data
1947        assert_eq!(
1948            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1949            Err(ProgramError::UninitializedAccount),
1950        );
1951
1952        // write base mint
1953        let mut state =
1954            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1955        *state.base = TEST_POD_MINT;
1956        state.init_account_type().unwrap();
1957
1958        // check raw buffer
1959        let mut expect = TEST_MINT_SLICE.to_vec();
1960        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1961        expect.push(AccountType::Mint.into());
1962        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1963        expect
1964            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1965        expect.extend_from_slice(&[1; 32]); // data
1966        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1967        expect.extend_from_slice(&[0; size_of::<Length>()]);
1968        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1969        assert_eq!(expect, buffer);
1970
1971        // unpack uninitialized will now fail because the PodMint is now initialized
1972        assert_eq!(
1973            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1974            Err(TokenError::AlreadyInUse.into()),
1975        );
1976
1977        // check unpacking
1978        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1979
1980        // update base
1981        *state.base = TEST_POD_MINT;
1982        state.base.supply = (u64::from(state.base.supply) + 100).into();
1983
1984        // check unpacking
1985        let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1986        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1987
1988        // update extension
1989        let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1990        unpacked_extension.close_authority = close_authority;
1991
1992        // check updates are propagated
1993        let base = *state.base;
1994        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1995        assert_eq!(state.base, &base);
1996        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1997        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1998
1999        // check raw buffer
2000        let mut expect = vec![];
2001        expect.extend_from_slice(bytemuck::bytes_of(&base));
2002        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2003        expect.push(AccountType::Mint.into());
2004        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2005        expect
2006            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2007        expect.extend_from_slice(&[0; 32]);
2008        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2009        expect.extend_from_slice(&[0; size_of::<Length>()]);
2010        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2011        assert_eq!(expect, buffer);
2012
2013        // fail unpack as an account
2014        assert_eq!(
2015            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2016            Err(ProgramError::UninitializedAccount),
2017        );
2018
2019        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2020        // init one more extension
2021        let mint_transfer_fee = test_transfer_fee_config();
2022        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2023        new_extension.transfer_fee_config_authority =
2024            mint_transfer_fee.transfer_fee_config_authority;
2025        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2026        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2027        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2028        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2029
2030        assert_eq!(
2031            &state.get_extension_types().unwrap(),
2032            &[
2033                ExtensionType::MintCloseAuthority,
2034                ExtensionType::TransferFeeConfig
2035            ]
2036        );
2037
2038        // check raw buffer
2039        let mut expect = vec![];
2040        expect.extend_from_slice(pod_bytes_of(&base));
2041        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2042        expect.push(AccountType::Mint.into());
2043        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2044        expect
2045            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2046        expect.extend_from_slice(&[0; 32]); // data
2047        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2048        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
2049        expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
2050        assert_eq!(expect, buffer);
2051
2052        // fail to init one more extension that does not fit
2053        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2054        assert_eq!(
2055            state.init_extension::<MintPaddingTest>(true),
2056            Err(ProgramError::InvalidAccountData),
2057        );
2058    }
2059
2060    #[test]
2061    fn mint_extension_any_order() {
2062        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2063            ExtensionType::MintCloseAuthority,
2064            ExtensionType::TransferFeeConfig,
2065        ])
2066        .unwrap();
2067        let mut buffer = vec![0; mint_size];
2068
2069        let mut state =
2070            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2071        // write extensions
2072        let close_authority =
2073            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2074        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2075        extension.close_authority = close_authority;
2076
2077        let mint_transfer_fee = test_transfer_fee_config();
2078        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2079        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2080        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2081        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2082        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2083        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2084
2085        assert_eq!(
2086            &state.get_extension_types().unwrap(),
2087            &[
2088                ExtensionType::MintCloseAuthority,
2089                ExtensionType::TransferFeeConfig
2090            ]
2091        );
2092
2093        // write base mint
2094        let mut state =
2095            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2096        *state.base = TEST_POD_MINT;
2097        state.init_account_type().unwrap();
2098
2099        let mut other_buffer = vec![0; mint_size];
2100        let mut state =
2101            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2102
2103        // write base mint
2104        *state.base = TEST_POD_MINT;
2105        state.init_account_type().unwrap();
2106
2107        // write extensions in a different order
2108        let mint_transfer_fee = test_transfer_fee_config();
2109        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2110        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2111        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2112        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2113        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2114        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2115
2116        let close_authority =
2117            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2118        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2119        extension.close_authority = close_authority;
2120
2121        assert_eq!(
2122            &state.get_extension_types().unwrap(),
2123            &[
2124                ExtensionType::TransferFeeConfig,
2125                ExtensionType::MintCloseAuthority
2126            ]
2127        );
2128
2129        // buffers are NOT the same because written in a different order
2130        assert_ne!(buffer, other_buffer);
2131        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2132        let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2133
2134        // BUT mint and extensions are the same
2135        assert_eq!(
2136            state.get_extension::<TransferFeeConfig>().unwrap(),
2137            other_state.get_extension::<TransferFeeConfig>().unwrap()
2138        );
2139        assert_eq!(
2140            state.get_extension::<MintCloseAuthority>().unwrap(),
2141            other_state.get_extension::<MintCloseAuthority>().unwrap()
2142        );
2143        assert_eq!(state.base, other_state.base);
2144    }
2145
2146    #[test]
2147    fn mint_with_multisig_len() {
2148        let mut buffer = vec![0; Multisig::LEN];
2149        assert_eq!(
2150            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2151            Err(ProgramError::InvalidAccountData),
2152        );
2153        let mint_size =
2154            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2155                .unwrap();
2156        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2157        let mut buffer = vec![0; mint_size];
2158
2159        // write base mint
2160        let mut state =
2161            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2162        *state.base = TEST_POD_MINT;
2163        state.init_account_type().unwrap();
2164
2165        // write padding
2166        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2167        extension.padding1 = [1; 128];
2168        extension.padding2 = [1; 48];
2169        extension.padding3 = [1; 9];
2170
2171        assert_eq!(
2172            &state.get_extension_types().unwrap(),
2173            &[ExtensionType::MintPaddingTest]
2174        );
2175
2176        // check raw buffer
2177        let mut expect = TEST_MINT_SLICE.to_vec();
2178        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2179        expect.push(AccountType::Mint.into());
2180        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2181        expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
2182        expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
2183        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2184        assert_eq!(expect, buffer);
2185    }
2186
2187    #[test]
2188    fn account_with_extension_pack_unpack() {
2189        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2190            ExtensionType::TransferFeeAmount,
2191        ])
2192        .unwrap();
2193        let mut buffer = vec![0; account_size];
2194
2195        // fail unpack
2196        assert_eq!(
2197            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2198            Err(ProgramError::UninitializedAccount),
2199        );
2200
2201        let mut state =
2202            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2203        // fail init mint extension
2204        assert_eq!(
2205            state.init_extension::<TransferFeeConfig>(true),
2206            Err(ProgramError::InvalidAccountData),
2207        );
2208        // success write extension
2209        let withheld_amount = PodU64::from(u64::MAX);
2210        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2211        extension.withheld_amount = withheld_amount;
2212
2213        assert_eq!(
2214            &state.get_extension_types().unwrap(),
2215            &[ExtensionType::TransferFeeAmount]
2216        );
2217
2218        // fail unpack again, still no base data
2219        assert_eq!(
2220            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2221            Err(ProgramError::UninitializedAccount),
2222        );
2223
2224        // write base account
2225        let mut state =
2226            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2227        *state.base = TEST_POD_ACCOUNT;
2228        state.init_account_type().unwrap();
2229        let base = *state.base;
2230
2231        // check raw buffer
2232        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2233        expect.push(AccountType::Account.into());
2234        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2235        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2236        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2237        assert_eq!(expect, buffer);
2238
2239        // check unpacking
2240        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2241        assert_eq!(state.base, &base);
2242        assert_eq!(
2243            &state.get_extension_types().unwrap(),
2244            &[ExtensionType::TransferFeeAmount]
2245        );
2246
2247        // update base
2248        *state.base = TEST_POD_ACCOUNT;
2249        state.base.amount = (u64::from(state.base.amount) + 100).into();
2250
2251        // check unpacking
2252        let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2253        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2254
2255        // update extension
2256        let withheld_amount = PodU64::from(u32::MAX as u64);
2257        unpacked_extension.withheld_amount = withheld_amount;
2258
2259        // check updates are propagated
2260        let base = *state.base;
2261        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2262        assert_eq!(state.base, &base);
2263        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2264        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2265
2266        // check raw buffer
2267        let mut expect = vec![];
2268        expect.extend_from_slice(pod_bytes_of(&base));
2269        expect.push(AccountType::Account.into());
2270        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2271        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2272        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2273        assert_eq!(expect, buffer);
2274
2275        // fail unpack as a mint
2276        assert_eq!(
2277            PodStateWithExtensions::<PodMint>::unpack(&buffer),
2278            Err(ProgramError::InvalidAccountData),
2279        );
2280    }
2281
2282    #[test]
2283    fn account_with_multisig_len() {
2284        let mut buffer = vec![0; Multisig::LEN];
2285        assert_eq!(
2286            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2287            Err(ProgramError::InvalidAccountData),
2288        );
2289        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2290            ExtensionType::AccountPaddingTest,
2291        ])
2292        .unwrap();
2293        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2294        let mut buffer = vec![0; account_size];
2295
2296        // write base account
2297        let mut state =
2298            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2299        *state.base = TEST_POD_ACCOUNT;
2300        state.init_account_type().unwrap();
2301
2302        // write padding
2303        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2304        extension.0.padding1 = [2; 128];
2305        extension.0.padding2 = [2; 48];
2306        extension.0.padding3 = [2; 9];
2307
2308        assert_eq!(
2309            &state.get_extension_types().unwrap(),
2310            &[ExtensionType::AccountPaddingTest]
2311        );
2312
2313        // check raw buffer
2314        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2315        expect.push(AccountType::Account.into());
2316        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2317        expect
2318            .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
2319        expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
2320        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2321        assert_eq!(expect, buffer);
2322    }
2323
2324    #[test]
2325    fn test_set_account_type() {
2326        // account with buffer big enough for AccountType and Extension
2327        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2328        let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2329            ExtensionType::ImmutableOwner,
2330        ])
2331        .unwrap()
2332            - buffer.len();
2333        buffer.append(&mut vec![0; needed_len]);
2334        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2335        assert_eq!(err, ProgramError::InvalidAccountData);
2336        set_account_type::<PodAccount>(&mut buffer).unwrap();
2337        // unpack is viable after manual set_account_type
2338        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2339        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2340        assert_eq!(state.account_type[0], AccountType::Account as u8);
2341        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
2342
2343        // account with buffer big enough for AccountType only
2344        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2345        buffer.append(&mut vec![0; 2]);
2346        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2347        assert_eq!(err, ProgramError::InvalidAccountData);
2348        set_account_type::<PodAccount>(&mut buffer).unwrap();
2349        // unpack is viable after manual set_account_type
2350        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2351        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2352        assert_eq!(state.account_type[0], AccountType::Account as u8);
2353
2354        // account with AccountType already set => noop
2355        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2356        buffer.append(&mut vec![2, 0]);
2357        let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2358        set_account_type::<PodAccount>(&mut buffer).unwrap();
2359        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2360        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2361        assert_eq!(state.account_type[0], AccountType::Account as u8);
2362
2363        // account with wrong AccountType fails
2364        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2365        buffer.append(&mut vec![1, 0]);
2366        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2367        assert_eq!(err, ProgramError::InvalidAccountData);
2368        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2369        assert_eq!(err, ProgramError::InvalidAccountData);
2370
2371        // mint with buffer big enough for AccountType and Extension
2372        let mut buffer = TEST_MINT_SLICE.to_vec();
2373        let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2374            ExtensionType::MintCloseAuthority,
2375        ])
2376        .unwrap()
2377            - buffer.len();
2378        buffer.append(&mut vec![0; needed_len]);
2379        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2380        assert_eq!(err, ProgramError::InvalidAccountData);
2381        set_account_type::<PodMint>(&mut buffer).unwrap();
2382        // unpack is viable after manual set_account_type
2383        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2384        assert_eq!(state.base, &TEST_POD_MINT);
2385        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2386        state.init_extension::<MintCloseAuthority>(true).unwrap();
2387
2388        // mint with buffer big enough for AccountType only
2389        let mut buffer = TEST_MINT_SLICE.to_vec();
2390        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2391        buffer.append(&mut vec![0; 2]);
2392        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2393        assert_eq!(err, ProgramError::InvalidAccountData);
2394        set_account_type::<PodMint>(&mut buffer).unwrap();
2395        // unpack is viable after manual set_account_type
2396        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2397        assert_eq!(state.base, &TEST_POD_MINT);
2398        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2399
2400        // mint with AccountType already set => noop
2401        let mut buffer = TEST_MINT_SLICE.to_vec();
2402        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2403        buffer.append(&mut vec![1, 0]);
2404        set_account_type::<PodMint>(&mut buffer).unwrap();
2405        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2406        assert_eq!(state.base, &TEST_POD_MINT);
2407        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2408
2409        // mint with wrong AccountType fails
2410        let mut buffer = TEST_MINT_SLICE.to_vec();
2411        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2412        buffer.append(&mut vec![2, 0]);
2413        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2414        assert_eq!(err, ProgramError::InvalidAccountData);
2415        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2416        assert_eq!(err, ProgramError::InvalidAccountData);
2417    }
2418
2419    #[test]
2420    fn test_set_account_type_wrongly() {
2421        // try to set PodAccount account_type to PodMint
2422        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2423        buffer.append(&mut vec![0; 2]);
2424        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2425        assert_eq!(err, ProgramError::InvalidAccountData);
2426
2427        // try to set PodMint account_type to PodAccount
2428        let mut buffer = TEST_MINT_SLICE.to_vec();
2429        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2430        buffer.append(&mut vec![0; 2]);
2431        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2432        assert_eq!(err, ProgramError::InvalidAccountData);
2433    }
2434
2435    #[test]
2436    fn test_get_required_init_account_extensions() {
2437        // Some mint extensions with no required account extensions
2438        let mint_extensions = vec![
2439            ExtensionType::MintCloseAuthority,
2440            ExtensionType::Uninitialized,
2441        ];
2442        assert_eq!(
2443            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2444            vec![]
2445        );
2446
2447        // One mint extension with required account extension, one without
2448        let mint_extensions = vec![
2449            ExtensionType::TransferFeeConfig,
2450            ExtensionType::MintCloseAuthority,
2451        ];
2452        assert_eq!(
2453            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2454            vec![ExtensionType::TransferFeeAmount]
2455        );
2456
2457        // Some mint extensions both with required account extensions
2458        let mint_extensions = vec![
2459            ExtensionType::TransferFeeConfig,
2460            ExtensionType::MintPaddingTest,
2461        ];
2462        assert_eq!(
2463            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2464            vec![
2465                ExtensionType::TransferFeeAmount,
2466                ExtensionType::AccountPaddingTest
2467            ]
2468        );
2469
2470        // Demonstrate that method does not dedupe inputs or outputs
2471        let mint_extensions = vec![
2472            ExtensionType::TransferFeeConfig,
2473            ExtensionType::TransferFeeConfig,
2474        ];
2475        assert_eq!(
2476            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2477            vec![
2478                ExtensionType::TransferFeeAmount,
2479                ExtensionType::TransferFeeAmount
2480            ]
2481        );
2482    }
2483
2484    #[test]
2485    fn mint_without_extensions() {
2486        let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2487        let mut buffer = vec![0; space];
2488        assert_eq!(
2489            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2490            Err(ProgramError::InvalidAccountData),
2491        );
2492
2493        // write base account
2494        let mut state =
2495            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2496        *state.base = TEST_POD_MINT;
2497        state.init_account_type().unwrap();
2498
2499        // fail init extension
2500        assert_eq!(
2501            state.init_extension::<TransferFeeConfig>(true),
2502            Err(ProgramError::InvalidAccountData),
2503        );
2504
2505        assert_eq!(TEST_MINT_SLICE, buffer);
2506    }
2507
2508    #[test]
2509    fn test_init_nonzero_default() {
2510        let mint_size =
2511            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2512                .unwrap();
2513        let mut buffer = vec![0; mint_size];
2514        let mut state =
2515            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2516        *state.base = TEST_POD_MINT;
2517        state.init_account_type().unwrap();
2518        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2519        assert_eq!(extension.padding1, [1; 128]);
2520        assert_eq!(extension.padding2, [2; 48]);
2521        assert_eq!(extension.padding3, [3; 9]);
2522    }
2523
2524    #[test]
2525    fn test_init_buffer_too_small() {
2526        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2527            ExtensionType::MintCloseAuthority,
2528        ])
2529        .unwrap();
2530        let mut buffer = vec![0; mint_size - 1];
2531        let mut state =
2532            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2533        let err = state
2534            .init_extension::<MintCloseAuthority>(true)
2535            .unwrap_err();
2536        assert_eq!(err, ProgramError::InvalidAccountData);
2537
2538        state.tlv_data[0] = 3;
2539        state.tlv_data[2] = 32;
2540        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2541        assert_eq!(err, ProgramError::InvalidAccountData);
2542
2543        let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2544        let err =
2545            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2546        assert_eq!(err, ProgramError::InvalidAccountData);
2547
2548        // OK since there are two bytes for the type, which is `Uninitialized`
2549        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2550        let mut state =
2551            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2552        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2553        assert_eq!(err, ProgramError::InvalidAccountData);
2554
2555        assert_eq!(state.get_extension_types().unwrap(), vec![]);
2556
2557        // OK, there aren't two bytes for the type, but that's fine
2558        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2559        let state =
2560            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2561        assert_eq!(state.get_extension_types().unwrap(), []);
2562    }
2563
2564    #[test]
2565    fn test_extension_with_no_data() {
2566        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2567            ExtensionType::ImmutableOwner,
2568        ])
2569        .unwrap();
2570        let mut buffer = vec![0; account_size];
2571        let mut state =
2572            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2573        *state.base = TEST_POD_ACCOUNT;
2574        state.init_account_type().unwrap();
2575
2576        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2577        assert_eq!(
2578            err,
2579            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2580        );
2581
2582        state.init_extension::<ImmutableOwner>(true).unwrap();
2583        assert_eq!(
2584            get_first_extension_type(state.tlv_data).unwrap(),
2585            Some(ExtensionType::ImmutableOwner)
2586        );
2587        assert_eq!(
2588            get_tlv_data_info(state.tlv_data).unwrap(),
2589            TlvDataInfo {
2590                extension_types: vec![ExtensionType::ImmutableOwner],
2591                used_len: add_type_and_length_to_len(0)
2592            }
2593        );
2594    }
2595
2596    #[test]
2597    fn fail_account_len_with_metadata() {
2598        assert_eq!(
2599            ExtensionType::try_calculate_account_len::<PodMint>(&[
2600                ExtensionType::MintCloseAuthority,
2601                ExtensionType::VariableLenMintTest,
2602                ExtensionType::TransferFeeConfig,
2603            ])
2604            .unwrap_err(),
2605            ProgramError::InvalidArgument
2606        );
2607    }
2608
2609    #[test]
2610    fn alloc() {
2611        let variable_len = VariableLenMintTest { data: vec![1] };
2612        let alloc_size = variable_len.get_packed_len().unwrap();
2613        let account_size =
2614            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2615        let mut buffer = vec![0; account_size];
2616        let mut state =
2617            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2618        state
2619            .init_variable_len_extension(&variable_len, false)
2620            .unwrap();
2621
2622        // can't double alloc
2623        assert_eq!(
2624            state
2625                .init_variable_len_extension(&variable_len, false)
2626                .unwrap_err(),
2627            TokenError::ExtensionAlreadyInitialized.into()
2628        );
2629
2630        // unless overwrite is set
2631        state
2632            .init_variable_len_extension(&variable_len, true)
2633            .unwrap();
2634
2635        // can't change the size during overwrite though
2636        assert_eq!(
2637            state
2638                .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2639                .unwrap_err(),
2640            TokenError::InvalidLengthForAlloc.into()
2641        );
2642
2643        // try to write too far, fail earlier
2644        assert_eq!(
2645            state
2646                .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2647                .unwrap_err(),
2648            ProgramError::InvalidAccountData
2649        );
2650    }
2651
2652    #[test]
2653    fn realloc() {
2654        let small_variable_len = VariableLenMintTest {
2655            data: vec![1, 2, 3],
2656        };
2657        let base_variable_len = VariableLenMintTest {
2658            data: vec![1, 2, 3, 4],
2659        };
2660        let big_variable_len = VariableLenMintTest {
2661            data: vec![1, 2, 3, 4, 5],
2662        };
2663        let too_big_variable_len = VariableLenMintTest {
2664            data: vec![1, 2, 3, 4, 5, 6],
2665        };
2666        let account_size =
2667            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2668                .unwrap()
2669                + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2670        let mut buffer = vec![0; account_size];
2671        let mut state =
2672            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2673
2674        // alloc both types
2675        state
2676            .init_variable_len_extension(&base_variable_len, false)
2677            .unwrap();
2678        let max_pubkey =
2679            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
2680        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2681        extension.authority = max_pubkey;
2682        extension.metadata_address = max_pubkey;
2683
2684        // realloc first entry to larger
2685        state
2686            .realloc_variable_len_extension(&big_variable_len)
2687            .unwrap();
2688        let extension = state
2689            .get_variable_len_extension::<VariableLenMintTest>()
2690            .unwrap();
2691        assert_eq!(extension, big_variable_len);
2692        let extension = state.get_extension::<MetadataPointer>().unwrap();
2693        assert_eq!(extension.authority, max_pubkey);
2694        assert_eq!(extension.metadata_address, max_pubkey);
2695
2696        // realloc to smaller
2697        state
2698            .realloc_variable_len_extension(&small_variable_len)
2699            .unwrap();
2700        let extension = state
2701            .get_variable_len_extension::<VariableLenMintTest>()
2702            .unwrap();
2703        assert_eq!(extension, small_variable_len);
2704        let extension = state.get_extension::<MetadataPointer>().unwrap();
2705        assert_eq!(extension.authority, max_pubkey);
2706        assert_eq!(extension.metadata_address, max_pubkey);
2707        let diff = big_variable_len.get_packed_len().unwrap()
2708            - small_variable_len.get_packed_len().unwrap();
2709        assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2710
2711        // unpack again since we dropped the last `state`
2712        let mut state =
2713            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2714        // realloc too much, fails
2715        assert_eq!(
2716            state
2717                .realloc_variable_len_extension(&too_big_variable_len)
2718                .unwrap_err(),
2719            ProgramError::InvalidAccountData,
2720        );
2721    }
2722
2723    #[test]
2724    fn account_len() {
2725        let small_variable_len = VariableLenMintTest {
2726            data: vec![20, 30, 40],
2727        };
2728        let variable_len = VariableLenMintTest {
2729            data: vec![20, 30, 40, 50],
2730        };
2731        let big_variable_len = VariableLenMintTest {
2732            data: vec![20, 30, 40, 50, 60],
2733        };
2734        let value_len = variable_len.get_packed_len().unwrap();
2735        let account_size =
2736            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2737        let mut buffer = vec![0; account_size];
2738        let mut state =
2739            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2740
2741        // With a new extension, new length must include padding, 1 byte for
2742        // account type, 2 bytes for type, 2 for length
2743        let current_len = state.try_get_account_len().unwrap();
2744        assert_eq!(current_len, PodMint::SIZE_OF);
2745        let new_len = state
2746            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2747                &variable_len,
2748            )
2749            .unwrap();
2750        assert_eq!(
2751            new_len,
2752            BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2753        );
2754
2755        state
2756            .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2757            .unwrap();
2758        let current_len = state.try_get_account_len().unwrap();
2759        assert_eq!(current_len, new_len);
2760
2761        // Reduce the extension size
2762        let new_len = state
2763            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2764                &small_variable_len,
2765            )
2766            .unwrap();
2767        assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2768
2769        // Increase the extension size
2770        let new_len = state
2771            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2772                &big_variable_len,
2773            )
2774            .unwrap();
2775        assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2776
2777        // Maintain the extension size
2778        let new_len = state
2779            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2780                &variable_len,
2781            )
2782            .unwrap();
2783        assert_eq!(new_len, current_len);
2784    }
2785
2786    /// Test helper for mimicking the data layout an on-chain `AccountInfo`,
2787    /// which permits "reallocs" as the Solana runtime does it
2788    struct SolanaAccountData {
2789        data: Vec<u8>,
2790        lamports: u64,
2791        owner: Pubkey,
2792    }
2793    impl SolanaAccountData {
2794        /// Create a new fake solana account data. The underlying vector is
2795        /// overallocated to mimic the runtime
2796        fn new(account_data: &[u8]) -> Self {
2797            let mut data = vec![];
2798            data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2799            data.extend_from_slice(account_data);
2800            data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2801            Self {
2802                data,
2803                lamports: 10,
2804                owner: Pubkey::new_unique(),
2805            }
2806        }
2807
2808        /// Data lops off the first 8 bytes, since those store the size of the
2809        /// account for the Solana runtime
2810        fn data(&self) -> &[u8] {
2811            let start = size_of::<u64>();
2812            let len = self.len();
2813            &self.data[start..start + len]
2814        }
2815
2816        /// Gets the runtime length of the account data
2817        fn len(&self) -> usize {
2818            self.data
2819                .get(..size_of::<u64>())
2820                .and_then(|slice| slice.try_into().ok())
2821                .map(u64::from_le_bytes)
2822                .unwrap() as usize
2823        }
2824    }
2825    impl GetAccount for SolanaAccountData {
2826        fn get(&mut self) -> (&mut u64, &mut [u8], &Pubkey, bool, Epoch) {
2827            // need to pull out the data here to avoid a double-mutable borrow
2828            let start = size_of::<u64>();
2829            let len = self.len();
2830            (
2831                &mut self.lamports,
2832                &mut self.data[start..start + len],
2833                &self.owner,
2834                false,
2835                Epoch::default(),
2836            )
2837        }
2838    }
2839
2840    #[test]
2841    fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2842        let fixed_len = FixedLenMintTest {
2843            data: [1, 2, 3, 4, 5, 6, 7, 8],
2844        };
2845        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2846        let base_account_size = PodMint::SIZE_OF;
2847        let mut buffer = vec![0; base_account_size];
2848        let state =
2849            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2850        *state.base = TEST_POD_MINT;
2851
2852        let mut data = SolanaAccountData::new(&buffer);
2853        let key = Pubkey::new_unique();
2854        let account_info = (&key, &mut data).into_account_info();
2855
2856        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2857        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2858        assert_eq!(data.len(), new_account_len);
2859        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2860        assert_eq!(
2861            state.get_extension::<FixedLenMintTest>().unwrap(),
2862            &fixed_len,
2863        );
2864
2865        // alloc again succeeds with "overwrite"
2866        let account_info = (&key, &mut data).into_account_info();
2867        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2868
2869        // alloc again fails without "overwrite"
2870        let account_info = (&key, &mut data).into_account_info();
2871        assert_eq!(
2872            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2873            TokenError::ExtensionAlreadyInitialized.into()
2874        );
2875    }
2876
2877    #[test]
2878    fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2879        let variable_len = VariableLenMintTest { data: vec![20, 99] };
2880        let value_len = variable_len.get_packed_len().unwrap();
2881        let base_account_size = PodMint::SIZE_OF;
2882        let mut buffer = vec![0; base_account_size];
2883        let state =
2884            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2885        *state.base = TEST_POD_MINT;
2886
2887        let mut data = SolanaAccountData::new(&buffer);
2888        let key = Pubkey::new_unique();
2889        let account_info = (&key, &mut data).into_account_info();
2890
2891        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2892            &account_info,
2893            &variable_len,
2894            false,
2895        )
2896        .unwrap();
2897        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2898        assert_eq!(data.len(), new_account_len);
2899        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2900        assert_eq!(
2901            state
2902                .get_variable_len_extension::<VariableLenMintTest>()
2903                .unwrap(),
2904            variable_len
2905        );
2906
2907        // alloc again succeeds with "overwrite"
2908        let account_info = (&key, &mut data).into_account_info();
2909        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2910            &account_info,
2911            &variable_len,
2912            true,
2913        )
2914        .unwrap();
2915
2916        // alloc again fails without "overwrite"
2917        let account_info = (&key, &mut data).into_account_info();
2918        assert_eq!(
2919            alloc_and_serialize_variable_len_extension::<PodMint, _>(
2920                &account_info,
2921                &variable_len,
2922                false,
2923            )
2924            .unwrap_err(),
2925            TokenError::ExtensionAlreadyInitialized.into()
2926        );
2927    }
2928
2929    #[test]
2930    fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2931        let fixed_len = FixedLenMintTest {
2932            data: [1, 2, 3, 4, 5, 6, 7, 8],
2933        };
2934        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2935        let account_size =
2936            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2937                .unwrap()
2938                + add_type_and_length_to_len(value_len);
2939        let mut buffer = vec![0; account_size];
2940        let mut state =
2941            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2942        *state.base = TEST_POD_MINT;
2943        state.init_account_type().unwrap();
2944
2945        let test_key =
2946            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2947        let extension = state.init_extension::<GroupPointer>(false).unwrap();
2948        extension.authority = test_key;
2949        extension.group_address = test_key;
2950
2951        let mut data = SolanaAccountData::new(&buffer);
2952        let key = Pubkey::new_unique();
2953        let account_info = (&key, &mut data).into_account_info();
2954
2955        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2956        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2957            + add_type_and_length_to_len(value_len)
2958            + add_type_and_length_to_len(size_of::<GroupPointer>());
2959        assert_eq!(data.len(), new_account_len);
2960        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2961        assert_eq!(
2962            state.get_extension::<FixedLenMintTest>().unwrap(),
2963            &fixed_len,
2964        );
2965        let extension = state.get_extension::<GroupPointer>().unwrap();
2966        assert_eq!(extension.authority, test_key);
2967        assert_eq!(extension.group_address, test_key);
2968
2969        // alloc again succeeds with "overwrite"
2970        let account_info = (&key, &mut data).into_account_info();
2971        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2972
2973        // alloc again fails without "overwrite"
2974        let account_info = (&key, &mut data).into_account_info();
2975        assert_eq!(
2976            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2977            TokenError::ExtensionAlreadyInitialized.into()
2978        );
2979    }
2980
2981    #[test]
2982    fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2983        let variable_len = VariableLenMintTest { data: vec![42, 6] };
2984        let value_len = variable_len.get_packed_len().unwrap();
2985        let account_size =
2986            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2987                .unwrap()
2988                + add_type_and_length_to_len(value_len);
2989        let mut buffer = vec![0; account_size];
2990        let mut state =
2991            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2992        *state.base = TEST_POD_MINT;
2993        state.init_account_type().unwrap();
2994
2995        let test_key =
2996            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2997        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2998        extension.authority = test_key;
2999        extension.metadata_address = test_key;
3000
3001        let mut data = SolanaAccountData::new(&buffer);
3002        let key = Pubkey::new_unique();
3003        let account_info = (&key, &mut data).into_account_info();
3004
3005        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3006            &account_info,
3007            &variable_len,
3008            false,
3009        )
3010        .unwrap();
3011        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3012            + add_type_and_length_to_len(value_len)
3013            + add_type_and_length_to_len(size_of::<MetadataPointer>());
3014        assert_eq!(data.len(), new_account_len);
3015        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3016        assert_eq!(
3017            state
3018                .get_variable_len_extension::<VariableLenMintTest>()
3019                .unwrap(),
3020            variable_len
3021        );
3022        let extension = state.get_extension::<MetadataPointer>().unwrap();
3023        assert_eq!(extension.authority, test_key);
3024        assert_eq!(extension.metadata_address, test_key);
3025
3026        // alloc again succeeds with "overwrite"
3027        let account_info = (&key, &mut data).into_account_info();
3028        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3029            &account_info,
3030            &variable_len,
3031            true,
3032        )
3033        .unwrap();
3034
3035        // alloc again fails without "overwrite"
3036        let account_info = (&key, &mut data).into_account_info();
3037        assert_eq!(
3038            alloc_and_serialize_variable_len_extension::<PodMint, _>(
3039                &account_info,
3040                &variable_len,
3041                false,
3042            )
3043            .unwrap_err(),
3044            TokenError::ExtensionAlreadyInitialized.into()
3045        );
3046    }
3047
3048    #[test]
3049    fn realloc_variable_len_tlv_in_account_info() {
3050        let variable_len = VariableLenMintTest {
3051            data: vec![1, 2, 3, 4, 5],
3052        };
3053        let alloc_size = variable_len.get_packed_len().unwrap();
3054        let account_size =
3055            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3056                .unwrap()
3057                + add_type_and_length_to_len(alloc_size);
3058        let mut buffer = vec![0; account_size];
3059        let mut state =
3060            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3061        *state.base = TEST_POD_MINT;
3062        state.init_account_type().unwrap();
3063
3064        // alloc both types
3065        state
3066            .init_variable_len_extension(&variable_len, false)
3067            .unwrap();
3068        let max_pubkey =
3069            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
3070        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3071        extension.authority = max_pubkey;
3072        extension.metadata_address = max_pubkey;
3073
3074        // reallocate to smaller, make sure existing extension is fine
3075        let mut data = SolanaAccountData::new(&buffer);
3076        let key = Pubkey::new_unique();
3077        let account_info = (&key, &mut data).into_account_info();
3078        let variable_len = VariableLenMintTest { data: vec![1, 2] };
3079        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3080            &account_info,
3081            &variable_len,
3082            true,
3083        )
3084        .unwrap();
3085
3086        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3087        let extension = state.get_extension::<MetadataPointer>().unwrap();
3088        assert_eq!(extension.authority, max_pubkey);
3089        assert_eq!(extension.metadata_address, max_pubkey);
3090        let extension = state
3091            .get_variable_len_extension::<VariableLenMintTest>()
3092            .unwrap();
3093        assert_eq!(extension, variable_len);
3094        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3095
3096        // reallocate to larger
3097        let account_info = (&key, &mut data).into_account_info();
3098        let variable_len = VariableLenMintTest {
3099            data: vec![1, 2, 3, 4, 5, 6, 7],
3100        };
3101        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3102            &account_info,
3103            &variable_len,
3104            true,
3105        )
3106        .unwrap();
3107
3108        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3109        let extension = state.get_extension::<MetadataPointer>().unwrap();
3110        assert_eq!(extension.authority, max_pubkey);
3111        assert_eq!(extension.metadata_address, max_pubkey);
3112        let extension = state
3113            .get_variable_len_extension::<VariableLenMintTest>()
3114            .unwrap();
3115        assert_eq!(extension, variable_len);
3116        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3117
3118        // reallocate to same
3119        let account_info = (&key, &mut data).into_account_info();
3120        let variable_len = VariableLenMintTest {
3121            data: vec![7, 6, 5, 4, 3, 2, 1],
3122        };
3123        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3124            &account_info,
3125            &variable_len,
3126            true,
3127        )
3128        .unwrap();
3129
3130        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3131        let extension = state.get_extension::<MetadataPointer>().unwrap();
3132        assert_eq!(extension.authority, max_pubkey);
3133        assert_eq!(extension.metadata_address, max_pubkey);
3134        let extension = state
3135            .get_variable_len_extension::<VariableLenMintTest>()
3136            .unwrap();
3137        assert_eq!(extension, variable_len);
3138        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3139    }
3140}