spl_token_2022_interface/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use {
6    crate::{
7        error::TokenError,
8        extension::{
9            confidential_mint_burn::ConfidentialMintBurn,
10            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11            confidential_transfer_fee::{
12                ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13            },
14            cpi_guard::CpiGuard,
15            default_account_state::DefaultAccountState,
16            group_member_pointer::GroupMemberPointer,
17            group_pointer::GroupPointer,
18            immutable_owner::ImmutableOwner,
19            interest_bearing_mint::InterestBearingConfig,
20            memo_transfer::MemoTransfer,
21            metadata_pointer::MetadataPointer,
22            mint_close_authority::MintCloseAuthority,
23            non_transferable::{NonTransferable, NonTransferableAccount},
24            pausable::{PausableAccount, PausableConfig},
25            permanent_delegate::PermanentDelegate,
26            scaled_ui_amount::ScaledUiAmountConfig,
27            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
28            transfer_hook::{TransferHook, TransferHookAccount},
29        },
30        pod::{PodAccount, PodMint},
31        state::{Account, Mint, Multisig, PackedSizeOf},
32    },
33    bytemuck::{Pod, Zeroable},
34    num_enum::{IntoPrimitive, TryFromPrimitive},
35    solana_account_info::AccountInfo,
36    solana_program_error::ProgramError,
37    solana_program_pack::{IsInitialized, Pack},
38    spl_pod::{
39        bytemuck::{pod_from_bytes, pod_from_bytes_mut, pod_get_packed_len},
40        primitives::PodU16,
41    },
42    spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
43    spl_type_length_value::variable_len_pack::VariableLenPack,
44    std::{
45        cmp::Ordering,
46        convert::{TryFrom, TryInto},
47        mem::size_of,
48    },
49};
50
51/// Confidential Transfer extension
52pub mod confidential_transfer;
53/// Confidential Transfer Fee extension
54pub mod confidential_transfer_fee;
55/// CPI Guard extension
56pub mod cpi_guard;
57/// Default Account State extension
58pub mod default_account_state;
59/// Group Member Pointer extension
60pub mod group_member_pointer;
61/// Group Pointer extension
62pub mod group_pointer;
63/// Immutable Owner extension
64pub mod immutable_owner;
65/// Interest-Bearing Mint extension
66pub mod interest_bearing_mint;
67/// Memo Transfer extension
68pub mod memo_transfer;
69/// Metadata Pointer extension
70pub mod metadata_pointer;
71/// Mint Close Authority extension
72pub mod mint_close_authority;
73/// Non Transferable extension
74pub mod non_transferable;
75/// Pausable extension
76pub mod pausable;
77/// Permanent Delegate extension
78pub mod permanent_delegate;
79/// Scaled UI Amount extension
80pub mod scaled_ui_amount;
81/// Token-group extension
82pub mod token_group;
83/// Token-metadata extension
84pub mod token_metadata;
85/// Transfer Fee extension
86pub mod transfer_fee;
87/// Transfer Hook extension
88pub mod transfer_hook;
89
90/// Confidential mint-burn extension
91pub mod confidential_mint_burn;
92
93/// Length in TLV structure
94#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
95#[repr(transparent)]
96pub struct Length(PodU16);
97impl From<Length> for usize {
98    fn from(n: Length) -> Self {
99        Self::from(u16::from(n.0))
100    }
101}
102impl TryFrom<usize> for Length {
103    type Error = ProgramError;
104    fn try_from(n: usize) -> Result<Self, Self::Error> {
105        u16::try_from(n)
106            .map(|v| Self(PodU16::from(v)))
107            .map_err(|_| ProgramError::AccountDataTooSmall)
108    }
109}
110
111/// Helper function to get the current `TlvIndices` from the current spot
112fn get_tlv_indices(type_start: usize) -> TlvIndices {
113    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
114    let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
115    TlvIndices {
116        type_start,
117        length_start,
118        value_start,
119    }
120}
121
122/// Helper function to tack on the size of an extension bytes if an account with
123/// extensions is exactly the size of a multisig
124const fn adjust_len_for_multisig(account_len: usize) -> usize {
125    if account_len == Multisig::LEN {
126        account_len.saturating_add(size_of::<ExtensionType>())
127    } else {
128        account_len
129    }
130}
131
132/// Helper function to calculate exactly how many bytes a value will take up,
133/// given the value's length
134const fn add_type_and_length_to_len(value_len: usize) -> usize {
135    value_len
136        .saturating_add(size_of::<ExtensionType>())
137        .saturating_add(pod_get_packed_len::<Length>())
138}
139
140/// Helper struct for returning the indices of the type, length, and value in
141/// a TLV entry
142#[derive(Debug)]
143struct TlvIndices {
144    pub type_start: usize,
145    pub length_start: usize,
146    pub value_start: usize,
147}
148fn get_extension_indices<V: Extension>(
149    tlv_data: &[u8],
150    init: bool,
151) -> Result<TlvIndices, ProgramError> {
152    let mut start_index = 0;
153    while start_index < tlv_data.len() {
154        let tlv_indices = get_tlv_indices(start_index);
155        if tlv_data.len() < tlv_indices.value_start {
156            return Err(ProgramError::InvalidAccountData);
157        }
158        let extension_type = u16::from_le_bytes(
159            tlv_data[tlv_indices.type_start..tlv_indices.length_start]
160                .try_into()
161                .map_err(|_| ProgramError::InvalidAccountData)?,
162        );
163        if extension_type == u16::from(V::TYPE) {
164            // found an instance of the extension that we're initializing, return!
165            return Ok(tlv_indices);
166        // got to an empty spot, init here, or error if we're searching, since
167        // nothing is written after an Uninitialized spot
168        } else if extension_type == u16::from(ExtensionType::Uninitialized) {
169            if init {
170                return Ok(tlv_indices);
171            } else {
172                return Err(TokenError::ExtensionNotFound.into());
173            }
174        } else {
175            let length = pod_from_bytes::<Length>(
176                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
177            )?;
178            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
179            start_index = value_end_index;
180        }
181    }
182    Err(ProgramError::InvalidAccountData)
183}
184
185/// Basic information about the TLV buffer, collected from iterating through all
186/// entries
187#[derive(Debug, PartialEq)]
188struct TlvDataInfo {
189    /// The extension types written in the TLV buffer
190    extension_types: Vec<ExtensionType>,
191    /// The total number bytes allocated for all TLV entries.
192    ///
193    /// Each TLV entry's allocated bytes comprises two bytes for the `type`, two
194    /// bytes for the `length`, and `length` number of bytes for the `value`.
195    used_len: usize,
196}
197
198/// Fetches basic information about the TLV buffer by iterating through all
199/// TLV entries.
200fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
201    let mut extension_types = vec![];
202    let mut start_index = 0;
203    while start_index < tlv_data.len() {
204        let tlv_indices = get_tlv_indices(start_index);
205        if tlv_data.len() < tlv_indices.length_start {
206            // There aren't enough bytes to store the next type, which means we
207            // got to the end. The last byte could be used during a realloc!
208            return Ok(TlvDataInfo {
209                extension_types,
210                used_len: tlv_indices.type_start,
211            });
212        }
213        let extension_type =
214            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
215        if extension_type == ExtensionType::Uninitialized {
216            return Ok(TlvDataInfo {
217                extension_types,
218                used_len: tlv_indices.type_start,
219            });
220        } else {
221            if tlv_data.len() < tlv_indices.value_start {
222                // not enough bytes to store the length, malformed
223                return Err(ProgramError::InvalidAccountData);
224            }
225            extension_types.push(extension_type);
226            let length = pod_from_bytes::<Length>(
227                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
228            )?;
229
230            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
231            if value_end_index > tlv_data.len() {
232                // value blows past the size of the slice, malformed
233                return Err(ProgramError::InvalidAccountData);
234            }
235            start_index = value_end_index;
236        }
237    }
238    Ok(TlvDataInfo {
239        extension_types,
240        used_len: start_index,
241    })
242}
243
244fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
245    if tlv_data.is_empty() {
246        Ok(None)
247    } else {
248        let tlv_indices = get_tlv_indices(0);
249        if tlv_data.len() <= tlv_indices.length_start {
250            return Ok(None);
251        }
252        let extension_type =
253            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
254        if extension_type == ExtensionType::Uninitialized {
255            Ok(None)
256        } else {
257            Ok(Some(extension_type))
258        }
259    }
260}
261
262fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
263    if input.len() == Multisig::LEN || input.len() < minimum_len {
264        Err(ProgramError::InvalidAccountData)
265    } else {
266        Ok(())
267    }
268}
269
270fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
271    if account_type != S::ACCOUNT_TYPE {
272        Err(ProgramError::InvalidAccountData)
273    } else {
274        Ok(())
275    }
276}
277
278/// Any account with extensions must be at least `Account::LEN`.  Both mints and
279/// accounts can have extensions
280/// A mint with extensions that takes it past 165 could be indiscernible from an
281/// Account with an extension, even if we add the account type. For example,
282/// let's say we have:
283///
284/// ```text
285/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
286///                          ^     ^       ^     ^
287///                     acct type  extension length data...
288///
289/// Mint: 82 bytes... + 83 bytes of other extension data
290///     + [2, 0, 3, 0, 100, ....]
291///      (data in extension just happens to look like this)
292/// ```
293///
294/// With this approach, we only start writing the TLV data after `Account::LEN`,
295/// which means we always know that the account type is going to be right after
296/// that. We do a special case checking for a Multisig length, because those
297/// aren't extensible under any circumstances.
298const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
299/// Helper that tacks on the `AccountType` length, which gives the minimum for
300/// any account with extensions
301const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
302
303fn type_and_tlv_indices<S: BaseState>(
304    rest_input: &[u8],
305) -> Result<Option<(usize, usize)>, ProgramError> {
306    if rest_input.is_empty() {
307        Ok(None)
308    } else {
309        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
310        // check padding is all zeroes
311        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
312        if rest_input.len() < tlv_start_index {
313            return Err(ProgramError::InvalidAccountData);
314        }
315        if rest_input[..account_type_index] != vec![0; account_type_index] {
316            Err(ProgramError::InvalidAccountData)
317        } else {
318            Ok(Some((account_type_index, tlv_start_index)))
319        }
320    }
321}
322
323/// Checks a base buffer to verify if it is an Account without having to
324/// completely deserialize it
325fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
326    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
327
328    if input.len() != BASE_ACCOUNT_LENGTH {
329        return Err(ProgramError::InvalidAccountData);
330    }
331    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
332}
333
334fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
335    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
336        return Err(ProgramError::InvalidAccountData);
337    }
338    let TlvIndices {
339        type_start: _,
340        length_start,
341        value_start,
342    } = get_extension_indices::<V>(tlv_data, false)?;
343    // get_extension_indices has checked that tlv_data is long enough to include
344    // these indices
345    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
346    let value_end = value_start.saturating_add(usize::from(*length));
347    if tlv_data.len() < value_end {
348        return Err(ProgramError::InvalidAccountData);
349    }
350    Ok(&tlv_data[value_start..value_end])
351}
352
353fn get_extension_bytes_mut<S: BaseState, V: Extension>(
354    tlv_data: &mut [u8],
355) -> Result<&mut [u8], ProgramError> {
356    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
357        return Err(ProgramError::InvalidAccountData);
358    }
359    let TlvIndices {
360        type_start: _,
361        length_start,
362        value_start,
363    } = get_extension_indices::<V>(tlv_data, false)?;
364    // get_extension_indices has checked that tlv_data is long enough to include
365    // these indices
366    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
367    let value_end = value_start.saturating_add(usize::from(*length));
368    if tlv_data.len() < value_end {
369        return Err(ProgramError::InvalidAccountData);
370    }
371    Ok(&mut tlv_data[value_start..value_end])
372}
373
374/// Calculate the new expected size if the state allocates the given number
375/// of bytes for the given extension type.
376///
377/// Provides the correct answer regardless if the extension is already present
378/// in the TLV data.
379fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
380    tlv_data: &[u8],
381    new_extension_len: usize,
382) -> Result<usize, ProgramError> {
383    // get the new length used by the extension
384    let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
385    let tlv_info = get_tlv_data_info(tlv_data)?;
386    // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
387    // and account type
388    let current_len = tlv_info
389        .used_len
390        .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
391    // get the current length used by the extension
392    let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
393        .map(|x| add_type_and_length_to_len(x.len()))
394        .unwrap_or(0);
395    let new_len = current_len
396        .saturating_sub(current_extension_len)
397        .saturating_add(new_extension_tlv_len);
398    Ok(adjust_len_for_multisig(new_len))
399}
400
401/// Trait for base state with extension
402pub trait BaseStateWithExtensions<S: BaseState> {
403    /// Get the buffer containing all extension data
404    fn get_tlv_data(&self) -> &[u8];
405
406    /// Fetch the bytes for a TLV entry
407    fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
408        get_extension_bytes::<S, V>(self.get_tlv_data())
409    }
410
411    /// Unpack a portion of the TLV data as the desired type
412    fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
413        pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
414    }
415
416    /// Unpacks a portion of the TLV data as the desired variable-length type
417    fn get_variable_len_extension<V: Extension + VariableLenPack>(
418        &self,
419    ) -> Result<V, ProgramError> {
420        let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
421        V::unpack_from_slice(data)
422    }
423
424    /// Iterates through the TLV entries, returning only the types
425    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
426        get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
427    }
428
429    /// Get just the first extension type, useful to track mixed initialization
430    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
431        get_first_extension_type(self.get_tlv_data())
432    }
433
434    /// Get the total number of bytes used by TLV entries and the base type
435    fn try_get_account_len(&self) -> Result<usize, ProgramError> {
436        let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
437        if tlv_info.extension_types.is_empty() {
438            Ok(S::SIZE_OF)
439        } else {
440            let total_len = tlv_info
441                .used_len
442                .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
443            Ok(adjust_len_for_multisig(total_len))
444        }
445    }
446    /// Calculate the new expected size if the state allocates the given
447    /// fixed-length extension instance.
448    /// If the state already has the extension, the resulting account length
449    /// will be unchanged.
450    fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
451        try_get_new_account_len_for_extension_len::<S, V>(
452            self.get_tlv_data(),
453            pod_get_packed_len::<V>(),
454        )
455    }
456
457    /// Calculate the new expected size if the state allocates the given
458    /// variable-length extension instance.
459    fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
460        &self,
461        new_extension: &V,
462    ) -> Result<usize, ProgramError> {
463        try_get_new_account_len_for_extension_len::<S, V>(
464            self.get_tlv_data(),
465            new_extension.get_packed_len()?,
466        )
467    }
468}
469
470/// Encapsulates owned immutable base state data (mint or account) with possible
471/// extensions
472#[derive(Clone, Debug, PartialEq)]
473pub struct StateWithExtensionsOwned<S: BaseState> {
474    /// Unpacked base data
475    pub base: S,
476    /// Raw TLV data, deserialized on demand
477    tlv_data: Vec<u8>,
478}
479impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
480    /// Unpack base state, leaving the extension data as a slice
481    ///
482    /// Fails if the base state is not initialized.
483    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
484        check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
485        let mut rest = input.split_off(S::SIZE_OF);
486        let base = S::unpack(&input)?;
487        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
488            // type_and_tlv_indices() checks that returned indexes are within range
489            let account_type = AccountType::try_from(rest[account_type_index])
490                .map_err(|_| ProgramError::InvalidAccountData)?;
491            check_account_type::<S>(account_type)?;
492            let tlv_data = rest.split_off(tlv_start_index);
493            Ok(Self { base, tlv_data })
494        } else {
495            Ok(Self {
496                base,
497                tlv_data: vec![],
498            })
499        }
500    }
501}
502
503impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
504    fn get_tlv_data(&self) -> &[u8] {
505        &self.tlv_data
506    }
507}
508
509/// Encapsulates immutable base state data (mint or account) with possible
510/// extensions
511#[derive(Debug, PartialEq)]
512pub struct StateWithExtensions<'data, S: BaseState + Pack> {
513    /// Unpacked base data
514    pub base: S,
515    /// Slice of data containing all TLV data, deserialized on demand
516    tlv_data: &'data [u8],
517}
518impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
519    /// Unpack base state, leaving the extension data as a slice
520    ///
521    /// Fails if the base state is not initialized.
522    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
523        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
524        let (base_data, rest) = input.split_at(S::SIZE_OF);
525        let base = S::unpack(base_data)?;
526        let tlv_data = unpack_tlv_data::<S>(rest)?;
527        Ok(Self { base, tlv_data })
528    }
529}
530impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
531    fn get_tlv_data(&self) -> &[u8] {
532        self.tlv_data
533    }
534}
535
536/// Encapsulates immutable base state data (mint or account) with possible
537/// extensions, where the base state is Pod for zero-copy serde.
538#[derive(Debug, PartialEq)]
539pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
540    /// Unpacked base data
541    pub base: &'data S,
542    /// Slice of data containing all TLV data, deserialized on demand
543    tlv_data: &'data [u8],
544}
545impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
546    /// Unpack base state, leaving the extension data as a slice
547    ///
548    /// Fails if the base state is not initialized.
549    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
550        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
551        let (base_data, rest) = input.split_at(S::SIZE_OF);
552        let base = pod_from_bytes::<S>(base_data)?;
553        if !base.is_initialized() {
554            Err(ProgramError::UninitializedAccount)
555        } else {
556            let tlv_data = unpack_tlv_data::<S>(rest)?;
557            Ok(Self { base, tlv_data })
558        }
559    }
560}
561impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
562    fn get_tlv_data(&self) -> &[u8] {
563        self.tlv_data
564    }
565}
566
567/// Trait for mutable base state with extension
568pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
569    /// Get the underlying TLV data as mutable
570    fn get_tlv_data_mut(&mut self) -> &mut [u8];
571
572    /// Get the underlying account type as mutable
573    fn get_account_type_mut(&mut self) -> &mut [u8];
574
575    /// Unpack a portion of the TLV data as the base mutable bytes
576    fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
577        get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
578    }
579
580    /// Unpack a portion of the TLV data as the desired type that allows
581    /// modifying the type
582    fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
583        pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
584    }
585
586    /// Packs a variable-length extension into its appropriate data segment.
587    /// Fails if space hasn't already been allocated for the given extension
588    fn pack_variable_len_extension<V: Extension + VariableLenPack>(
589        &mut self,
590        extension: &V,
591    ) -> Result<(), ProgramError> {
592        let data = self.get_extension_bytes_mut::<V>()?;
593        // NOTE: Do *not* use `pack`, since the length check will cause
594        // reallocations to smaller sizes to fail
595        extension.pack_into_slice(data)
596    }
597
598    /// Packs the default extension data into an open slot if not already found
599    /// in the data buffer. If extension is already found in the buffer, it
600    /// overwrites the existing extension with the default state if
601    /// `overwrite` is set. If extension found, but `overwrite` is not set,
602    /// it returns error.
603    fn init_extension<V: Extension + Pod + Default>(
604        &mut self,
605        overwrite: bool,
606    ) -> Result<&mut V, ProgramError> {
607        let length = pod_get_packed_len::<V>();
608        let buffer = self.alloc::<V>(length, overwrite)?;
609        let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
610        *extension_ref = V::default();
611        Ok(extension_ref)
612    }
613
614    /// Reallocate and overwrite the TLV entry for the given variable-length
615    /// extension.
616    ///
617    /// Returns an error if the extension is not present, or if there is not
618    /// enough space in the buffer.
619    fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
620        &mut self,
621        new_extension: &V,
622    ) -> Result<(), ProgramError> {
623        let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
624        new_extension.pack_into_slice(data)
625    }
626
627    /// Reallocate the TLV entry for the given extension to the given number of
628    /// bytes.
629    ///
630    /// If the new length is smaller, it will compact the rest of the buffer and
631    /// zero out the difference at the end. If it's larger, it will move the
632    /// rest of the buffer data and zero out the new data.
633    ///
634    /// Returns an error if the extension is not present, or if this is not
635    /// enough space in the buffer.
636    fn realloc<V: Extension + VariableLenPack>(
637        &mut self,
638        length: usize,
639    ) -> Result<&mut [u8], ProgramError> {
640        let tlv_data = self.get_tlv_data_mut();
641        let TlvIndices {
642            type_start: _,
643            length_start,
644            value_start,
645        } = get_extension_indices::<V>(tlv_data, false)?;
646        let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
647        let data_len = tlv_data.len();
648
649        let length_ref = pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
650        let old_length = usize::from(*length_ref);
651
652        // Length check to avoid a panic later in `copy_within`
653        if old_length < length {
654            let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
655            if new_tlv_len > data_len {
656                return Err(ProgramError::InvalidAccountData);
657            }
658        }
659
660        // write new length after the check, to avoid getting into a bad situation
661        // if trying to recover from an error
662        *length_ref = Length::try_from(length)?;
663
664        let old_value_end = value_start.saturating_add(old_length);
665        let new_value_end = value_start.saturating_add(length);
666        tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
667        match old_length.cmp(&length) {
668            Ordering::Greater => {
669                // realloc to smaller, zero out the end
670                let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
671                tlv_data[new_tlv_len..tlv_len].fill(0);
672            }
673            Ordering::Less => {
674                // realloc to bigger, zero out the new bytes
675                tlv_data[old_value_end..new_value_end].fill(0);
676            }
677            Ordering::Equal => {} // nothing needed!
678        }
679
680        Ok(&mut tlv_data[value_start..new_value_end])
681    }
682
683    /// Allocate the given number of bytes for the given variable-length
684    /// extension and write its contents into the TLV buffer.
685    ///
686    /// This can only be used for variable-sized types, such as `String` or
687    /// `Vec`. `Pod` types must use `init_extension`
688    fn init_variable_len_extension<V: Extension + VariableLenPack>(
689        &mut self,
690        extension: &V,
691        overwrite: bool,
692    ) -> Result<(), ProgramError> {
693        let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
694        extension.pack_into_slice(data)
695    }
696
697    /// Allocate some space for the extension in the TLV data
698    fn alloc<V: Extension>(
699        &mut self,
700        length: usize,
701        overwrite: bool,
702    ) -> Result<&mut [u8], ProgramError> {
703        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
704            return Err(ProgramError::InvalidAccountData);
705        }
706        let tlv_data = self.get_tlv_data_mut();
707        let TlvIndices {
708            type_start,
709            length_start,
710            value_start,
711        } = get_extension_indices::<V>(tlv_data, true)?;
712
713        if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
714            return Err(ProgramError::InvalidAccountData);
715        }
716        let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
717
718        if extension_type == ExtensionType::Uninitialized || overwrite {
719            // write extension type
720            let extension_type_array: [u8; 2] = V::TYPE.into();
721            let extension_type_ref = &mut tlv_data[type_start..length_start];
722            extension_type_ref.copy_from_slice(&extension_type_array);
723            // write length
724            let length_ref =
725                pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
726
727            // check that the length is the same if we're doing an alloc
728            // with overwrite, otherwise a realloc should be done
729            if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
730                return Err(TokenError::InvalidLengthForAlloc.into());
731            }
732
733            *length_ref = Length::try_from(length)?;
734
735            let value_end = value_start.saturating_add(length);
736            Ok(&mut tlv_data[value_start..value_end])
737        } else {
738            // extension is already initialized, but no overwrite permission
739            Err(TokenError::ExtensionAlreadyInitialized.into())
740        }
741    }
742
743    /// If `extension_type` is an Account-associated `ExtensionType` that
744    /// requires initialization on `InitializeAccount`, this method packs
745    /// the default relevant `Extension` of an `ExtensionType` into an open
746    /// slot if not already found in the data buffer, otherwise overwrites
747    /// the existing extension with the default state. For all other
748    /// `ExtensionType`s, this is a no-op.
749    fn init_account_extension_from_type(
750        &mut self,
751        extension_type: ExtensionType,
752    ) -> Result<(), ProgramError> {
753        if extension_type.get_account_type() != AccountType::Account {
754            return Ok(());
755        }
756        match extension_type {
757            ExtensionType::TransferFeeAmount => {
758                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
759            }
760            ExtensionType::ImmutableOwner => {
761                self.init_extension::<ImmutableOwner>(true).map(|_| ())
762            }
763            ExtensionType::NonTransferableAccount => self
764                .init_extension::<NonTransferableAccount>(true)
765                .map(|_| ()),
766            ExtensionType::TransferHookAccount => {
767                self.init_extension::<TransferHookAccount>(true).map(|_| ())
768            }
769            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
770            // on InitializeAccount
771            ExtensionType::ConfidentialTransferAccount => Ok(()),
772            ExtensionType::PausableAccount => {
773                self.init_extension::<PausableAccount>(true).map(|_| ())
774            }
775            #[cfg(test)]
776            ExtensionType::AccountPaddingTest => {
777                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
778            }
779            _ => unreachable!(),
780        }
781    }
782
783    /// Write the account type into the buffer, done during the base
784    /// state initialization
785    /// Noops if there is no room for an extension in the account, needed for
786    /// pure base mints / accounts.
787    fn init_account_type(&mut self) -> Result<(), ProgramError> {
788        let first_extension_type = self.get_first_extension_type()?;
789        let account_type = self.get_account_type_mut();
790        if !account_type.is_empty() {
791            if let Some(extension_type) = first_extension_type {
792                let account_type = extension_type.get_account_type();
793                if account_type != S::ACCOUNT_TYPE {
794                    return Err(TokenError::ExtensionBaseMismatch.into());
795                }
796            }
797            account_type[0] = S::ACCOUNT_TYPE.into();
798        }
799        Ok(())
800    }
801
802    /// Check that the account type on the account (if initialized) matches the
803    /// account type for any extensions initialized on the TLV data
804    fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
805        if let Some(extension_type) = self.get_first_extension_type()? {
806            let account_type = extension_type.get_account_type();
807            if account_type != S::ACCOUNT_TYPE {
808                return Err(TokenError::ExtensionBaseMismatch.into());
809            }
810        }
811        Ok(())
812    }
813}
814
815/// Encapsulates mutable base state data (mint or account) with possible
816/// extensions
817#[derive(Debug, PartialEq)]
818pub struct StateWithExtensionsMut<'data, S: BaseState> {
819    /// Unpacked base data
820    pub base: S,
821    /// Raw base data
822    base_data: &'data mut [u8],
823    /// Writable account type
824    account_type: &'data mut [u8],
825    /// Slice of data containing all TLV data, deserialized on demand
826    tlv_data: &'data mut [u8],
827}
828impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
829    /// Unpack base state, leaving the extension data as a mutable slice
830    ///
831    /// Fails if the base state is not initialized.
832    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
833        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
834        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
835        let base = S::unpack(base_data)?;
836        let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
837        Ok(Self {
838            base,
839            base_data,
840            account_type,
841            tlv_data,
842        })
843    }
844
845    /// Unpack an uninitialized base state, leaving the extension data as a
846    /// mutable slice
847    ///
848    /// Fails if the base state has already been initialized.
849    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
850        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
851        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
852        let base = S::unpack_unchecked(base_data)?;
853        if base.is_initialized() {
854            return Err(TokenError::AlreadyInUse.into());
855        }
856        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
857        let state = Self {
858            base,
859            base_data,
860            account_type,
861            tlv_data,
862        };
863        state.check_account_type_matches_extension_type()?;
864        Ok(state)
865    }
866
867    /// Packs base state data into the base data portion
868    pub fn pack_base(&mut self) {
869        S::pack_into_slice(&self.base, self.base_data);
870    }
871}
872impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
873    fn get_tlv_data(&self) -> &[u8] {
874        self.tlv_data
875    }
876}
877impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
878    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
879        self.tlv_data
880    }
881    fn get_account_type_mut(&mut self) -> &mut [u8] {
882        self.account_type
883    }
884}
885
886/// Encapsulates mutable base state data (mint or account) with possible
887/// extensions, where the base state is Pod for zero-copy serde.
888#[derive(Debug, PartialEq)]
889pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
890    /// Unpacked base data
891    pub base: &'data mut S,
892    /// Writable account type
893    account_type: &'data mut [u8],
894    /// Slice of data containing all TLV data, deserialized on demand
895    tlv_data: &'data mut [u8],
896}
897impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
898    /// Unpack base state, leaving the extension data as a mutable slice
899    ///
900    /// Fails if the base state is not initialized.
901    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
902        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
903        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
904        let base = pod_from_bytes_mut::<S>(base_data)?;
905        if !base.is_initialized() {
906            Err(ProgramError::UninitializedAccount)
907        } else {
908            let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
909            Ok(Self {
910                base,
911                account_type,
912                tlv_data,
913            })
914        }
915    }
916
917    /// Unpack an uninitialized base state, leaving the extension data as a
918    /// mutable slice
919    ///
920    /// Fails if the base state has already been initialized.
921    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
922        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
923        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
924        let base = pod_from_bytes_mut::<S>(base_data)?;
925        if base.is_initialized() {
926            return Err(TokenError::AlreadyInUse.into());
927        }
928        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
929        let state = Self {
930            base,
931            account_type,
932            tlv_data,
933        };
934        state.check_account_type_matches_extension_type()?;
935        Ok(state)
936    }
937}
938
939impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
940    fn get_tlv_data(&self) -> &[u8] {
941        self.tlv_data
942    }
943}
944impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
945    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
946        self.tlv_data
947    }
948    fn get_account_type_mut(&mut self) -> &mut [u8] {
949        self.account_type
950    }
951}
952
953fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
954    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
955        // type_and_tlv_indices() checks that returned indexes are within range
956        let account_type = AccountType::try_from(rest[account_type_index])
957            .map_err(|_| ProgramError::InvalidAccountData)?;
958        check_account_type::<S>(account_type)?;
959        Ok(&rest[tlv_start_index..])
960    } else {
961        Ok(&[])
962    }
963}
964
965fn unpack_type_and_tlv_data_with_check_mut<
966    S: BaseState,
967    F: Fn(AccountType) -> Result<(), ProgramError>,
968>(
969    rest: &mut [u8],
970    check_fn: F,
971) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
972    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
973        // type_and_tlv_indices() checks that returned indexes are within range
974        let account_type = AccountType::try_from(rest[account_type_index])
975            .map_err(|_| ProgramError::InvalidAccountData)?;
976        check_fn(account_type)?;
977        let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
978        Ok((
979            &mut account_type[account_type_index..tlv_start_index],
980            tlv_data,
981        ))
982    } else {
983        Ok((&mut [], &mut []))
984    }
985}
986
987fn unpack_type_and_tlv_data_mut<S: BaseState>(
988    rest: &mut [u8],
989) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
990    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
991}
992
993fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
994    rest: &mut [u8],
995) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
996    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
997        if account_type != AccountType::Uninitialized {
998            Err(ProgramError::InvalidAccountData)
999        } else {
1000            Ok(())
1001        }
1002    })
1003}
1004
1005/// If `AccountType` is uninitialized, set it to the `BaseState`'s
1006/// `ACCOUNT_TYPE`; if `AccountType` is already set, check is set correctly for
1007/// `BaseState`. This method assumes that the `base_data` has already been
1008/// packed with data of the desired type.
1009pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1010    check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1011    let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1012    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1013        return Err(ProgramError::InvalidAccountData);
1014    }
1015    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1016        let mut account_type = AccountType::try_from(rest[account_type_index])
1017            .map_err(|_| ProgramError::InvalidAccountData)?;
1018        if account_type == AccountType::Uninitialized {
1019            rest[account_type_index] = S::ACCOUNT_TYPE.into();
1020            account_type = S::ACCOUNT_TYPE;
1021        }
1022        check_account_type::<S>(account_type)?;
1023        Ok(())
1024    } else {
1025        Err(ProgramError::InvalidAccountData)
1026    }
1027}
1028
1029/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig`
1030/// types are determined exclusively by the size of the account, and are not
1031/// included in the account data. `AccountType` is only included if extensions
1032/// have been initialized.
1033#[repr(u8)]
1034#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1035pub enum AccountType {
1036    /// Marker for 0 data
1037    Uninitialized,
1038    /// Mint account with additional extensions
1039    Mint,
1040    /// Token holding account with additional extensions
1041    Account,
1042}
1043impl Default for AccountType {
1044    fn default() -> Self {
1045        Self::Uninitialized
1046    }
1047}
1048
1049/// Extensions that can be applied to mints or accounts.  Mint extensions must
1050/// only be applied to mint accounts, and account extensions must only be
1051/// applied to token holding accounts.
1052#[repr(u16)]
1053#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1054#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
1055#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1056pub enum ExtensionType {
1057    /// Used as padding if the account size would otherwise be 355, same as a
1058    /// multisig
1059    Uninitialized,
1060    /// Includes transfer fee rate info and accompanying authorities to withdraw
1061    /// and set the fee
1062    TransferFeeConfig,
1063    /// Includes withheld transfer fees
1064    TransferFeeAmount,
1065    /// Includes an optional mint close authority
1066    MintCloseAuthority,
1067    /// Auditor configuration for confidential transfers
1068    ConfidentialTransferMint,
1069    /// State for confidential transfers
1070    ConfidentialTransferAccount,
1071    /// Specifies the default Account::state for new Accounts
1072    DefaultAccountState,
1073    /// Indicates that the Account owner authority cannot be changed
1074    ImmutableOwner,
1075    /// Require inbound transfers to have memo
1076    MemoTransfer,
1077    /// Indicates that the tokens from this mint can't be transferred
1078    NonTransferable,
1079    /// Tokens accrue interest over time,
1080    InterestBearingConfig,
1081    /// Locks privileged token operations from happening via CPI
1082    CpiGuard,
1083    /// Includes an optional permanent delegate
1084    PermanentDelegate,
1085    /// Indicates that the tokens in this account belong to a non-transferable
1086    /// mint
1087    NonTransferableAccount,
1088    /// Mint requires a CPI to a program implementing the "transfer hook"
1089    /// interface
1090    TransferHook,
1091    /// Indicates that the tokens in this account belong to a mint with a
1092    /// transfer hook
1093    TransferHookAccount,
1094    /// Includes encrypted withheld fees and the encryption public that they are
1095    /// encrypted under
1096    ConfidentialTransferFeeConfig,
1097    /// Includes confidential withheld transfer fees
1098    ConfidentialTransferFeeAmount,
1099    /// Mint contains a pointer to another account (or the same account) that
1100    /// holds metadata
1101    MetadataPointer,
1102    /// Mint contains token-metadata
1103    TokenMetadata,
1104    /// Mint contains a pointer to another account (or the same account) that
1105    /// holds group configurations
1106    GroupPointer,
1107    /// Mint contains token group configurations
1108    TokenGroup,
1109    /// Mint contains a pointer to another account (or the same account) that
1110    /// holds group member configurations
1111    GroupMemberPointer,
1112    /// Mint contains token group member configurations
1113    TokenGroupMember,
1114    /// Mint allowing the minting and burning of confidential tokens
1115    ConfidentialMintBurn,
1116    /// Tokens whose UI amount is scaled by a given amount
1117    ScaledUiAmount,
1118    /// Tokens where minting / burning / transferring can be paused
1119    Pausable,
1120    /// Indicates that the account belongs to a pausable mint
1121    PausableAccount,
1122
1123    /// Test variable-length mint extension
1124    #[cfg(test)]
1125    VariableLenMintTest = u16::MAX - 2,
1126    /// Padding extension used to make an account exactly Multisig::LEN, used
1127    /// for testing
1128    #[cfg(test)]
1129    AccountPaddingTest,
1130    /// Padding extension used to make a mint exactly Multisig::LEN, used for
1131    /// testing
1132    #[cfg(test)]
1133    MintPaddingTest,
1134}
1135impl TryFrom<&[u8]> for ExtensionType {
1136    type Error = ProgramError;
1137    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1138        Self::try_from(u16::from_le_bytes(
1139            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1140        ))
1141        .map_err(|_| ProgramError::InvalidAccountData)
1142    }
1143}
1144impl From<ExtensionType> for [u8; 2] {
1145    fn from(a: ExtensionType) -> Self {
1146        u16::from(a).to_le_bytes()
1147    }
1148}
1149impl ExtensionType {
1150    /// Returns true if the given extension type is sized
1151    ///
1152    /// Most extension types should be sized, so any variable-length extension
1153    /// types should be added here by hand
1154    const fn sized(&self) -> bool {
1155        match self {
1156            ExtensionType::TokenMetadata => false,
1157            #[cfg(test)]
1158            ExtensionType::VariableLenMintTest => false,
1159            _ => true,
1160        }
1161    }
1162
1163    /// Get the data length of the type associated with the enum
1164    ///
1165    /// Fails if the extension type has a variable length
1166    fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1167        if !self.sized() {
1168            return Err(ProgramError::InvalidArgument);
1169        }
1170        Ok(match self {
1171            ExtensionType::Uninitialized => 0,
1172            ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
1173            ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
1174            ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
1175            ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
1176            ExtensionType::ConfidentialTransferMint => {
1177                pod_get_packed_len::<ConfidentialTransferMint>()
1178            }
1179            ExtensionType::ConfidentialTransferAccount => {
1180                pod_get_packed_len::<ConfidentialTransferAccount>()
1181            }
1182            ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
1183            ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
1184            ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
1185            ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
1186            ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
1187            ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
1188            ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
1189            ExtensionType::TransferHook => pod_get_packed_len::<TransferHook>(),
1190            ExtensionType::TransferHookAccount => pod_get_packed_len::<TransferHookAccount>(),
1191            ExtensionType::ConfidentialTransferFeeConfig => {
1192                pod_get_packed_len::<ConfidentialTransferFeeConfig>()
1193            }
1194            ExtensionType::ConfidentialTransferFeeAmount => {
1195                pod_get_packed_len::<ConfidentialTransferFeeAmount>()
1196            }
1197            ExtensionType::MetadataPointer => pod_get_packed_len::<MetadataPointer>(),
1198            ExtensionType::TokenMetadata => unreachable!(),
1199            ExtensionType::GroupPointer => pod_get_packed_len::<GroupPointer>(),
1200            ExtensionType::TokenGroup => pod_get_packed_len::<TokenGroup>(),
1201            ExtensionType::GroupMemberPointer => pod_get_packed_len::<GroupMemberPointer>(),
1202            ExtensionType::TokenGroupMember => pod_get_packed_len::<TokenGroupMember>(),
1203            ExtensionType::ConfidentialMintBurn => pod_get_packed_len::<ConfidentialMintBurn>(),
1204            ExtensionType::ScaledUiAmount => pod_get_packed_len::<ScaledUiAmountConfig>(),
1205            ExtensionType::Pausable => pod_get_packed_len::<PausableConfig>(),
1206            ExtensionType::PausableAccount => pod_get_packed_len::<PausableAccount>(),
1207            #[cfg(test)]
1208            ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
1209            #[cfg(test)]
1210            ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
1211            #[cfg(test)]
1212            ExtensionType::VariableLenMintTest => unreachable!(),
1213        })
1214    }
1215
1216    /// Get the TLV length for an `ExtensionType`
1217    ///
1218    /// Fails if the extension type has a variable length
1219    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1220        Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1221    }
1222
1223    /// Get the TLV length for a set of `ExtensionType`s
1224    ///
1225    /// Fails if any of the extension types has a variable length
1226    fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1227        // dedupe extensions
1228        let mut extensions = vec![];
1229        for extension_type in extension_types {
1230            if !extensions.contains(&extension_type) {
1231                extensions.push(extension_type);
1232            }
1233        }
1234        extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1235    }
1236
1237    /// Get the required account data length for the given `ExtensionType`s
1238    ///
1239    /// Fails if any of the extension types has a variable length
1240    pub fn try_calculate_account_len<S: BaseState>(
1241        extension_types: &[Self],
1242    ) -> Result<usize, ProgramError> {
1243        if extension_types.is_empty() {
1244            Ok(S::SIZE_OF)
1245        } else {
1246            let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1247            let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1248            Ok(adjust_len_for_multisig(total_len))
1249        }
1250    }
1251
1252    /// Get the associated account type
1253    pub fn get_account_type(&self) -> AccountType {
1254        match self {
1255            ExtensionType::Uninitialized => AccountType::Uninitialized,
1256            ExtensionType::TransferFeeConfig
1257            | ExtensionType::MintCloseAuthority
1258            | ExtensionType::ConfidentialTransferMint
1259            | ExtensionType::DefaultAccountState
1260            | ExtensionType::NonTransferable
1261            | ExtensionType::InterestBearingConfig
1262            | ExtensionType::PermanentDelegate
1263            | ExtensionType::TransferHook
1264            | ExtensionType::ConfidentialTransferFeeConfig
1265            | ExtensionType::MetadataPointer
1266            | ExtensionType::TokenMetadata
1267            | ExtensionType::GroupPointer
1268            | ExtensionType::TokenGroup
1269            | ExtensionType::GroupMemberPointer
1270            | ExtensionType::ConfidentialMintBurn
1271            | ExtensionType::TokenGroupMember
1272            | ExtensionType::ScaledUiAmount
1273            | ExtensionType::Pausable => AccountType::Mint,
1274            ExtensionType::ImmutableOwner
1275            | ExtensionType::TransferFeeAmount
1276            | ExtensionType::ConfidentialTransferAccount
1277            | ExtensionType::MemoTransfer
1278            | ExtensionType::NonTransferableAccount
1279            | ExtensionType::TransferHookAccount
1280            | ExtensionType::CpiGuard
1281            | ExtensionType::ConfidentialTransferFeeAmount
1282            | ExtensionType::PausableAccount => AccountType::Account,
1283            #[cfg(test)]
1284            ExtensionType::VariableLenMintTest => AccountType::Mint,
1285            #[cfg(test)]
1286            ExtensionType::AccountPaddingTest => AccountType::Account,
1287            #[cfg(test)]
1288            ExtensionType::MintPaddingTest => AccountType::Mint,
1289        }
1290    }
1291
1292    /// Based on a set of `AccountType::Mint` `ExtensionType`s, get the list of
1293    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1294    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1295        let mut account_extension_types = vec![];
1296        for extension_type in mint_extension_types {
1297            match extension_type {
1298                ExtensionType::TransferFeeConfig => {
1299                    account_extension_types.push(ExtensionType::TransferFeeAmount);
1300                }
1301                ExtensionType::NonTransferable => {
1302                    account_extension_types.push(ExtensionType::NonTransferableAccount);
1303                    account_extension_types.push(ExtensionType::ImmutableOwner);
1304                }
1305                ExtensionType::TransferHook => {
1306                    account_extension_types.push(ExtensionType::TransferHookAccount);
1307                }
1308                ExtensionType::Pausable => {
1309                    account_extension_types.push(ExtensionType::PausableAccount);
1310                }
1311                #[cfg(test)]
1312                ExtensionType::MintPaddingTest => {
1313                    account_extension_types.push(ExtensionType::AccountPaddingTest);
1314                }
1315                _ => {}
1316            }
1317        }
1318        account_extension_types
1319    }
1320
1321    /// Check for invalid combination of mint extensions
1322    pub fn check_for_invalid_mint_extension_combinations(
1323        mint_extension_types: &[Self],
1324    ) -> Result<(), TokenError> {
1325        let mut transfer_fee_config = false;
1326        let mut confidential_transfer_mint = false;
1327        let mut confidential_transfer_fee_config = false;
1328        let mut confidential_mint_burn = false;
1329        let mut interest_bearing = false;
1330        let mut scaled_ui_amount = false;
1331
1332        for extension_type in mint_extension_types {
1333            match extension_type {
1334                ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1335                ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1336                ExtensionType::ConfidentialTransferFeeConfig => {
1337                    confidential_transfer_fee_config = true
1338                }
1339                ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1340                ExtensionType::InterestBearingConfig => interest_bearing = true,
1341                ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1342                _ => (),
1343            }
1344        }
1345
1346        if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1347        {
1348            return Err(TokenError::InvalidExtensionCombination);
1349        }
1350
1351        if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1352            return Err(TokenError::InvalidExtensionCombination);
1353        }
1354
1355        if confidential_mint_burn && !confidential_transfer_mint {
1356            return Err(TokenError::InvalidExtensionCombination);
1357        }
1358
1359        if scaled_ui_amount && interest_bearing {
1360            return Err(TokenError::InvalidExtensionCombination);
1361        }
1362
1363        Ok(())
1364    }
1365}
1366
1367/// Trait for base states, specifying the associated enum
1368pub trait BaseState: PackedSizeOf + IsInitialized {
1369    /// Associated extension type enum, checked at the start of TLV entries
1370    const ACCOUNT_TYPE: AccountType;
1371}
1372impl BaseState for Account {
1373    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1374}
1375impl BaseState for Mint {
1376    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1377}
1378impl BaseState for PodAccount {
1379    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1380}
1381impl BaseState for PodMint {
1382    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1383}
1384
1385/// Trait to be implemented by all extension states, specifying which extension
1386/// and account type they are associated with
1387pub trait Extension {
1388    /// Associated extension type enum, checked at the start of TLV entries
1389    const TYPE: ExtensionType;
1390}
1391
1392/// Padding a mint account to be exactly `Multisig::LEN`.
1393/// We need to pad 185 bytes, since `Multisig::LEN = 355`, `Account::LEN = 165`,
1394/// `size_of::<AccountType>() = 1`, `size_of::<ExtensionType>() = 2`,
1395/// `size_of::<Length>() = 2`.
1396///
1397/// ```
1398/// assert_eq!(355 - 165 - 1 - 2 - 2, 185);
1399/// ```
1400#[cfg(test)]
1401#[repr(C)]
1402#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1403pub struct MintPaddingTest {
1404    /// Largest value under 185 that implements Pod
1405    pub padding1: [u8; 128],
1406    /// Largest value under 57 that implements Pod
1407    pub padding2: [u8; 48],
1408    /// Exact value needed to finish the padding
1409    pub padding3: [u8; 9],
1410}
1411#[cfg(test)]
1412impl Extension for MintPaddingTest {
1413    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1414}
1415#[cfg(test)]
1416impl Default for MintPaddingTest {
1417    fn default() -> Self {
1418        Self {
1419            padding1: [1; 128],
1420            padding2: [2; 48],
1421            padding3: [3; 9],
1422        }
1423    }
1424}
1425/// Account version of the `MintPadding`
1426#[cfg(test)]
1427#[repr(C)]
1428#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1429pub struct AccountPaddingTest(MintPaddingTest);
1430#[cfg(test)]
1431impl Extension for AccountPaddingTest {
1432    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1433}
1434
1435/// Packs a fixed-length extension into a TLV space
1436///
1437/// This function reallocates the account as needed to accommodate for the
1438/// change in space.
1439///
1440/// If the extension already exists, it will overwrite the existing extension
1441/// if `overwrite` is `true`, otherwise it will return an error.
1442///
1443/// If the extension does not exist, it will reallocate the account and write
1444/// the extension into the TLV buffer.
1445///
1446/// NOTE: Since this function deals with fixed-size extensions, it does not
1447/// handle _decreasing_ the size of an account's data buffer, like the function
1448/// `alloc_and_serialize_variable_len_extension` does.
1449pub fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1450    account_info: &AccountInfo,
1451    new_extension: &V,
1452    overwrite: bool,
1453) -> Result<(), ProgramError> {
1454    let previous_account_len = account_info.try_data_len()?;
1455    let new_account_len = {
1456        let data = account_info.try_borrow_data()?;
1457        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1458        state.try_get_new_account_len::<V>()?
1459    };
1460
1461    // Realloc the account first, if needed
1462    if new_account_len > previous_account_len {
1463        account_info.resize(new_account_len)?;
1464    }
1465    let mut buffer = account_info.try_borrow_mut_data()?;
1466    if previous_account_len <= BASE_ACCOUNT_LENGTH {
1467        set_account_type::<S>(*buffer)?;
1468    }
1469    let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1470
1471    // Write the extension
1472    let extension = state.init_extension::<V>(overwrite)?;
1473    *extension = *new_extension;
1474
1475    Ok(())
1476}
1477
1478/// Packs a variable-length extension into a TLV space
1479///
1480/// This function reallocates the account as needed to accommodate for the
1481/// change in space, then reallocates in the TLV buffer, and finally writes the
1482/// bytes.
1483///
1484/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1485/// size of an account if it has too many bytes allocated for the given value.
1486pub fn alloc_and_serialize_variable_len_extension<
1487    S: BaseState + Pod,
1488    V: Extension + VariableLenPack,
1489>(
1490    account_info: &AccountInfo,
1491    new_extension: &V,
1492    overwrite: bool,
1493) -> Result<(), ProgramError> {
1494    let previous_account_len = account_info.try_data_len()?;
1495    let (new_account_len, extension_already_exists) = {
1496        let data = account_info.try_borrow_data()?;
1497        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1498        let new_account_len =
1499            state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1500        let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1501        (new_account_len, extension_already_exists)
1502    };
1503
1504    if extension_already_exists && !overwrite {
1505        return Err(TokenError::ExtensionAlreadyInitialized.into());
1506    }
1507
1508    if previous_account_len < new_account_len {
1509        // account size increased, so realloc the account, then the TLV entry, then
1510        // write data
1511        account_info.resize(new_account_len)?;
1512        let mut buffer = account_info.try_borrow_mut_data()?;
1513        if extension_already_exists {
1514            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1515            state.realloc_variable_len_extension(new_extension)?;
1516        } else {
1517            if previous_account_len <= BASE_ACCOUNT_LENGTH {
1518                set_account_type::<S>(*buffer)?;
1519            }
1520            // now alloc in the TLV buffer and write the data
1521            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1522            state.init_variable_len_extension(new_extension, false)?;
1523        }
1524    } else {
1525        // do it backwards otherwise, write the state, realloc TLV, then the account
1526        let mut buffer = account_info.try_borrow_mut_data()?;
1527        let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1528        if extension_already_exists {
1529            state.realloc_variable_len_extension(new_extension)?;
1530        } else {
1531            // this situation can happen if we have an overallocated buffer
1532            state.init_variable_len_extension(new_extension, false)?;
1533        }
1534
1535        let removed_bytes = previous_account_len
1536            .checked_sub(new_account_len)
1537            .ok_or(ProgramError::AccountDataTooSmall)?;
1538        if removed_bytes > 0 {
1539            // this is probably fine, but be safe and avoid invalidating references
1540            drop(buffer);
1541            account_info.resize(new_account_len)?;
1542        }
1543    }
1544    Ok(())
1545}
1546
1547#[cfg(test)]
1548mod test {
1549    use {
1550        super::*,
1551        crate::{
1552            pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1553            state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1554        },
1555        bytemuck::Pod,
1556        solana_account_info::{
1557            Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1558        },
1559        solana_pubkey::Pubkey,
1560        spl_pod::{
1561            bytemuck::pod_bytes_of,
1562            optional_keys::OptionalNonZeroPubkey,
1563            primitives::{PodBool, PodU64},
1564        },
1565        transfer_fee::test::test_transfer_fee_config,
1566    };
1567
1568    /// Test fixed-length struct
1569    #[repr(C)]
1570    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1571    struct FixedLenMintTest {
1572        data: [u8; 8],
1573    }
1574    impl Extension for FixedLenMintTest {
1575        const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1576    }
1577
1578    /// Test variable-length struct
1579    #[derive(Clone, Debug, PartialEq)]
1580    struct VariableLenMintTest {
1581        data: Vec<u8>,
1582    }
1583    impl Extension for VariableLenMintTest {
1584        const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1585    }
1586    impl VariableLenPack for VariableLenMintTest {
1587        fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1588            let data_start = size_of::<u64>();
1589            let end = data_start + self.data.len();
1590            if dst.len() < end {
1591                Err(ProgramError::InvalidAccountData)
1592            } else {
1593                dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1594                dst[data_start..end].copy_from_slice(&self.data);
1595                Ok(())
1596            }
1597        }
1598        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1599            let data_start = size_of::<u64>();
1600            let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1601            if src[data_start..data_start + length].len() != length {
1602                return Err(ProgramError::InvalidAccountData);
1603            }
1604            let data = Vec::from(&src[data_start..data_start + length]);
1605            Ok(Self { data })
1606        }
1607        fn get_packed_len(&self) -> Result<usize, ProgramError> {
1608            Ok(size_of::<u64>().saturating_add(self.data.len()))
1609        }
1610    }
1611
1612    const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1613        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1614        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1615        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1616        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1617        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1618        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1619        1, // account type
1620    ];
1621
1622    const MINT_WITH_EXTENSION: &[u8] = &[
1623        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1624        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1625        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1626        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1627        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1628        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1629        1, // account type
1630        3, 0, // extension type
1631        32, 0, // length
1632        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1633        1, 1, // data
1634    ];
1635
1636    const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1637        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1638        1, 1, // mint
1639        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1640        2, 2, // owner
1641        3, 0, 0, 0, 0, 0, 0, 0, // amount
1642        1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1643        4, 4, 4, 4, 4, 4, // delegate
1644        2, // account state
1645        1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, // is native
1646        6, 0, 0, 0, 0, 0, 0, 0, // delegated amount
1647        1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1648        7, 7, 7, 7, 7, 7, // close authority
1649        2, // account type
1650        15, 0, // extension type
1651        1, 0, // length
1652        1, // data
1653    ];
1654
1655    #[test]
1656    fn unpack_opaque_buffer() {
1657        // Mint
1658        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1659        assert_eq!(state.base, &TEST_POD_MINT);
1660        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1661        assert_eq!(state.base, &TEST_POD_MINT);
1662        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1663        let close_authority =
1664            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1665        assert_eq!(extension.close_authority, close_authority);
1666        assert_eq!(
1667            state.get_extension::<TransferFeeConfig>(),
1668            Err(ProgramError::InvalidAccountData)
1669        );
1670        assert_eq!(
1671            PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1672            Err(ProgramError::UninitializedAccount)
1673        );
1674
1675        let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1676        assert_eq!(state.base, &TEST_POD_MINT);
1677
1678        let mut test_mint = TEST_MINT_SLICE.to_vec();
1679        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1680        assert_eq!(state.base, &TEST_POD_MINT);
1681
1682        // Account
1683        let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1684        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1685        let extension = state.get_extension::<TransferHookAccount>().unwrap();
1686        let transferring = PodBool::from(true);
1687        assert_eq!(extension.transferring, transferring);
1688        assert_eq!(
1689            PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1690            Err(ProgramError::InvalidAccountData)
1691        );
1692
1693        let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1694        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1695
1696        let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1697        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1698        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1699    }
1700
1701    #[test]
1702    fn mint_fail_unpack_opaque_buffer() {
1703        // input buffer too small
1704        let mut buffer = vec![0, 3];
1705        assert_eq!(
1706            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1707            Err(ProgramError::InvalidAccountData)
1708        );
1709        assert_eq!(
1710            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1711            Err(ProgramError::InvalidAccountData)
1712        );
1713        assert_eq!(
1714            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1715            Err(ProgramError::InvalidAccountData)
1716        );
1717
1718        // tweak the account type
1719        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1720        buffer[BASE_ACCOUNT_LENGTH] = 3;
1721        assert_eq!(
1722            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1723            Err(ProgramError::InvalidAccountData)
1724        );
1725
1726        // clear the mint initialized byte
1727        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1728        buffer[45] = 0;
1729        assert_eq!(
1730            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1731            Err(ProgramError::UninitializedAccount)
1732        );
1733
1734        // tweak the padding
1735        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1736        buffer[PodMint::SIZE_OF] = 100;
1737        assert_eq!(
1738            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1739            Err(ProgramError::InvalidAccountData)
1740        );
1741
1742        // tweak the extension type
1743        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1744        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1745        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1746        assert_eq!(
1747            state.get_extension::<TransferFeeConfig>(),
1748            Err(ProgramError::InvalidAccountData)
1749        );
1750
1751        // tweak the length, too big
1752        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1753        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1754        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1755        assert_eq!(
1756            state.get_extension::<TransferFeeConfig>(),
1757            Err(ProgramError::InvalidAccountData)
1758        );
1759
1760        // tweak the length, too small
1761        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1762        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1763        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1764        assert_eq!(
1765            state.get_extension::<TransferFeeConfig>(),
1766            Err(ProgramError::InvalidAccountData)
1767        );
1768
1769        // data buffer is too small
1770        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1771        let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1772        assert_eq!(
1773            state.get_extension::<MintCloseAuthority>(),
1774            Err(ProgramError::InvalidAccountData)
1775        );
1776    }
1777
1778    #[test]
1779    fn account_fail_unpack_opaque_buffer() {
1780        // input buffer too small
1781        let mut buffer = vec![0, 3];
1782        assert_eq!(
1783            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1784            Err(ProgramError::InvalidAccountData)
1785        );
1786        assert_eq!(
1787            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1788            Err(ProgramError::InvalidAccountData)
1789        );
1790        assert_eq!(
1791            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1792            Err(ProgramError::InvalidAccountData)
1793        );
1794
1795        // input buffer invalid
1796        // all 5's - not a valid `AccountState`
1797        let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1798        assert_eq!(
1799            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1800            Err(ProgramError::UninitializedAccount)
1801        );
1802        assert_eq!(
1803            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1804            Err(ProgramError::UninitializedAccount)
1805        );
1806
1807        // tweak the account type
1808        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1809        buffer[BASE_ACCOUNT_LENGTH] = 3;
1810        assert_eq!(
1811            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1812            Err(ProgramError::InvalidAccountData)
1813        );
1814
1815        // clear the state byte
1816        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1817        buffer[108] = 0;
1818        assert_eq!(
1819            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1820            Err(ProgramError::UninitializedAccount)
1821        );
1822
1823        // tweak the extension type
1824        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1825        buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1826        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1827        assert_eq!(
1828            state.get_extension::<TransferHookAccount>(),
1829            Err(ProgramError::InvalidAccountData),
1830        );
1831
1832        // tweak the length, too big
1833        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1834        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1835        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1836        assert_eq!(
1837            state.get_extension::<TransferHookAccount>(),
1838            Err(ProgramError::InvalidAccountData)
1839        );
1840
1841        // tweak the length, too small
1842        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1843        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1844        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1845        assert_eq!(
1846            state.get_extension::<TransferHookAccount>(),
1847            Err(ProgramError::InvalidAccountData)
1848        );
1849
1850        // data buffer is too small
1851        let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1852        let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1853        assert_eq!(
1854            state.get_extension::<TransferHookAccount>(),
1855            Err(ProgramError::InvalidAccountData)
1856        );
1857    }
1858
1859    #[test]
1860    fn get_extension_types_with_opaque_buffer() {
1861        // incorrect due to the length
1862        assert_eq!(
1863            get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1864            ProgramError::InvalidAccountData,
1865        );
1866        // incorrect due to the huge enum number
1867        assert_eq!(
1868            get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1869            ProgramError::InvalidAccountData,
1870        );
1871        // correct due to the good enum number and zero length
1872        assert_eq!(
1873            get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1874            TlvDataInfo {
1875                extension_types: vec![ExtensionType::try_from(1).unwrap()],
1876                used_len: add_type_and_length_to_len(0),
1877            }
1878        );
1879        // correct since it's just uninitialized data at the end
1880        assert_eq!(
1881            get_tlv_data_info(&[0, 0]).unwrap(),
1882            TlvDataInfo {
1883                extension_types: vec![],
1884                used_len: 0
1885            }
1886        );
1887    }
1888
1889    #[test]
1890    fn mint_with_extension_pack_unpack() {
1891        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1892            ExtensionType::MintCloseAuthority,
1893            ExtensionType::TransferFeeConfig,
1894        ])
1895        .unwrap();
1896        let mut buffer = vec![0; mint_size];
1897
1898        // fail unpack
1899        assert_eq!(
1900            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1901            Err(ProgramError::UninitializedAccount),
1902        );
1903
1904        let mut state =
1905            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1906        // fail init account extension
1907        assert_eq!(
1908            state.init_extension::<TransferFeeAmount>(true),
1909            Err(ProgramError::InvalidAccountData),
1910        );
1911
1912        // success write extension
1913        let close_authority =
1914            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1915        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1916        extension.close_authority = close_authority;
1917        assert_eq!(
1918            &state.get_extension_types().unwrap(),
1919            &[ExtensionType::MintCloseAuthority]
1920        );
1921
1922        // fail init extension when already initialized
1923        assert_eq!(
1924            state.init_extension::<MintCloseAuthority>(false),
1925            Err(ProgramError::Custom(
1926                TokenError::ExtensionAlreadyInitialized as u32
1927            ))
1928        );
1929
1930        // fail unpack as account, a mint extension was written
1931        assert_eq!(
1932            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1933            Err(ProgramError::Custom(
1934                TokenError::ExtensionBaseMismatch as u32
1935            ))
1936        );
1937
1938        // fail unpack again, still no base data
1939        assert_eq!(
1940            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1941            Err(ProgramError::UninitializedAccount),
1942        );
1943
1944        // write base mint
1945        let mut state =
1946            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1947        *state.base = TEST_POD_MINT;
1948        state.init_account_type().unwrap();
1949
1950        // check raw buffer
1951        let mut expect = TEST_MINT_SLICE.to_vec();
1952        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1953        expect.push(AccountType::Mint.into());
1954        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1955        expect
1956            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1957        expect.extend_from_slice(&[1; 32]); // data
1958        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1959        expect.extend_from_slice(&[0; size_of::<Length>()]);
1960        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1961        assert_eq!(expect, buffer);
1962
1963        // unpack uninitialized will now fail because the PodMint is now initialized
1964        assert_eq!(
1965            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1966            Err(TokenError::AlreadyInUse.into()),
1967        );
1968
1969        // check unpacking
1970        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1971
1972        // update base
1973        *state.base = TEST_POD_MINT;
1974        state.base.supply = (u64::from(state.base.supply) + 100).into();
1975
1976        // check unpacking
1977        let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1978        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1979
1980        // update extension
1981        let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1982        unpacked_extension.close_authority = close_authority;
1983
1984        // check updates are propagated
1985        let base = *state.base;
1986        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1987        assert_eq!(state.base, &base);
1988        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1989        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1990
1991        // check raw buffer
1992        let mut expect = vec![];
1993        expect.extend_from_slice(bytemuck::bytes_of(&base));
1994        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1995        expect.push(AccountType::Mint.into());
1996        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1997        expect
1998            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1999        expect.extend_from_slice(&[0; 32]);
2000        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2001        expect.extend_from_slice(&[0; size_of::<Length>()]);
2002        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2003        assert_eq!(expect, buffer);
2004
2005        // fail unpack as an account
2006        assert_eq!(
2007            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2008            Err(ProgramError::UninitializedAccount),
2009        );
2010
2011        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2012        // init one more extension
2013        let mint_transfer_fee = test_transfer_fee_config();
2014        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2015        new_extension.transfer_fee_config_authority =
2016            mint_transfer_fee.transfer_fee_config_authority;
2017        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2018        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2019        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2020        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2021
2022        assert_eq!(
2023            &state.get_extension_types().unwrap(),
2024            &[
2025                ExtensionType::MintCloseAuthority,
2026                ExtensionType::TransferFeeConfig
2027            ]
2028        );
2029
2030        // check raw buffer
2031        let mut expect = vec![];
2032        expect.extend_from_slice(pod_bytes_of(&base));
2033        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2034        expect.push(AccountType::Mint.into());
2035        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2036        expect
2037            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2038        expect.extend_from_slice(&[0; 32]); // data
2039        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2040        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
2041        expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
2042        assert_eq!(expect, buffer);
2043
2044        // fail to init one more extension that does not fit
2045        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2046        assert_eq!(
2047            state.init_extension::<MintPaddingTest>(true),
2048            Err(ProgramError::InvalidAccountData),
2049        );
2050    }
2051
2052    #[test]
2053    fn mint_extension_any_order() {
2054        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2055            ExtensionType::MintCloseAuthority,
2056            ExtensionType::TransferFeeConfig,
2057        ])
2058        .unwrap();
2059        let mut buffer = vec![0; mint_size];
2060
2061        let mut state =
2062            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2063        // write extensions
2064        let close_authority =
2065            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2066        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2067        extension.close_authority = close_authority;
2068
2069        let mint_transfer_fee = test_transfer_fee_config();
2070        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2071        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2072        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2073        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2074        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2075        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2076
2077        assert_eq!(
2078            &state.get_extension_types().unwrap(),
2079            &[
2080                ExtensionType::MintCloseAuthority,
2081                ExtensionType::TransferFeeConfig
2082            ]
2083        );
2084
2085        // write base mint
2086        let mut state =
2087            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2088        *state.base = TEST_POD_MINT;
2089        state.init_account_type().unwrap();
2090
2091        let mut other_buffer = vec![0; mint_size];
2092        let mut state =
2093            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2094
2095        // write base mint
2096        *state.base = TEST_POD_MINT;
2097        state.init_account_type().unwrap();
2098
2099        // write extensions in a different order
2100        let mint_transfer_fee = test_transfer_fee_config();
2101        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2102        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2103        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2104        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2105        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2106        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2107
2108        let close_authority =
2109            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2110        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2111        extension.close_authority = close_authority;
2112
2113        assert_eq!(
2114            &state.get_extension_types().unwrap(),
2115            &[
2116                ExtensionType::TransferFeeConfig,
2117                ExtensionType::MintCloseAuthority
2118            ]
2119        );
2120
2121        // buffers are NOT the same because written in a different order
2122        assert_ne!(buffer, other_buffer);
2123        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2124        let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2125
2126        // BUT mint and extensions are the same
2127        assert_eq!(
2128            state.get_extension::<TransferFeeConfig>().unwrap(),
2129            other_state.get_extension::<TransferFeeConfig>().unwrap()
2130        );
2131        assert_eq!(
2132            state.get_extension::<MintCloseAuthority>().unwrap(),
2133            other_state.get_extension::<MintCloseAuthority>().unwrap()
2134        );
2135        assert_eq!(state.base, other_state.base);
2136    }
2137
2138    #[test]
2139    fn mint_with_multisig_len() {
2140        let mut buffer = vec![0; Multisig::LEN];
2141        assert_eq!(
2142            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2143            Err(ProgramError::InvalidAccountData),
2144        );
2145        let mint_size =
2146            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2147                .unwrap();
2148        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2149        let mut buffer = vec![0; mint_size];
2150
2151        // write base mint
2152        let mut state =
2153            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2154        *state.base = TEST_POD_MINT;
2155        state.init_account_type().unwrap();
2156
2157        // write padding
2158        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2159        extension.padding1 = [1; 128];
2160        extension.padding2 = [1; 48];
2161        extension.padding3 = [1; 9];
2162
2163        assert_eq!(
2164            &state.get_extension_types().unwrap(),
2165            &[ExtensionType::MintPaddingTest]
2166        );
2167
2168        // check raw buffer
2169        let mut expect = TEST_MINT_SLICE.to_vec();
2170        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2171        expect.push(AccountType::Mint.into());
2172        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2173        expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
2174        expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
2175        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2176        assert_eq!(expect, buffer);
2177    }
2178
2179    #[test]
2180    fn account_with_extension_pack_unpack() {
2181        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2182            ExtensionType::TransferFeeAmount,
2183        ])
2184        .unwrap();
2185        let mut buffer = vec![0; account_size];
2186
2187        // fail unpack
2188        assert_eq!(
2189            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2190            Err(ProgramError::UninitializedAccount),
2191        );
2192
2193        let mut state =
2194            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2195        // fail init mint extension
2196        assert_eq!(
2197            state.init_extension::<TransferFeeConfig>(true),
2198            Err(ProgramError::InvalidAccountData),
2199        );
2200        // success write extension
2201        let withheld_amount = PodU64::from(u64::MAX);
2202        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2203        extension.withheld_amount = withheld_amount;
2204
2205        assert_eq!(
2206            &state.get_extension_types().unwrap(),
2207            &[ExtensionType::TransferFeeAmount]
2208        );
2209
2210        // fail unpack again, still no base data
2211        assert_eq!(
2212            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2213            Err(ProgramError::UninitializedAccount),
2214        );
2215
2216        // write base account
2217        let mut state =
2218            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2219        *state.base = TEST_POD_ACCOUNT;
2220        state.init_account_type().unwrap();
2221        let base = *state.base;
2222
2223        // check raw buffer
2224        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2225        expect.push(AccountType::Account.into());
2226        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2227        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2228        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2229        assert_eq!(expect, buffer);
2230
2231        // check unpacking
2232        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2233        assert_eq!(state.base, &base);
2234        assert_eq!(
2235            &state.get_extension_types().unwrap(),
2236            &[ExtensionType::TransferFeeAmount]
2237        );
2238
2239        // update base
2240        *state.base = TEST_POD_ACCOUNT;
2241        state.base.amount = (u64::from(state.base.amount) + 100).into();
2242
2243        // check unpacking
2244        let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2245        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2246
2247        // update extension
2248        let withheld_amount = PodU64::from(u32::MAX as u64);
2249        unpacked_extension.withheld_amount = withheld_amount;
2250
2251        // check updates are propagated
2252        let base = *state.base;
2253        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2254        assert_eq!(state.base, &base);
2255        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2256        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2257
2258        // check raw buffer
2259        let mut expect = vec![];
2260        expect.extend_from_slice(pod_bytes_of(&base));
2261        expect.push(AccountType::Account.into());
2262        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2263        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2264        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2265        assert_eq!(expect, buffer);
2266
2267        // fail unpack as a mint
2268        assert_eq!(
2269            PodStateWithExtensions::<PodMint>::unpack(&buffer),
2270            Err(ProgramError::InvalidAccountData),
2271        );
2272    }
2273
2274    #[test]
2275    fn account_with_multisig_len() {
2276        let mut buffer = vec![0; Multisig::LEN];
2277        assert_eq!(
2278            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2279            Err(ProgramError::InvalidAccountData),
2280        );
2281        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2282            ExtensionType::AccountPaddingTest,
2283        ])
2284        .unwrap();
2285        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2286        let mut buffer = vec![0; account_size];
2287
2288        // write base account
2289        let mut state =
2290            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2291        *state.base = TEST_POD_ACCOUNT;
2292        state.init_account_type().unwrap();
2293
2294        // write padding
2295        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2296        extension.0.padding1 = [2; 128];
2297        extension.0.padding2 = [2; 48];
2298        extension.0.padding3 = [2; 9];
2299
2300        assert_eq!(
2301            &state.get_extension_types().unwrap(),
2302            &[ExtensionType::AccountPaddingTest]
2303        );
2304
2305        // check raw buffer
2306        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2307        expect.push(AccountType::Account.into());
2308        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2309        expect
2310            .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
2311        expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
2312        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2313        assert_eq!(expect, buffer);
2314    }
2315
2316    #[test]
2317    fn test_set_account_type() {
2318        // account with buffer big enough for AccountType and Extension
2319        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2320        let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2321            ExtensionType::ImmutableOwner,
2322        ])
2323        .unwrap()
2324            - buffer.len();
2325        buffer.append(&mut vec![0; needed_len]);
2326        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2327        assert_eq!(err, ProgramError::InvalidAccountData);
2328        set_account_type::<PodAccount>(&mut buffer).unwrap();
2329        // unpack is viable after manual set_account_type
2330        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2331        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2332        assert_eq!(state.account_type[0], AccountType::Account as u8);
2333        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
2334
2335        // account with buffer big enough for AccountType only
2336        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2337        buffer.append(&mut vec![0; 2]);
2338        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2339        assert_eq!(err, ProgramError::InvalidAccountData);
2340        set_account_type::<PodAccount>(&mut buffer).unwrap();
2341        // unpack is viable after manual set_account_type
2342        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2343        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2344        assert_eq!(state.account_type[0], AccountType::Account as u8);
2345
2346        // account with AccountType already set => noop
2347        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2348        buffer.append(&mut vec![2, 0]);
2349        let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2350        set_account_type::<PodAccount>(&mut buffer).unwrap();
2351        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2352        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2353        assert_eq!(state.account_type[0], AccountType::Account as u8);
2354
2355        // account with wrong AccountType fails
2356        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2357        buffer.append(&mut vec![1, 0]);
2358        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2359        assert_eq!(err, ProgramError::InvalidAccountData);
2360        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2361        assert_eq!(err, ProgramError::InvalidAccountData);
2362
2363        // mint with buffer big enough for AccountType and Extension
2364        let mut buffer = TEST_MINT_SLICE.to_vec();
2365        let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2366            ExtensionType::MintCloseAuthority,
2367        ])
2368        .unwrap()
2369            - buffer.len();
2370        buffer.append(&mut vec![0; needed_len]);
2371        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2372        assert_eq!(err, ProgramError::InvalidAccountData);
2373        set_account_type::<PodMint>(&mut buffer).unwrap();
2374        // unpack is viable after manual set_account_type
2375        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2376        assert_eq!(state.base, &TEST_POD_MINT);
2377        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2378        state.init_extension::<MintCloseAuthority>(true).unwrap();
2379
2380        // mint with buffer big enough for AccountType only
2381        let mut buffer = TEST_MINT_SLICE.to_vec();
2382        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2383        buffer.append(&mut vec![0; 2]);
2384        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2385        assert_eq!(err, ProgramError::InvalidAccountData);
2386        set_account_type::<PodMint>(&mut buffer).unwrap();
2387        // unpack is viable after manual set_account_type
2388        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2389        assert_eq!(state.base, &TEST_POD_MINT);
2390        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2391
2392        // mint with AccountType already set => noop
2393        let mut buffer = TEST_MINT_SLICE.to_vec();
2394        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2395        buffer.append(&mut vec![1, 0]);
2396        set_account_type::<PodMint>(&mut buffer).unwrap();
2397        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2398        assert_eq!(state.base, &TEST_POD_MINT);
2399        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2400
2401        // mint with wrong AccountType fails
2402        let mut buffer = TEST_MINT_SLICE.to_vec();
2403        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2404        buffer.append(&mut vec![2, 0]);
2405        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2406        assert_eq!(err, ProgramError::InvalidAccountData);
2407        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2408        assert_eq!(err, ProgramError::InvalidAccountData);
2409    }
2410
2411    #[test]
2412    fn test_set_account_type_wrongly() {
2413        // try to set PodAccount account_type to PodMint
2414        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2415        buffer.append(&mut vec![0; 2]);
2416        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2417        assert_eq!(err, ProgramError::InvalidAccountData);
2418
2419        // try to set PodMint account_type to PodAccount
2420        let mut buffer = TEST_MINT_SLICE.to_vec();
2421        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2422        buffer.append(&mut vec![0; 2]);
2423        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2424        assert_eq!(err, ProgramError::InvalidAccountData);
2425    }
2426
2427    #[test]
2428    fn test_get_required_init_account_extensions() {
2429        // Some mint extensions with no required account extensions
2430        let mint_extensions = vec![
2431            ExtensionType::MintCloseAuthority,
2432            ExtensionType::Uninitialized,
2433        ];
2434        assert_eq!(
2435            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2436            vec![]
2437        );
2438
2439        // One mint extension with required account extension, one without
2440        let mint_extensions = vec![
2441            ExtensionType::TransferFeeConfig,
2442            ExtensionType::MintCloseAuthority,
2443        ];
2444        assert_eq!(
2445            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2446            vec![ExtensionType::TransferFeeAmount]
2447        );
2448
2449        // Some mint extensions both with required account extensions
2450        let mint_extensions = vec![
2451            ExtensionType::TransferFeeConfig,
2452            ExtensionType::MintPaddingTest,
2453        ];
2454        assert_eq!(
2455            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2456            vec![
2457                ExtensionType::TransferFeeAmount,
2458                ExtensionType::AccountPaddingTest
2459            ]
2460        );
2461
2462        // Demonstrate that method does not dedupe inputs or outputs
2463        let mint_extensions = vec![
2464            ExtensionType::TransferFeeConfig,
2465            ExtensionType::TransferFeeConfig,
2466        ];
2467        assert_eq!(
2468            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2469            vec![
2470                ExtensionType::TransferFeeAmount,
2471                ExtensionType::TransferFeeAmount
2472            ]
2473        );
2474    }
2475
2476    #[test]
2477    fn mint_without_extensions() {
2478        let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2479        let mut buffer = vec![0; space];
2480        assert_eq!(
2481            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2482            Err(ProgramError::InvalidAccountData),
2483        );
2484
2485        // write base account
2486        let mut state =
2487            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2488        *state.base = TEST_POD_MINT;
2489        state.init_account_type().unwrap();
2490
2491        // fail init extension
2492        assert_eq!(
2493            state.init_extension::<TransferFeeConfig>(true),
2494            Err(ProgramError::InvalidAccountData),
2495        );
2496
2497        assert_eq!(TEST_MINT_SLICE, buffer);
2498    }
2499
2500    #[test]
2501    fn test_init_nonzero_default() {
2502        let mint_size =
2503            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2504                .unwrap();
2505        let mut buffer = vec![0; mint_size];
2506        let mut state =
2507            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2508        *state.base = TEST_POD_MINT;
2509        state.init_account_type().unwrap();
2510        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2511        assert_eq!(extension.padding1, [1; 128]);
2512        assert_eq!(extension.padding2, [2; 48]);
2513        assert_eq!(extension.padding3, [3; 9]);
2514    }
2515
2516    #[test]
2517    fn test_init_buffer_too_small() {
2518        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2519            ExtensionType::MintCloseAuthority,
2520        ])
2521        .unwrap();
2522        let mut buffer = vec![0; mint_size - 1];
2523        let mut state =
2524            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2525        let err = state
2526            .init_extension::<MintCloseAuthority>(true)
2527            .unwrap_err();
2528        assert_eq!(err, ProgramError::InvalidAccountData);
2529
2530        state.tlv_data[0] = 3;
2531        state.tlv_data[2] = 32;
2532        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2533        assert_eq!(err, ProgramError::InvalidAccountData);
2534
2535        let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2536        let err =
2537            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2538        assert_eq!(err, ProgramError::InvalidAccountData);
2539
2540        // OK since there are two bytes for the type, which is `Uninitialized`
2541        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2542        let mut state =
2543            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2544        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2545        assert_eq!(err, ProgramError::InvalidAccountData);
2546
2547        assert_eq!(state.get_extension_types().unwrap(), vec![]);
2548
2549        // OK, there aren't two bytes for the type, but that's fine
2550        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2551        let state =
2552            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2553        assert_eq!(state.get_extension_types().unwrap(), []);
2554    }
2555
2556    #[test]
2557    fn test_extension_with_no_data() {
2558        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2559            ExtensionType::ImmutableOwner,
2560        ])
2561        .unwrap();
2562        let mut buffer = vec![0; account_size];
2563        let mut state =
2564            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2565        *state.base = TEST_POD_ACCOUNT;
2566        state.init_account_type().unwrap();
2567
2568        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2569        assert_eq!(
2570            err,
2571            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2572        );
2573
2574        state.init_extension::<ImmutableOwner>(true).unwrap();
2575        assert_eq!(
2576            get_first_extension_type(state.tlv_data).unwrap(),
2577            Some(ExtensionType::ImmutableOwner)
2578        );
2579        assert_eq!(
2580            get_tlv_data_info(state.tlv_data).unwrap(),
2581            TlvDataInfo {
2582                extension_types: vec![ExtensionType::ImmutableOwner],
2583                used_len: add_type_and_length_to_len(0)
2584            }
2585        );
2586    }
2587
2588    #[test]
2589    fn fail_account_len_with_metadata() {
2590        assert_eq!(
2591            ExtensionType::try_calculate_account_len::<PodMint>(&[
2592                ExtensionType::MintCloseAuthority,
2593                ExtensionType::VariableLenMintTest,
2594                ExtensionType::TransferFeeConfig,
2595            ])
2596            .unwrap_err(),
2597            ProgramError::InvalidArgument
2598        );
2599    }
2600
2601    #[test]
2602    fn alloc() {
2603        let variable_len = VariableLenMintTest { data: vec![1] };
2604        let alloc_size = variable_len.get_packed_len().unwrap();
2605        let account_size =
2606            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2607        let mut buffer = vec![0; account_size];
2608        let mut state =
2609            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2610        state
2611            .init_variable_len_extension(&variable_len, false)
2612            .unwrap();
2613
2614        // can't double alloc
2615        assert_eq!(
2616            state
2617                .init_variable_len_extension(&variable_len, false)
2618                .unwrap_err(),
2619            TokenError::ExtensionAlreadyInitialized.into()
2620        );
2621
2622        // unless overwrite is set
2623        state
2624            .init_variable_len_extension(&variable_len, true)
2625            .unwrap();
2626
2627        // can't change the size during overwrite though
2628        assert_eq!(
2629            state
2630                .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2631                .unwrap_err(),
2632            TokenError::InvalidLengthForAlloc.into()
2633        );
2634
2635        // try to write too far, fail earlier
2636        assert_eq!(
2637            state
2638                .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2639                .unwrap_err(),
2640            ProgramError::InvalidAccountData
2641        );
2642    }
2643
2644    #[test]
2645    fn realloc() {
2646        let small_variable_len = VariableLenMintTest {
2647            data: vec![1, 2, 3],
2648        };
2649        let base_variable_len = VariableLenMintTest {
2650            data: vec![1, 2, 3, 4],
2651        };
2652        let big_variable_len = VariableLenMintTest {
2653            data: vec![1, 2, 3, 4, 5],
2654        };
2655        let too_big_variable_len = VariableLenMintTest {
2656            data: vec![1, 2, 3, 4, 5, 6],
2657        };
2658        let account_size =
2659            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2660                .unwrap()
2661                + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2662        let mut buffer = vec![0; account_size];
2663        let mut state =
2664            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2665
2666        // alloc both types
2667        state
2668            .init_variable_len_extension(&base_variable_len, false)
2669            .unwrap();
2670        let max_pubkey =
2671            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
2672        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2673        extension.authority = max_pubkey;
2674        extension.metadata_address = max_pubkey;
2675
2676        // realloc first entry to larger
2677        state
2678            .realloc_variable_len_extension(&big_variable_len)
2679            .unwrap();
2680        let extension = state
2681            .get_variable_len_extension::<VariableLenMintTest>()
2682            .unwrap();
2683        assert_eq!(extension, big_variable_len);
2684        let extension = state.get_extension::<MetadataPointer>().unwrap();
2685        assert_eq!(extension.authority, max_pubkey);
2686        assert_eq!(extension.metadata_address, max_pubkey);
2687
2688        // realloc to smaller
2689        state
2690            .realloc_variable_len_extension(&small_variable_len)
2691            .unwrap();
2692        let extension = state
2693            .get_variable_len_extension::<VariableLenMintTest>()
2694            .unwrap();
2695        assert_eq!(extension, small_variable_len);
2696        let extension = state.get_extension::<MetadataPointer>().unwrap();
2697        assert_eq!(extension.authority, max_pubkey);
2698        assert_eq!(extension.metadata_address, max_pubkey);
2699        let diff = big_variable_len.get_packed_len().unwrap()
2700            - small_variable_len.get_packed_len().unwrap();
2701        assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2702
2703        // unpack again since we dropped the last `state`
2704        let mut state =
2705            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2706        // realloc too much, fails
2707        assert_eq!(
2708            state
2709                .realloc_variable_len_extension(&too_big_variable_len)
2710                .unwrap_err(),
2711            ProgramError::InvalidAccountData,
2712        );
2713    }
2714
2715    #[test]
2716    fn account_len() {
2717        let small_variable_len = VariableLenMintTest {
2718            data: vec![20, 30, 40],
2719        };
2720        let variable_len = VariableLenMintTest {
2721            data: vec![20, 30, 40, 50],
2722        };
2723        let big_variable_len = VariableLenMintTest {
2724            data: vec![20, 30, 40, 50, 60],
2725        };
2726        let value_len = variable_len.get_packed_len().unwrap();
2727        let account_size =
2728            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2729        let mut buffer = vec![0; account_size];
2730        let mut state =
2731            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2732
2733        // With a new extension, new length must include padding, 1 byte for
2734        // account type, 2 bytes for type, 2 for length
2735        let current_len = state.try_get_account_len().unwrap();
2736        assert_eq!(current_len, PodMint::SIZE_OF);
2737        let new_len = state
2738            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2739                &variable_len,
2740            )
2741            .unwrap();
2742        assert_eq!(
2743            new_len,
2744            BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2745        );
2746
2747        state
2748            .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2749            .unwrap();
2750        let current_len = state.try_get_account_len().unwrap();
2751        assert_eq!(current_len, new_len);
2752
2753        // Reduce the extension size
2754        let new_len = state
2755            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2756                &small_variable_len,
2757            )
2758            .unwrap();
2759        assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2760
2761        // Increase the extension size
2762        let new_len = state
2763            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2764                &big_variable_len,
2765            )
2766            .unwrap();
2767        assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2768
2769        // Maintain the extension size
2770        let new_len = state
2771            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2772                &variable_len,
2773            )
2774            .unwrap();
2775        assert_eq!(new_len, current_len);
2776    }
2777
2778    /// Test helper for mimicking the data layout an on-chain `AccountInfo`,
2779    /// which permits "reallocs" as the Solana runtime does it
2780    struct SolanaAccountData {
2781        data: Vec<u8>,
2782        lamports: u64,
2783        owner: Pubkey,
2784    }
2785    impl SolanaAccountData {
2786        /// Create a new fake solana account data. The underlying vector is
2787        /// overallocated to mimic the runtime
2788        fn new(account_data: &[u8]) -> Self {
2789            let mut data = vec![];
2790            data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2791            data.extend_from_slice(account_data);
2792            data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2793            Self {
2794                data,
2795                lamports: 10,
2796                owner: Pubkey::new_unique(),
2797            }
2798        }
2799
2800        /// Data lops off the first 8 bytes, since those store the size of the
2801        /// account for the Solana runtime
2802        fn data(&self) -> &[u8] {
2803            let start = size_of::<u64>();
2804            let len = self.len();
2805            &self.data[start..start + len]
2806        }
2807
2808        /// Gets the runtime length of the account data
2809        fn len(&self) -> usize {
2810            self.data
2811                .get(..size_of::<u64>())
2812                .and_then(|slice| slice.try_into().ok())
2813                .map(u64::from_le_bytes)
2814                .unwrap() as usize
2815        }
2816    }
2817    impl GetAccount for SolanaAccountData {
2818        fn get(&mut self) -> (&mut u64, &mut [u8], &Pubkey, bool) {
2819            // need to pull out the data here to avoid a double-mutable borrow
2820            let start = size_of::<u64>();
2821            let len = self.len();
2822            (
2823                &mut self.lamports,
2824                &mut self.data[start..start + len],
2825                &self.owner,
2826                false,
2827            )
2828        }
2829    }
2830
2831    #[test]
2832    fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2833        let fixed_len = FixedLenMintTest {
2834            data: [1, 2, 3, 4, 5, 6, 7, 8],
2835        };
2836        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2837        let base_account_size = PodMint::SIZE_OF;
2838        let mut buffer = vec![0; base_account_size];
2839        let state =
2840            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2841        *state.base = TEST_POD_MINT;
2842
2843        let mut data = SolanaAccountData::new(&buffer);
2844        let key = Pubkey::new_unique();
2845        let account_info = (&key, &mut data).into_account_info();
2846
2847        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2848        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2849        assert_eq!(data.len(), new_account_len);
2850        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2851        assert_eq!(
2852            state.get_extension::<FixedLenMintTest>().unwrap(),
2853            &fixed_len,
2854        );
2855
2856        // alloc again succeeds with "overwrite"
2857        let account_info = (&key, &mut data).into_account_info();
2858        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2859
2860        // alloc again fails without "overwrite"
2861        let account_info = (&key, &mut data).into_account_info();
2862        assert_eq!(
2863            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2864            TokenError::ExtensionAlreadyInitialized.into()
2865        );
2866    }
2867
2868    #[test]
2869    fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2870        let variable_len = VariableLenMintTest { data: vec![20, 99] };
2871        let value_len = variable_len.get_packed_len().unwrap();
2872        let base_account_size = PodMint::SIZE_OF;
2873        let mut buffer = vec![0; base_account_size];
2874        let state =
2875            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2876        *state.base = TEST_POD_MINT;
2877
2878        let mut data = SolanaAccountData::new(&buffer);
2879        let key = Pubkey::new_unique();
2880        let account_info = (&key, &mut data).into_account_info();
2881
2882        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2883            &account_info,
2884            &variable_len,
2885            false,
2886        )
2887        .unwrap();
2888        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2889        assert_eq!(data.len(), new_account_len);
2890        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2891        assert_eq!(
2892            state
2893                .get_variable_len_extension::<VariableLenMintTest>()
2894                .unwrap(),
2895            variable_len
2896        );
2897
2898        // alloc again succeeds with "overwrite"
2899        let account_info = (&key, &mut data).into_account_info();
2900        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2901            &account_info,
2902            &variable_len,
2903            true,
2904        )
2905        .unwrap();
2906
2907        // alloc again fails without "overwrite"
2908        let account_info = (&key, &mut data).into_account_info();
2909        assert_eq!(
2910            alloc_and_serialize_variable_len_extension::<PodMint, _>(
2911                &account_info,
2912                &variable_len,
2913                false,
2914            )
2915            .unwrap_err(),
2916            TokenError::ExtensionAlreadyInitialized.into()
2917        );
2918    }
2919
2920    #[test]
2921    fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2922        let fixed_len = FixedLenMintTest {
2923            data: [1, 2, 3, 4, 5, 6, 7, 8],
2924        };
2925        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2926        let account_size =
2927            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2928                .unwrap()
2929                + add_type_and_length_to_len(value_len);
2930        let mut buffer = vec![0; account_size];
2931        let mut state =
2932            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2933        *state.base = TEST_POD_MINT;
2934        state.init_account_type().unwrap();
2935
2936        let test_key =
2937            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2938        let extension = state.init_extension::<GroupPointer>(false).unwrap();
2939        extension.authority = test_key;
2940        extension.group_address = test_key;
2941
2942        let mut data = SolanaAccountData::new(&buffer);
2943        let key = Pubkey::new_unique();
2944        let account_info = (&key, &mut data).into_account_info();
2945
2946        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2947        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2948            + add_type_and_length_to_len(value_len)
2949            + add_type_and_length_to_len(size_of::<GroupPointer>());
2950        assert_eq!(data.len(), new_account_len);
2951        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2952        assert_eq!(
2953            state.get_extension::<FixedLenMintTest>().unwrap(),
2954            &fixed_len,
2955        );
2956        let extension = state.get_extension::<GroupPointer>().unwrap();
2957        assert_eq!(extension.authority, test_key);
2958        assert_eq!(extension.group_address, test_key);
2959
2960        // alloc again succeeds with "overwrite"
2961        let account_info = (&key, &mut data).into_account_info();
2962        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2963
2964        // alloc again fails without "overwrite"
2965        let account_info = (&key, &mut data).into_account_info();
2966        assert_eq!(
2967            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2968            TokenError::ExtensionAlreadyInitialized.into()
2969        );
2970    }
2971
2972    #[test]
2973    fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2974        let variable_len = VariableLenMintTest { data: vec![42, 6] };
2975        let value_len = variable_len.get_packed_len().unwrap();
2976        let account_size =
2977            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2978                .unwrap()
2979                + add_type_and_length_to_len(value_len);
2980        let mut buffer = vec![0; account_size];
2981        let mut state =
2982            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2983        *state.base = TEST_POD_MINT;
2984        state.init_account_type().unwrap();
2985
2986        let test_key =
2987            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2988        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2989        extension.authority = test_key;
2990        extension.metadata_address = test_key;
2991
2992        let mut data = SolanaAccountData::new(&buffer);
2993        let key = Pubkey::new_unique();
2994        let account_info = (&key, &mut data).into_account_info();
2995
2996        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2997            &account_info,
2998            &variable_len,
2999            false,
3000        )
3001        .unwrap();
3002        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3003            + add_type_and_length_to_len(value_len)
3004            + add_type_and_length_to_len(size_of::<MetadataPointer>());
3005        assert_eq!(data.len(), new_account_len);
3006        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3007        assert_eq!(
3008            state
3009                .get_variable_len_extension::<VariableLenMintTest>()
3010                .unwrap(),
3011            variable_len
3012        );
3013        let extension = state.get_extension::<MetadataPointer>().unwrap();
3014        assert_eq!(extension.authority, test_key);
3015        assert_eq!(extension.metadata_address, test_key);
3016
3017        // alloc again succeeds with "overwrite"
3018        let account_info = (&key, &mut data).into_account_info();
3019        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3020            &account_info,
3021            &variable_len,
3022            true,
3023        )
3024        .unwrap();
3025
3026        // alloc again fails without "overwrite"
3027        let account_info = (&key, &mut data).into_account_info();
3028        assert_eq!(
3029            alloc_and_serialize_variable_len_extension::<PodMint, _>(
3030                &account_info,
3031                &variable_len,
3032                false,
3033            )
3034            .unwrap_err(),
3035            TokenError::ExtensionAlreadyInitialized.into()
3036        );
3037    }
3038
3039    #[test]
3040    fn realloc_variable_len_tlv_in_account_info() {
3041        let variable_len = VariableLenMintTest {
3042            data: vec![1, 2, 3, 4, 5],
3043        };
3044        let alloc_size = variable_len.get_packed_len().unwrap();
3045        let account_size =
3046            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3047                .unwrap()
3048                + add_type_and_length_to_len(alloc_size);
3049        let mut buffer = vec![0; account_size];
3050        let mut state =
3051            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3052        *state.base = TEST_POD_MINT;
3053        state.init_account_type().unwrap();
3054
3055        // alloc both types
3056        state
3057            .init_variable_len_extension(&variable_len, false)
3058            .unwrap();
3059        let max_pubkey =
3060            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
3061        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3062        extension.authority = max_pubkey;
3063        extension.metadata_address = max_pubkey;
3064
3065        // reallocate to smaller, make sure existing extension is fine
3066        let mut data = SolanaAccountData::new(&buffer);
3067        let key = Pubkey::new_unique();
3068        let account_info = (&key, &mut data).into_account_info();
3069        let variable_len = VariableLenMintTest { data: vec![1, 2] };
3070        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3071            &account_info,
3072            &variable_len,
3073            true,
3074        )
3075        .unwrap();
3076
3077        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3078        let extension = state.get_extension::<MetadataPointer>().unwrap();
3079        assert_eq!(extension.authority, max_pubkey);
3080        assert_eq!(extension.metadata_address, max_pubkey);
3081        let extension = state
3082            .get_variable_len_extension::<VariableLenMintTest>()
3083            .unwrap();
3084        assert_eq!(extension, variable_len);
3085        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3086
3087        // reallocate to larger
3088        let account_info = (&key, &mut data).into_account_info();
3089        let variable_len = VariableLenMintTest {
3090            data: vec![1, 2, 3, 4, 5, 6, 7],
3091        };
3092        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3093            &account_info,
3094            &variable_len,
3095            true,
3096        )
3097        .unwrap();
3098
3099        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3100        let extension = state.get_extension::<MetadataPointer>().unwrap();
3101        assert_eq!(extension.authority, max_pubkey);
3102        assert_eq!(extension.metadata_address, max_pubkey);
3103        let extension = state
3104            .get_variable_len_extension::<VariableLenMintTest>()
3105            .unwrap();
3106        assert_eq!(extension, variable_len);
3107        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3108
3109        // reallocate to same
3110        let account_info = (&key, &mut data).into_account_info();
3111        let variable_len = VariableLenMintTest {
3112            data: vec![7, 6, 5, 4, 3, 2, 1],
3113        };
3114        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3115            &account_info,
3116            &variable_len,
3117            true,
3118        )
3119        .unwrap();
3120
3121        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3122        let extension = state.get_extension::<MetadataPointer>().unwrap();
3123        assert_eq!(extension.authority, max_pubkey);
3124        assert_eq!(extension.metadata_address, max_pubkey);
3125        let extension = state
3126            .get_variable_len_extension::<VariableLenMintTest>()
3127            .unwrap();
3128        assert_eq!(extension, variable_len);
3129        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3130    }
3131}