spl_token_2022/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3#[cfg(feature = "serde-traits")]
4use serde::{Deserialize, Serialize};
5use {
6    crate::{
7        error::TokenError,
8        extension::{
9            confidential_mint_burn::ConfidentialMintBurn,
10            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
11            confidential_transfer_fee::{
12                ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
13            },
14            cpi_guard::CpiGuard,
15            default_account_state::DefaultAccountState,
16            group_member_pointer::GroupMemberPointer,
17            group_pointer::GroupPointer,
18            immutable_owner::ImmutableOwner,
19            interest_bearing_mint::InterestBearingConfig,
20            memo_transfer::MemoTransfer,
21            metadata_pointer::MetadataPointer,
22            mint_close_authority::MintCloseAuthority,
23            non_transferable::{NonTransferable, NonTransferableAccount},
24            pausable::{PausableAccount, PausableConfig},
25            permanent_delegate::PermanentDelegate,
26            scaled_ui_amount::ScaledUiAmountConfig,
27            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
28            transfer_hook::{TransferHook, TransferHookAccount},
29        },
30        pod::{PodAccount, PodMint},
31        state::{Account, Mint, Multisig, PackedSizeOf},
32    },
33    bytemuck::{Pod, Zeroable},
34    num_enum::{IntoPrimitive, TryFromPrimitive},
35    solana_account_info::AccountInfo,
36    solana_program_error::ProgramError,
37    solana_program_pack::{IsInitialized, Pack},
38    spl_pod::{
39        bytemuck::{pod_from_bytes, pod_from_bytes_mut, pod_get_packed_len},
40        primitives::PodU16,
41    },
42    spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
43    spl_type_length_value::variable_len_pack::VariableLenPack,
44    std::{
45        cmp::Ordering,
46        convert::{TryFrom, TryInto},
47        mem::size_of,
48    },
49};
50
51/// Confidential Transfer extension
52pub mod confidential_transfer;
53/// Confidential Transfer Fee extension
54pub mod confidential_transfer_fee;
55/// CPI Guard extension
56pub mod cpi_guard;
57/// Default Account State extension
58pub mod default_account_state;
59/// Group Member Pointer extension
60pub mod group_member_pointer;
61/// Group Pointer extension
62pub mod group_pointer;
63/// Immutable Owner extension
64pub mod immutable_owner;
65/// Interest-Bearing Mint extension
66pub mod interest_bearing_mint;
67/// Memo Transfer extension
68pub mod memo_transfer;
69/// Metadata Pointer extension
70pub mod metadata_pointer;
71/// Mint Close Authority extension
72pub mod mint_close_authority;
73/// Non Transferable extension
74pub mod non_transferable;
75/// Pausable extension
76pub mod pausable;
77/// Permanent Delegate extension
78pub mod permanent_delegate;
79/// Utility to reallocate token accounts
80pub mod reallocate;
81/// Scaled UI Amount extension
82pub mod scaled_ui_amount;
83/// Token-group extension
84pub mod token_group;
85/// Token-metadata extension
86pub mod token_metadata;
87/// Transfer Fee extension
88pub mod transfer_fee;
89/// Transfer Hook extension
90pub mod transfer_hook;
91
92/// Confidential mint-burn extension
93pub mod confidential_mint_burn;
94
95/// Length in TLV structure
96#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
97#[repr(transparent)]
98pub struct Length(PodU16);
99impl From<Length> for usize {
100    fn from(n: Length) -> Self {
101        Self::from(u16::from(n.0))
102    }
103}
104impl TryFrom<usize> for Length {
105    type Error = ProgramError;
106    fn try_from(n: usize) -> Result<Self, Self::Error> {
107        u16::try_from(n)
108            .map(|v| Self(PodU16::from(v)))
109            .map_err(|_| ProgramError::AccountDataTooSmall)
110    }
111}
112
113/// Helper function to get the current `TlvIndices` from the current spot
114fn get_tlv_indices(type_start: usize) -> TlvIndices {
115    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
116    let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
117    TlvIndices {
118        type_start,
119        length_start,
120        value_start,
121    }
122}
123
124/// Helper function to tack on the size of an extension bytes if an account with
125/// extensions is exactly the size of a multisig
126const fn adjust_len_for_multisig(account_len: usize) -> usize {
127    if account_len == Multisig::LEN {
128        account_len.saturating_add(size_of::<ExtensionType>())
129    } else {
130        account_len
131    }
132}
133
134/// Helper function to calculate exactly how many bytes a value will take up,
135/// given the value's length
136const fn add_type_and_length_to_len(value_len: usize) -> usize {
137    value_len
138        .saturating_add(size_of::<ExtensionType>())
139        .saturating_add(pod_get_packed_len::<Length>())
140}
141
142/// Helper struct for returning the indices of the type, length, and value in
143/// a TLV entry
144#[derive(Debug)]
145struct TlvIndices {
146    pub type_start: usize,
147    pub length_start: usize,
148    pub value_start: usize,
149}
150fn get_extension_indices<V: Extension>(
151    tlv_data: &[u8],
152    init: bool,
153) -> Result<TlvIndices, ProgramError> {
154    let mut start_index = 0;
155    let v_account_type = V::TYPE.get_account_type();
156    while start_index < tlv_data.len() {
157        let tlv_indices = get_tlv_indices(start_index);
158        if tlv_data.len() < tlv_indices.value_start {
159            return Err(ProgramError::InvalidAccountData);
160        }
161        let extension_type =
162            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
163        let account_type = extension_type.get_account_type();
164        if extension_type == V::TYPE {
165            // found an instance of the extension that we're initializing, return!
166            return Ok(tlv_indices);
167        // got to an empty spot, init here, or error if we're searching, since
168        // nothing is written after an Uninitialized spot
169        } else if extension_type == ExtensionType::Uninitialized {
170            if init {
171                return Ok(tlv_indices);
172            } else {
173                return Err(TokenError::ExtensionNotFound.into());
174            }
175        } else if v_account_type != account_type {
176            return Err(TokenError::ExtensionTypeMismatch.into());
177        } else {
178            let length = pod_from_bytes::<Length>(
179                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
180            )?;
181            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
182            start_index = value_end_index;
183        }
184    }
185    Err(ProgramError::InvalidAccountData)
186}
187
188/// Basic information about the TLV buffer, collected from iterating through all
189/// entries
190#[derive(Debug, PartialEq)]
191struct TlvDataInfo {
192    /// The extension types written in the TLV buffer
193    extension_types: Vec<ExtensionType>,
194    /// The total number bytes allocated for all TLV entries.
195    ///
196    /// Each TLV entry's allocated bytes comprises two bytes for the `type`, two
197    /// bytes for the `length`, and `length` number of bytes for the `value`.
198    used_len: usize,
199}
200
201/// Fetches basic information about the TLV buffer by iterating through all
202/// TLV entries.
203fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
204    let mut extension_types = vec![];
205    let mut start_index = 0;
206    while start_index < tlv_data.len() {
207        let tlv_indices = get_tlv_indices(start_index);
208        if tlv_data.len() < tlv_indices.length_start {
209            // There aren't enough bytes to store the next type, which means we
210            // got to the end. The last byte could be used during a realloc!
211            return Ok(TlvDataInfo {
212                extension_types,
213                used_len: tlv_indices.type_start,
214            });
215        }
216        let extension_type =
217            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
218        if extension_type == ExtensionType::Uninitialized {
219            return Ok(TlvDataInfo {
220                extension_types,
221                used_len: tlv_indices.type_start,
222            });
223        } else {
224            if tlv_data.len() < tlv_indices.value_start {
225                // not enough bytes to store the length, malformed
226                return Err(ProgramError::InvalidAccountData);
227            }
228            extension_types.push(extension_type);
229            let length = pod_from_bytes::<Length>(
230                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
231            )?;
232
233            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
234            if value_end_index > tlv_data.len() {
235                // value blows past the size of the slice, malformed
236                return Err(ProgramError::InvalidAccountData);
237            }
238            start_index = value_end_index;
239        }
240    }
241    Ok(TlvDataInfo {
242        extension_types,
243        used_len: start_index,
244    })
245}
246
247fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
248    if tlv_data.is_empty() {
249        Ok(None)
250    } else {
251        let tlv_indices = get_tlv_indices(0);
252        if tlv_data.len() <= tlv_indices.length_start {
253            return Ok(None);
254        }
255        let extension_type =
256            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
257        if extension_type == ExtensionType::Uninitialized {
258            Ok(None)
259        } else {
260            Ok(Some(extension_type))
261        }
262    }
263}
264
265fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
266    if input.len() == Multisig::LEN || input.len() < minimum_len {
267        Err(ProgramError::InvalidAccountData)
268    } else {
269        Ok(())
270    }
271}
272
273fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
274    if account_type != S::ACCOUNT_TYPE {
275        Err(ProgramError::InvalidAccountData)
276    } else {
277        Ok(())
278    }
279}
280
281/// Any account with extensions must be at least `Account::LEN`.  Both mints and
282/// accounts can have extensions
283/// A mint with extensions that takes it past 165 could be indiscernible from an
284/// Account with an extension, even if we add the account type. For example,
285/// let's say we have:
286///
287/// ```text
288/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
289///                          ^     ^       ^     ^
290///                     acct type  extension length data...
291///
292/// Mint: 82 bytes... + 83 bytes of other extension data
293///     + [2, 0, 3, 0, 100, ....]
294///      (data in extension just happens to look like this)
295/// ```
296///
297/// With this approach, we only start writing the TLV data after `Account::LEN`,
298/// which means we always know that the account type is going to be right after
299/// that. We do a special case checking for a Multisig length, because those
300/// aren't extensible under any circumstances.
301const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
302/// Helper that tacks on the `AccountType` length, which gives the minimum for
303/// any account with extensions
304const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
305
306fn type_and_tlv_indices<S: BaseState>(
307    rest_input: &[u8],
308) -> Result<Option<(usize, usize)>, ProgramError> {
309    if rest_input.is_empty() {
310        Ok(None)
311    } else {
312        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
313        // check padding is all zeroes
314        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
315        if rest_input.len() <= tlv_start_index {
316            return Err(ProgramError::InvalidAccountData);
317        }
318        if rest_input[..account_type_index] != vec![0; account_type_index] {
319            Err(ProgramError::InvalidAccountData)
320        } else {
321            Ok(Some((account_type_index, tlv_start_index)))
322        }
323    }
324}
325
326/// Checks a base buffer to verify if it is an Account without having to
327/// completely deserialize it
328fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
329    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
330
331    if input.len() != BASE_ACCOUNT_LENGTH {
332        return Err(ProgramError::InvalidAccountData);
333    }
334    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
335}
336
337fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
338    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
339        return Err(ProgramError::InvalidAccountData);
340    }
341    let TlvIndices {
342        type_start: _,
343        length_start,
344        value_start,
345    } = get_extension_indices::<V>(tlv_data, false)?;
346    // get_extension_indices has checked that tlv_data is long enough to include
347    // these indices
348    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
349    let value_end = value_start.saturating_add(usize::from(*length));
350    if tlv_data.len() < value_end {
351        return Err(ProgramError::InvalidAccountData);
352    }
353    Ok(&tlv_data[value_start..value_end])
354}
355
356fn get_extension_bytes_mut<S: BaseState, V: Extension>(
357    tlv_data: &mut [u8],
358) -> Result<&mut [u8], ProgramError> {
359    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
360        return Err(ProgramError::InvalidAccountData);
361    }
362    let TlvIndices {
363        type_start: _,
364        length_start,
365        value_start,
366    } = get_extension_indices::<V>(tlv_data, false)?;
367    // get_extension_indices has checked that tlv_data is long enough to include
368    // these indices
369    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
370    let value_end = value_start.saturating_add(usize::from(*length));
371    if tlv_data.len() < value_end {
372        return Err(ProgramError::InvalidAccountData);
373    }
374    Ok(&mut tlv_data[value_start..value_end])
375}
376
377/// Calculate the new expected size if the state allocates the given number
378/// of bytes for the given extension type.
379///
380/// Provides the correct answer regardless if the extension is already present
381/// in the TLV data.
382fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
383    tlv_data: &[u8],
384    new_extension_len: usize,
385) -> Result<usize, ProgramError> {
386    // get the new length used by the extension
387    let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
388    let tlv_info = get_tlv_data_info(tlv_data)?;
389    // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
390    // and account type
391    let current_len = tlv_info
392        .used_len
393        .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
394    // get the current length used by the extension
395    let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
396        .map(|x| add_type_and_length_to_len(x.len()))
397        .unwrap_or(0);
398    let new_len = current_len
399        .saturating_sub(current_extension_len)
400        .saturating_add(new_extension_tlv_len);
401    Ok(adjust_len_for_multisig(new_len))
402}
403
404/// Trait for base state with extension
405pub trait BaseStateWithExtensions<S: BaseState> {
406    /// Get the buffer containing all extension data
407    fn get_tlv_data(&self) -> &[u8];
408
409    /// Fetch the bytes for a TLV entry
410    fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
411        get_extension_bytes::<S, V>(self.get_tlv_data())
412    }
413
414    /// Unpack a portion of the TLV data as the desired type
415    fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
416        pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
417    }
418
419    /// Unpacks a portion of the TLV data as the desired variable-length type
420    fn get_variable_len_extension<V: Extension + VariableLenPack>(
421        &self,
422    ) -> Result<V, ProgramError> {
423        let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
424        V::unpack_from_slice(data)
425    }
426
427    /// Iterates through the TLV entries, returning only the types
428    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
429        get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
430    }
431
432    /// Get just the first extension type, useful to track mixed initialization
433    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
434        get_first_extension_type(self.get_tlv_data())
435    }
436
437    /// Get the total number of bytes used by TLV entries and the base type
438    fn try_get_account_len(&self) -> Result<usize, ProgramError> {
439        let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
440        if tlv_info.extension_types.is_empty() {
441            Ok(S::SIZE_OF)
442        } else {
443            let total_len = tlv_info
444                .used_len
445                .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
446            Ok(adjust_len_for_multisig(total_len))
447        }
448    }
449    /// Calculate the new expected size if the state allocates the given
450    /// fixed-length extension instance.
451    /// If the state already has the extension, the resulting account length
452    /// will be unchanged.
453    fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
454        try_get_new_account_len_for_extension_len::<S, V>(
455            self.get_tlv_data(),
456            pod_get_packed_len::<V>(),
457        )
458    }
459
460    /// Calculate the new expected size if the state allocates the given
461    /// variable-length extension instance.
462    fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
463        &self,
464        new_extension: &V,
465    ) -> Result<usize, ProgramError> {
466        try_get_new_account_len_for_extension_len::<S, V>(
467            self.get_tlv_data(),
468            new_extension.get_packed_len()?,
469        )
470    }
471}
472
473/// Encapsulates owned immutable base state data (mint or account) with possible
474/// extensions
475#[derive(Clone, Debug, PartialEq)]
476pub struct StateWithExtensionsOwned<S: BaseState> {
477    /// Unpacked base data
478    pub base: S,
479    /// Raw TLV data, deserialized on demand
480    tlv_data: Vec<u8>,
481}
482impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
483    /// Unpack base state, leaving the extension data as a slice
484    ///
485    /// Fails if the base state is not initialized.
486    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
487        check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
488        let mut rest = input.split_off(S::SIZE_OF);
489        let base = S::unpack(&input)?;
490        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
491            // type_and_tlv_indices() checks that returned indexes are within range
492            let account_type = AccountType::try_from(rest[account_type_index])
493                .map_err(|_| ProgramError::InvalidAccountData)?;
494            check_account_type::<S>(account_type)?;
495            let tlv_data = rest.split_off(tlv_start_index);
496            Ok(Self { base, tlv_data })
497        } else {
498            Ok(Self {
499                base,
500                tlv_data: vec![],
501            })
502        }
503    }
504}
505
506impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
507    fn get_tlv_data(&self) -> &[u8] {
508        &self.tlv_data
509    }
510}
511
512/// Encapsulates immutable base state data (mint or account) with possible
513/// extensions
514#[derive(Debug, PartialEq)]
515pub struct StateWithExtensions<'data, S: BaseState + Pack> {
516    /// Unpacked base data
517    pub base: S,
518    /// Slice of data containing all TLV data, deserialized on demand
519    tlv_data: &'data [u8],
520}
521impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
522    /// Unpack base state, leaving the extension data as a slice
523    ///
524    /// Fails if the base state is not initialized.
525    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
526        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
527        let (base_data, rest) = input.split_at(S::SIZE_OF);
528        let base = S::unpack(base_data)?;
529        let tlv_data = unpack_tlv_data::<S>(rest)?;
530        Ok(Self { base, tlv_data })
531    }
532}
533impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
534    fn get_tlv_data(&self) -> &[u8] {
535        self.tlv_data
536    }
537}
538
539/// Encapsulates immutable base state data (mint or account) with possible
540/// extensions, where the base state is Pod for zero-copy serde.
541#[derive(Debug, PartialEq)]
542pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
543    /// Unpacked base data
544    pub base: &'data S,
545    /// Slice of data containing all TLV data, deserialized on demand
546    tlv_data: &'data [u8],
547}
548impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
549    /// Unpack base state, leaving the extension data as a slice
550    ///
551    /// Fails if the base state is not initialized.
552    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
553        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
554        let (base_data, rest) = input.split_at(S::SIZE_OF);
555        let base = pod_from_bytes::<S>(base_data)?;
556        if !base.is_initialized() {
557            Err(ProgramError::UninitializedAccount)
558        } else {
559            let tlv_data = unpack_tlv_data::<S>(rest)?;
560            Ok(Self { base, tlv_data })
561        }
562    }
563}
564impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
565    fn get_tlv_data(&self) -> &[u8] {
566        self.tlv_data
567    }
568}
569
570/// Trait for mutable base state with extension
571pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
572    /// Get the underlying TLV data as mutable
573    fn get_tlv_data_mut(&mut self) -> &mut [u8];
574
575    /// Get the underlying account type as mutable
576    fn get_account_type_mut(&mut self) -> &mut [u8];
577
578    /// Unpack a portion of the TLV data as the base mutable bytes
579    fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
580        get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
581    }
582
583    /// Unpack a portion of the TLV data as the desired type that allows
584    /// modifying the type
585    fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
586        pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
587    }
588
589    /// Packs a variable-length extension into its appropriate data segment.
590    /// Fails if space hasn't already been allocated for the given extension
591    fn pack_variable_len_extension<V: Extension + VariableLenPack>(
592        &mut self,
593        extension: &V,
594    ) -> Result<(), ProgramError> {
595        let data = self.get_extension_bytes_mut::<V>()?;
596        // NOTE: Do *not* use `pack`, since the length check will cause
597        // reallocations to smaller sizes to fail
598        extension.pack_into_slice(data)
599    }
600
601    /// Packs the default extension data into an open slot if not already found
602    /// in the data buffer. If extension is already found in the buffer, it
603    /// overwrites the existing extension with the default state if
604    /// `overwrite` is set. If extension found, but `overwrite` is not set,
605    /// it returns error.
606    fn init_extension<V: Extension + Pod + Default>(
607        &mut self,
608        overwrite: bool,
609    ) -> Result<&mut V, ProgramError> {
610        let length = pod_get_packed_len::<V>();
611        let buffer = self.alloc::<V>(length, overwrite)?;
612        let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
613        *extension_ref = V::default();
614        Ok(extension_ref)
615    }
616
617    /// Reallocate and overwrite the TLV entry for the given variable-length
618    /// extension.
619    ///
620    /// Returns an error if the extension is not present, or if there is not
621    /// enough space in the buffer.
622    fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
623        &mut self,
624        new_extension: &V,
625    ) -> Result<(), ProgramError> {
626        let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
627        new_extension.pack_into_slice(data)
628    }
629
630    /// Reallocate the TLV entry for the given extension to the given number of
631    /// bytes.
632    ///
633    /// If the new length is smaller, it will compact the rest of the buffer and
634    /// zero out the difference at the end. If it's larger, it will move the
635    /// rest of the buffer data and zero out the new data.
636    ///
637    /// Returns an error if the extension is not present, or if this is not
638    /// enough space in the buffer.
639    fn realloc<V: Extension + VariableLenPack>(
640        &mut self,
641        length: usize,
642    ) -> Result<&mut [u8], ProgramError> {
643        let tlv_data = self.get_tlv_data_mut();
644        let TlvIndices {
645            type_start: _,
646            length_start,
647            value_start,
648        } = get_extension_indices::<V>(tlv_data, false)?;
649        let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
650        let data_len = tlv_data.len();
651
652        let length_ref = pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
653        let old_length = usize::from(*length_ref);
654
655        // Length check to avoid a panic later in `copy_within`
656        if old_length < length {
657            let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
658            if new_tlv_len > data_len {
659                return Err(ProgramError::InvalidAccountData);
660            }
661        }
662
663        // write new length after the check, to avoid getting into a bad situation
664        // if trying to recover from an error
665        *length_ref = Length::try_from(length)?;
666
667        let old_value_end = value_start.saturating_add(old_length);
668        let new_value_end = value_start.saturating_add(length);
669        tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
670        match old_length.cmp(&length) {
671            Ordering::Greater => {
672                // realloc to smaller, zero out the end
673                let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
674                tlv_data[new_tlv_len..tlv_len].fill(0);
675            }
676            Ordering::Less => {
677                // realloc to bigger, zero out the new bytes
678                tlv_data[old_value_end..new_value_end].fill(0);
679            }
680            Ordering::Equal => {} // nothing needed!
681        }
682
683        Ok(&mut tlv_data[value_start..new_value_end])
684    }
685
686    /// Allocate the given number of bytes for the given variable-length
687    /// extension and write its contents into the TLV buffer.
688    ///
689    /// This can only be used for variable-sized types, such as `String` or
690    /// `Vec`. `Pod` types must use `init_extension`
691    fn init_variable_len_extension<V: Extension + VariableLenPack>(
692        &mut self,
693        extension: &V,
694        overwrite: bool,
695    ) -> Result<(), ProgramError> {
696        let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
697        extension.pack_into_slice(data)
698    }
699
700    /// Allocate some space for the extension in the TLV data
701    fn alloc<V: Extension>(
702        &mut self,
703        length: usize,
704        overwrite: bool,
705    ) -> Result<&mut [u8], ProgramError> {
706        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
707            return Err(ProgramError::InvalidAccountData);
708        }
709        let tlv_data = self.get_tlv_data_mut();
710        let TlvIndices {
711            type_start,
712            length_start,
713            value_start,
714        } = get_extension_indices::<V>(tlv_data, true)?;
715
716        if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
717            return Err(ProgramError::InvalidAccountData);
718        }
719        let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
720
721        if extension_type == ExtensionType::Uninitialized || overwrite {
722            // write extension type
723            let extension_type_array: [u8; 2] = V::TYPE.into();
724            let extension_type_ref = &mut tlv_data[type_start..length_start];
725            extension_type_ref.copy_from_slice(&extension_type_array);
726            // write length
727            let length_ref =
728                pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
729
730            // check that the length is the same if we're doing an alloc
731            // with overwrite, otherwise a realloc should be done
732            if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
733                return Err(TokenError::InvalidLengthForAlloc.into());
734            }
735
736            *length_ref = Length::try_from(length)?;
737
738            let value_end = value_start.saturating_add(length);
739            Ok(&mut tlv_data[value_start..value_end])
740        } else {
741            // extension is already initialized, but no overwrite permission
742            Err(TokenError::ExtensionAlreadyInitialized.into())
743        }
744    }
745
746    /// If `extension_type` is an Account-associated `ExtensionType` that
747    /// requires initialization on `InitializeAccount`, this method packs
748    /// the default relevant `Extension` of an `ExtensionType` into an open
749    /// slot if not already found in the data buffer, otherwise overwrites
750    /// the existing extension with the default state. For all other
751    /// `ExtensionType`s, this is a no-op.
752    fn init_account_extension_from_type(
753        &mut self,
754        extension_type: ExtensionType,
755    ) -> Result<(), ProgramError> {
756        if extension_type.get_account_type() != AccountType::Account {
757            return Ok(());
758        }
759        match extension_type {
760            ExtensionType::TransferFeeAmount => {
761                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
762            }
763            ExtensionType::ImmutableOwner => {
764                self.init_extension::<ImmutableOwner>(true).map(|_| ())
765            }
766            ExtensionType::NonTransferableAccount => self
767                .init_extension::<NonTransferableAccount>(true)
768                .map(|_| ()),
769            ExtensionType::TransferHookAccount => {
770                self.init_extension::<TransferHookAccount>(true).map(|_| ())
771            }
772            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
773            // on InitializeAccount
774            ExtensionType::ConfidentialTransferAccount => Ok(()),
775            ExtensionType::PausableAccount => {
776                self.init_extension::<PausableAccount>(true).map(|_| ())
777            }
778            #[cfg(test)]
779            ExtensionType::AccountPaddingTest => {
780                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
781            }
782            _ => unreachable!(),
783        }
784    }
785
786    /// Write the account type into the buffer, done during the base
787    /// state initialization
788    /// Noops if there is no room for an extension in the account, needed for
789    /// pure base mints / accounts.
790    fn init_account_type(&mut self) -> Result<(), ProgramError> {
791        let first_extension_type = self.get_first_extension_type()?;
792        let account_type = self.get_account_type_mut();
793        if !account_type.is_empty() {
794            if let Some(extension_type) = first_extension_type {
795                let account_type = extension_type.get_account_type();
796                if account_type != S::ACCOUNT_TYPE {
797                    return Err(TokenError::ExtensionBaseMismatch.into());
798                }
799            }
800            account_type[0] = S::ACCOUNT_TYPE.into();
801        }
802        Ok(())
803    }
804
805    /// Check that the account type on the account (if initialized) matches the
806    /// account type for any extensions initialized on the TLV data
807    fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
808        if let Some(extension_type) = self.get_first_extension_type()? {
809            let account_type = extension_type.get_account_type();
810            if account_type != S::ACCOUNT_TYPE {
811                return Err(TokenError::ExtensionBaseMismatch.into());
812            }
813        }
814        Ok(())
815    }
816}
817
818/// Encapsulates mutable base state data (mint or account) with possible
819/// extensions
820#[derive(Debug, PartialEq)]
821pub struct StateWithExtensionsMut<'data, S: BaseState> {
822    /// Unpacked base data
823    pub base: S,
824    /// Raw base data
825    base_data: &'data mut [u8],
826    /// Writable account type
827    account_type: &'data mut [u8],
828    /// Slice of data containing all TLV data, deserialized on demand
829    tlv_data: &'data mut [u8],
830}
831impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
832    /// Unpack base state, leaving the extension data as a mutable slice
833    ///
834    /// Fails if the base state is not initialized.
835    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
836        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
837        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
838        let base = S::unpack(base_data)?;
839        let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
840        Ok(Self {
841            base,
842            base_data,
843            account_type,
844            tlv_data,
845        })
846    }
847
848    /// Unpack an uninitialized base state, leaving the extension data as a
849    /// mutable slice
850    ///
851    /// Fails if the base state has already been initialized.
852    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
853        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
854        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
855        let base = S::unpack_unchecked(base_data)?;
856        if base.is_initialized() {
857            return Err(TokenError::AlreadyInUse.into());
858        }
859        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
860        let state = Self {
861            base,
862            base_data,
863            account_type,
864            tlv_data,
865        };
866        state.check_account_type_matches_extension_type()?;
867        Ok(state)
868    }
869
870    /// Packs base state data into the base data portion
871    pub fn pack_base(&mut self) {
872        S::pack_into_slice(&self.base, self.base_data);
873    }
874}
875impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
876    fn get_tlv_data(&self) -> &[u8] {
877        self.tlv_data
878    }
879}
880impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
881    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
882        self.tlv_data
883    }
884    fn get_account_type_mut(&mut self) -> &mut [u8] {
885        self.account_type
886    }
887}
888
889/// Encapsulates mutable base state data (mint or account) with possible
890/// extensions, where the base state is Pod for zero-copy serde.
891#[derive(Debug, PartialEq)]
892pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
893    /// Unpacked base data
894    pub base: &'data mut S,
895    /// Writable account type
896    account_type: &'data mut [u8],
897    /// Slice of data containing all TLV data, deserialized on demand
898    tlv_data: &'data mut [u8],
899}
900impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
901    /// Unpack base state, leaving the extension data as a mutable slice
902    ///
903    /// Fails if the base state is not initialized.
904    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
905        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
906        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
907        let base = pod_from_bytes_mut::<S>(base_data)?;
908        if !base.is_initialized() {
909            Err(ProgramError::UninitializedAccount)
910        } else {
911            let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
912            Ok(Self {
913                base,
914                account_type,
915                tlv_data,
916            })
917        }
918    }
919
920    /// Unpack an uninitialized base state, leaving the extension data as a
921    /// mutable slice
922    ///
923    /// Fails if the base state has already been initialized.
924    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
925        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
926        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
927        let base = pod_from_bytes_mut::<S>(base_data)?;
928        if base.is_initialized() {
929            return Err(TokenError::AlreadyInUse.into());
930        }
931        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
932        let state = Self {
933            base,
934            account_type,
935            tlv_data,
936        };
937        state.check_account_type_matches_extension_type()?;
938        Ok(state)
939    }
940}
941
942impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
943    fn get_tlv_data(&self) -> &[u8] {
944        self.tlv_data
945    }
946}
947impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
948    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
949        self.tlv_data
950    }
951    fn get_account_type_mut(&mut self) -> &mut [u8] {
952        self.account_type
953    }
954}
955
956fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
957    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
958        // type_and_tlv_indices() checks that returned indexes are within range
959        let account_type = AccountType::try_from(rest[account_type_index])
960            .map_err(|_| ProgramError::InvalidAccountData)?;
961        check_account_type::<S>(account_type)?;
962        Ok(&rest[tlv_start_index..])
963    } else {
964        Ok(&[])
965    }
966}
967
968fn unpack_type_and_tlv_data_with_check_mut<
969    S: BaseState,
970    F: Fn(AccountType) -> Result<(), ProgramError>,
971>(
972    rest: &mut [u8],
973    check_fn: F,
974) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
975    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
976        // type_and_tlv_indices() checks that returned indexes are within range
977        let account_type = AccountType::try_from(rest[account_type_index])
978            .map_err(|_| ProgramError::InvalidAccountData)?;
979        check_fn(account_type)?;
980        let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
981        Ok((
982            &mut account_type[account_type_index..tlv_start_index],
983            tlv_data,
984        ))
985    } else {
986        Ok((&mut [], &mut []))
987    }
988}
989
990fn unpack_type_and_tlv_data_mut<S: BaseState>(
991    rest: &mut [u8],
992) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
993    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
994}
995
996fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
997    rest: &mut [u8],
998) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
999    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1000        if account_type != AccountType::Uninitialized {
1001            Err(ProgramError::InvalidAccountData)
1002        } else {
1003            Ok(())
1004        }
1005    })
1006}
1007
1008/// If `AccountType` is uninitialized, set it to the `BaseState`'s
1009/// `ACCOUNT_TYPE`; if `AccountType` is already set, check is set correctly for
1010/// `BaseState`. This method assumes that the `base_data` has already been
1011/// packed with data of the desired type.
1012pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1013    check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1014    let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1015    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1016        return Err(ProgramError::InvalidAccountData);
1017    }
1018    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1019        let mut account_type = AccountType::try_from(rest[account_type_index])
1020            .map_err(|_| ProgramError::InvalidAccountData)?;
1021        if account_type == AccountType::Uninitialized {
1022            rest[account_type_index] = S::ACCOUNT_TYPE.into();
1023            account_type = S::ACCOUNT_TYPE;
1024        }
1025        check_account_type::<S>(account_type)?;
1026        Ok(())
1027    } else {
1028        Err(ProgramError::InvalidAccountData)
1029    }
1030}
1031
1032/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig`
1033/// types are determined exclusively by the size of the account, and are not
1034/// included in the account data. `AccountType` is only included if extensions
1035/// have been initialized.
1036#[repr(u8)]
1037#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1038pub enum AccountType {
1039    /// Marker for 0 data
1040    Uninitialized,
1041    /// Mint account with additional extensions
1042    Mint,
1043    /// Token holding account with additional extensions
1044    Account,
1045}
1046impl Default for AccountType {
1047    fn default() -> Self {
1048        Self::Uninitialized
1049    }
1050}
1051
1052/// Extensions that can be applied to mints or accounts.  Mint extensions must
1053/// only be applied to mint accounts, and account extensions must only be
1054/// applied to token holding accounts.
1055#[repr(u16)]
1056#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
1057#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
1058#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1059pub enum ExtensionType {
1060    /// Used as padding if the account size would otherwise be 355, same as a
1061    /// multisig
1062    Uninitialized,
1063    /// Includes transfer fee rate info and accompanying authorities to withdraw
1064    /// and set the fee
1065    TransferFeeConfig,
1066    /// Includes withheld transfer fees
1067    TransferFeeAmount,
1068    /// Includes an optional mint close authority
1069    MintCloseAuthority,
1070    /// Auditor configuration for confidential transfers
1071    ConfidentialTransferMint,
1072    /// State for confidential transfers
1073    ConfidentialTransferAccount,
1074    /// Specifies the default Account::state for new Accounts
1075    DefaultAccountState,
1076    /// Indicates that the Account owner authority cannot be changed
1077    ImmutableOwner,
1078    /// Require inbound transfers to have memo
1079    MemoTransfer,
1080    /// Indicates that the tokens from this mint can't be transferred
1081    NonTransferable,
1082    /// Tokens accrue interest over time,
1083    InterestBearingConfig,
1084    /// Locks privileged token operations from happening via CPI
1085    CpiGuard,
1086    /// Includes an optional permanent delegate
1087    PermanentDelegate,
1088    /// Indicates that the tokens in this account belong to a non-transferable
1089    /// mint
1090    NonTransferableAccount,
1091    /// Mint requires a CPI to a program implementing the "transfer hook"
1092    /// interface
1093    TransferHook,
1094    /// Indicates that the tokens in this account belong to a mint with a
1095    /// transfer hook
1096    TransferHookAccount,
1097    /// Includes encrypted withheld fees and the encryption public that they are
1098    /// encrypted under
1099    ConfidentialTransferFeeConfig,
1100    /// Includes confidential withheld transfer fees
1101    ConfidentialTransferFeeAmount,
1102    /// Mint contains a pointer to another account (or the same account) that
1103    /// holds metadata
1104    MetadataPointer,
1105    /// Mint contains token-metadata
1106    TokenMetadata,
1107    /// Mint contains a pointer to another account (or the same account) that
1108    /// holds group configurations
1109    GroupPointer,
1110    /// Mint contains token group configurations
1111    TokenGroup,
1112    /// Mint contains a pointer to another account (or the same account) that
1113    /// holds group member configurations
1114    GroupMemberPointer,
1115    /// Mint contains token group member configurations
1116    TokenGroupMember,
1117    /// Mint allowing the minting and burning of confidential tokens
1118    ConfidentialMintBurn,
1119    /// Tokens whose UI amount is scaled by a given amount
1120    ScaledUiAmount,
1121    /// Tokens where minting / burning / transferring can be paused
1122    Pausable,
1123    /// Indicates that the account belongs to a pausable mint
1124    PausableAccount,
1125
1126    /// Test variable-length mint extension
1127    #[cfg(test)]
1128    VariableLenMintTest = u16::MAX - 2,
1129    /// Padding extension used to make an account exactly Multisig::LEN, used
1130    /// for testing
1131    #[cfg(test)]
1132    AccountPaddingTest,
1133    /// Padding extension used to make a mint exactly Multisig::LEN, used for
1134    /// testing
1135    #[cfg(test)]
1136    MintPaddingTest,
1137}
1138impl TryFrom<&[u8]> for ExtensionType {
1139    type Error = ProgramError;
1140    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1141        Self::try_from(u16::from_le_bytes(
1142            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1143        ))
1144        .map_err(|_| ProgramError::InvalidAccountData)
1145    }
1146}
1147impl From<ExtensionType> for [u8; 2] {
1148    fn from(a: ExtensionType) -> Self {
1149        u16::from(a).to_le_bytes()
1150    }
1151}
1152impl ExtensionType {
1153    /// Returns true if the given extension type is sized
1154    ///
1155    /// Most extension types should be sized, so any variable-length extension
1156    /// types should be added here by hand
1157    const fn sized(&self) -> bool {
1158        match self {
1159            ExtensionType::TokenMetadata => false,
1160            #[cfg(test)]
1161            ExtensionType::VariableLenMintTest => false,
1162            _ => true,
1163        }
1164    }
1165
1166    /// Get the data length of the type associated with the enum
1167    ///
1168    /// Fails if the extension type has a variable length
1169    fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1170        if !self.sized() {
1171            return Err(ProgramError::InvalidArgument);
1172        }
1173        Ok(match self {
1174            ExtensionType::Uninitialized => 0,
1175            ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
1176            ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
1177            ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
1178            ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
1179            ExtensionType::ConfidentialTransferMint => {
1180                pod_get_packed_len::<ConfidentialTransferMint>()
1181            }
1182            ExtensionType::ConfidentialTransferAccount => {
1183                pod_get_packed_len::<ConfidentialTransferAccount>()
1184            }
1185            ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
1186            ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
1187            ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
1188            ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
1189            ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
1190            ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
1191            ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
1192            ExtensionType::TransferHook => pod_get_packed_len::<TransferHook>(),
1193            ExtensionType::TransferHookAccount => pod_get_packed_len::<TransferHookAccount>(),
1194            ExtensionType::ConfidentialTransferFeeConfig => {
1195                pod_get_packed_len::<ConfidentialTransferFeeConfig>()
1196            }
1197            ExtensionType::ConfidentialTransferFeeAmount => {
1198                pod_get_packed_len::<ConfidentialTransferFeeAmount>()
1199            }
1200            ExtensionType::MetadataPointer => pod_get_packed_len::<MetadataPointer>(),
1201            ExtensionType::TokenMetadata => unreachable!(),
1202            ExtensionType::GroupPointer => pod_get_packed_len::<GroupPointer>(),
1203            ExtensionType::TokenGroup => pod_get_packed_len::<TokenGroup>(),
1204            ExtensionType::GroupMemberPointer => pod_get_packed_len::<GroupMemberPointer>(),
1205            ExtensionType::TokenGroupMember => pod_get_packed_len::<TokenGroupMember>(),
1206            ExtensionType::ConfidentialMintBurn => pod_get_packed_len::<ConfidentialMintBurn>(),
1207            ExtensionType::ScaledUiAmount => pod_get_packed_len::<ScaledUiAmountConfig>(),
1208            ExtensionType::Pausable => pod_get_packed_len::<PausableConfig>(),
1209            ExtensionType::PausableAccount => pod_get_packed_len::<PausableAccount>(),
1210            #[cfg(test)]
1211            ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
1212            #[cfg(test)]
1213            ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
1214            #[cfg(test)]
1215            ExtensionType::VariableLenMintTest => unreachable!(),
1216        })
1217    }
1218
1219    /// Get the TLV length for an `ExtensionType`
1220    ///
1221    /// Fails if the extension type has a variable length
1222    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1223        Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1224    }
1225
1226    /// Get the TLV length for a set of `ExtensionType`s
1227    ///
1228    /// Fails if any of the extension types has a variable length
1229    fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
1230        // dedupe extensions
1231        let mut extensions = vec![];
1232        for extension_type in extension_types {
1233            if !extensions.contains(&extension_type) {
1234                extensions.push(extension_type);
1235            }
1236        }
1237        extensions.iter().map(|e| e.try_get_tlv_len()).sum()
1238    }
1239
1240    /// Get the required account data length for the given `ExtensionType`s
1241    ///
1242    /// Fails if any of the extension types has a variable length
1243    pub fn try_calculate_account_len<S: BaseState>(
1244        extension_types: &[Self],
1245    ) -> Result<usize, ProgramError> {
1246        if extension_types.is_empty() {
1247            Ok(S::SIZE_OF)
1248        } else {
1249            let extension_size = Self::try_get_total_tlv_len(extension_types)?;
1250            let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
1251            Ok(adjust_len_for_multisig(total_len))
1252        }
1253    }
1254
1255    /// Get the associated account type
1256    pub fn get_account_type(&self) -> AccountType {
1257        match self {
1258            ExtensionType::Uninitialized => AccountType::Uninitialized,
1259            ExtensionType::TransferFeeConfig
1260            | ExtensionType::MintCloseAuthority
1261            | ExtensionType::ConfidentialTransferMint
1262            | ExtensionType::DefaultAccountState
1263            | ExtensionType::NonTransferable
1264            | ExtensionType::InterestBearingConfig
1265            | ExtensionType::PermanentDelegate
1266            | ExtensionType::TransferHook
1267            | ExtensionType::ConfidentialTransferFeeConfig
1268            | ExtensionType::MetadataPointer
1269            | ExtensionType::TokenMetadata
1270            | ExtensionType::GroupPointer
1271            | ExtensionType::TokenGroup
1272            | ExtensionType::GroupMemberPointer
1273            | ExtensionType::ConfidentialMintBurn
1274            | ExtensionType::TokenGroupMember
1275            | ExtensionType::ScaledUiAmount
1276            | ExtensionType::Pausable => AccountType::Mint,
1277            ExtensionType::ImmutableOwner
1278            | ExtensionType::TransferFeeAmount
1279            | ExtensionType::ConfidentialTransferAccount
1280            | ExtensionType::MemoTransfer
1281            | ExtensionType::NonTransferableAccount
1282            | ExtensionType::TransferHookAccount
1283            | ExtensionType::CpiGuard
1284            | ExtensionType::ConfidentialTransferFeeAmount
1285            | ExtensionType::PausableAccount => AccountType::Account,
1286            #[cfg(test)]
1287            ExtensionType::VariableLenMintTest => AccountType::Mint,
1288            #[cfg(test)]
1289            ExtensionType::AccountPaddingTest => AccountType::Account,
1290            #[cfg(test)]
1291            ExtensionType::MintPaddingTest => AccountType::Mint,
1292        }
1293    }
1294
1295    /// Based on a set of `AccountType::Mint` `ExtensionType`s, get the list of
1296    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1297    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1298        let mut account_extension_types = vec![];
1299        for extension_type in mint_extension_types {
1300            match extension_type {
1301                ExtensionType::TransferFeeConfig => {
1302                    account_extension_types.push(ExtensionType::TransferFeeAmount);
1303                }
1304                ExtensionType::NonTransferable => {
1305                    account_extension_types.push(ExtensionType::NonTransferableAccount);
1306                    account_extension_types.push(ExtensionType::ImmutableOwner);
1307                }
1308                ExtensionType::TransferHook => {
1309                    account_extension_types.push(ExtensionType::TransferHookAccount);
1310                }
1311                ExtensionType::Pausable => {
1312                    account_extension_types.push(ExtensionType::PausableAccount);
1313                }
1314                #[cfg(test)]
1315                ExtensionType::MintPaddingTest => {
1316                    account_extension_types.push(ExtensionType::AccountPaddingTest);
1317                }
1318                _ => {}
1319            }
1320        }
1321        account_extension_types
1322    }
1323
1324    /// Check for invalid combination of mint extensions
1325    pub fn check_for_invalid_mint_extension_combinations(
1326        mint_extension_types: &[Self],
1327    ) -> Result<(), TokenError> {
1328        let mut transfer_fee_config = false;
1329        let mut confidential_transfer_mint = false;
1330        let mut confidential_transfer_fee_config = false;
1331        let mut confidential_mint_burn = false;
1332        let mut interest_bearing = false;
1333        let mut scaled_ui_amount = false;
1334
1335        for extension_type in mint_extension_types {
1336            match extension_type {
1337                ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1338                ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1339                ExtensionType::ConfidentialTransferFeeConfig => {
1340                    confidential_transfer_fee_config = true
1341                }
1342                ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1343                ExtensionType::InterestBearingConfig => interest_bearing = true,
1344                ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1345                _ => (),
1346            }
1347        }
1348
1349        if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1350        {
1351            return Err(TokenError::InvalidExtensionCombination);
1352        }
1353
1354        if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1355            return Err(TokenError::InvalidExtensionCombination);
1356        }
1357
1358        if confidential_mint_burn && !confidential_transfer_mint {
1359            return Err(TokenError::InvalidExtensionCombination);
1360        }
1361
1362        if scaled_ui_amount && interest_bearing {
1363            return Err(TokenError::InvalidExtensionCombination);
1364        }
1365
1366        Ok(())
1367    }
1368}
1369
1370/// Trait for base states, specifying the associated enum
1371pub trait BaseState: PackedSizeOf + IsInitialized {
1372    /// Associated extension type enum, checked at the start of TLV entries
1373    const ACCOUNT_TYPE: AccountType;
1374}
1375impl BaseState for Account {
1376    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1377}
1378impl BaseState for Mint {
1379    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1380}
1381impl BaseState for PodAccount {
1382    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1383}
1384impl BaseState for PodMint {
1385    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1386}
1387
1388/// Trait to be implemented by all extension states, specifying which extension
1389/// and account type they are associated with
1390pub trait Extension {
1391    /// Associated extension type enum, checked at the start of TLV entries
1392    const TYPE: ExtensionType;
1393}
1394
1395/// Padding a mint account to be exactly `Multisig::LEN`.
1396/// We need to pad 185 bytes, since `Multisig::LEN = 355`, `Account::LEN = 165`,
1397/// `size_of::<AccountType>() = 1`, `size_of::<ExtensionType>() = 2`,
1398/// `size_of::<Length>() = 2`.
1399///
1400/// ```
1401/// assert_eq!(355 - 165 - 1 - 2 - 2, 185);
1402/// ```
1403#[cfg(test)]
1404#[repr(C)]
1405#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1406pub struct MintPaddingTest {
1407    /// Largest value under 185 that implements Pod
1408    pub padding1: [u8; 128],
1409    /// Largest value under 57 that implements Pod
1410    pub padding2: [u8; 48],
1411    /// Exact value needed to finish the padding
1412    pub padding3: [u8; 9],
1413}
1414#[cfg(test)]
1415impl Extension for MintPaddingTest {
1416    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1417}
1418#[cfg(test)]
1419impl Default for MintPaddingTest {
1420    fn default() -> Self {
1421        Self {
1422            padding1: [1; 128],
1423            padding2: [2; 48],
1424            padding3: [3; 9],
1425        }
1426    }
1427}
1428/// Account version of the `MintPadding`
1429#[cfg(test)]
1430#[repr(C)]
1431#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1432pub struct AccountPaddingTest(MintPaddingTest);
1433#[cfg(test)]
1434impl Extension for AccountPaddingTest {
1435    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1436}
1437
1438/// Packs a fixed-length extension into a TLV space
1439///
1440/// This function reallocates the account as needed to accommodate for the
1441/// change in space.
1442///
1443/// If the extension already exists, it will overwrite the existing extension
1444/// if `overwrite` is `true`, otherwise it will return an error.
1445///
1446/// If the extension does not exist, it will reallocate the account and write
1447/// the extension into the TLV buffer.
1448///
1449/// NOTE: Since this function deals with fixed-size extensions, it does not
1450/// handle _decreasing_ the size of an account's data buffer, like the function
1451/// `alloc_and_serialize_variable_len_extension` does.
1452pub(crate) fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1453    account_info: &AccountInfo,
1454    new_extension: &V,
1455    overwrite: bool,
1456) -> Result<(), ProgramError> {
1457    let previous_account_len = account_info.try_data_len()?;
1458    let new_account_len = {
1459        let data = account_info.try_borrow_data()?;
1460        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1461        state.try_get_new_account_len::<V>()?
1462    };
1463
1464    // Realloc the account first, if needed
1465    if new_account_len > previous_account_len {
1466        account_info.realloc(new_account_len, false)?;
1467    }
1468    let mut buffer = account_info.try_borrow_mut_data()?;
1469    if previous_account_len <= BASE_ACCOUNT_LENGTH {
1470        set_account_type::<S>(*buffer)?;
1471    }
1472    let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1473
1474    // Write the extension
1475    let extension = state.init_extension::<V>(overwrite)?;
1476    *extension = *new_extension;
1477
1478    Ok(())
1479}
1480
1481/// Packs a variable-length extension into a TLV space
1482///
1483/// This function reallocates the account as needed to accommodate for the
1484/// change in space, then reallocates in the TLV buffer, and finally writes the
1485/// bytes.
1486///
1487/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1488/// size of an account if it has too many bytes allocated for the given value.
1489pub(crate) fn alloc_and_serialize_variable_len_extension<
1490    S: BaseState + Pod,
1491    V: Extension + VariableLenPack,
1492>(
1493    account_info: &AccountInfo,
1494    new_extension: &V,
1495    overwrite: bool,
1496) -> Result<(), ProgramError> {
1497    let previous_account_len = account_info.try_data_len()?;
1498    let (new_account_len, extension_already_exists) = {
1499        let data = account_info.try_borrow_data()?;
1500        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1501        let new_account_len =
1502            state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1503        let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1504        (new_account_len, extension_already_exists)
1505    };
1506
1507    if extension_already_exists && !overwrite {
1508        return Err(TokenError::ExtensionAlreadyInitialized.into());
1509    }
1510
1511    if previous_account_len < new_account_len {
1512        // account size increased, so realloc the account, then the TLV entry, then
1513        // write data
1514        account_info.realloc(new_account_len, false)?;
1515        let mut buffer = account_info.try_borrow_mut_data()?;
1516        if extension_already_exists {
1517            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1518            state.realloc_variable_len_extension(new_extension)?;
1519        } else {
1520            if previous_account_len <= BASE_ACCOUNT_LENGTH {
1521                set_account_type::<S>(*buffer)?;
1522            }
1523            // now alloc in the TLV buffer and write the data
1524            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1525            state.init_variable_len_extension(new_extension, false)?;
1526        }
1527    } else {
1528        // do it backwards otherwise, write the state, realloc TLV, then the account
1529        let mut buffer = account_info.try_borrow_mut_data()?;
1530        let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1531        if extension_already_exists {
1532            state.realloc_variable_len_extension(new_extension)?;
1533        } else {
1534            // this situation can happen if we have an overallocated buffer
1535            state.init_variable_len_extension(new_extension, false)?;
1536        }
1537
1538        let removed_bytes = previous_account_len
1539            .checked_sub(new_account_len)
1540            .ok_or(ProgramError::AccountDataTooSmall)?;
1541        if removed_bytes > 0 {
1542            // this is probably fine, but be safe and avoid invalidating references
1543            drop(buffer);
1544            account_info.realloc(new_account_len, false)?;
1545        }
1546    }
1547    Ok(())
1548}
1549
1550#[cfg(test)]
1551mod test {
1552    use {
1553        super::*,
1554        crate::{
1555            pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1556            state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1557        },
1558        bytemuck::Pod,
1559        solana_account_info::{
1560            Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1561        },
1562        solana_clock::Epoch,
1563        solana_pubkey::Pubkey,
1564        spl_pod::{
1565            bytemuck::pod_bytes_of,
1566            optional_keys::OptionalNonZeroPubkey,
1567            primitives::{PodBool, PodU64},
1568        },
1569        transfer_fee::test::test_transfer_fee_config,
1570    };
1571
1572    /// Test fixed-length struct
1573    #[repr(C)]
1574    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1575    struct FixedLenMintTest {
1576        data: [u8; 8],
1577    }
1578    impl Extension for FixedLenMintTest {
1579        const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1580    }
1581
1582    /// Test variable-length struct
1583    #[derive(Clone, Debug, PartialEq)]
1584    struct VariableLenMintTest {
1585        data: Vec<u8>,
1586    }
1587    impl Extension for VariableLenMintTest {
1588        const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1589    }
1590    impl VariableLenPack for VariableLenMintTest {
1591        fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1592            let data_start = size_of::<u64>();
1593            let end = data_start + self.data.len();
1594            if dst.len() < end {
1595                Err(ProgramError::InvalidAccountData)
1596            } else {
1597                dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1598                dst[data_start..end].copy_from_slice(&self.data);
1599                Ok(())
1600            }
1601        }
1602        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1603            let data_start = size_of::<u64>();
1604            let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1605            if src[data_start..data_start + length].len() != length {
1606                return Err(ProgramError::InvalidAccountData);
1607            }
1608            let data = Vec::from(&src[data_start..data_start + length]);
1609            Ok(Self { data })
1610        }
1611        fn get_packed_len(&self) -> Result<usize, ProgramError> {
1612            Ok(size_of::<u64>().saturating_add(self.data.len()))
1613        }
1614    }
1615
1616    const MINT_WITH_EXTENSION: &[u8] = &[
1617        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1618        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1619        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1620        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1621        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1622        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1623        1, // account type
1624        3, 0, // extension type
1625        32, 0, // length
1626        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1627        1, 1, // data
1628    ];
1629
1630    const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1631        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1632        1, 1, // mint
1633        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1634        2, 2, // owner
1635        3, 0, 0, 0, 0, 0, 0, 0, // amount
1636        1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1637        4, 4, 4, 4, 4, 4, // delegate
1638        2, // account state
1639        1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, // is native
1640        6, 0, 0, 0, 0, 0, 0, 0, // delegated amount
1641        1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1642        7, 7, 7, 7, 7, 7, // close authority
1643        2, // account type
1644        15, 0, // extension type
1645        1, 0, // length
1646        1, // data
1647    ];
1648
1649    #[test]
1650    fn unpack_opaque_buffer() {
1651        // Mint
1652        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1653        assert_eq!(state.base, &TEST_POD_MINT);
1654        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1655        let close_authority =
1656            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1657        assert_eq!(extension.close_authority, close_authority);
1658        assert_eq!(
1659            state.get_extension::<TransferFeeConfig>(),
1660            Err(ProgramError::InvalidAccountData)
1661        );
1662        assert_eq!(
1663            PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1664            Err(ProgramError::UninitializedAccount)
1665        );
1666
1667        let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1668        assert_eq!(state.base, &TEST_POD_MINT);
1669
1670        let mut test_mint = TEST_MINT_SLICE.to_vec();
1671        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1672        assert_eq!(state.base, &TEST_POD_MINT);
1673
1674        // Account
1675        let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1676        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1677        let extension = state.get_extension::<TransferHookAccount>().unwrap();
1678        let transferring = PodBool::from(true);
1679        assert_eq!(extension.transferring, transferring);
1680        assert_eq!(
1681            PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1682            Err(ProgramError::InvalidAccountData)
1683        );
1684
1685        let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1686        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1687
1688        let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1689        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1690        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1691    }
1692
1693    #[test]
1694    fn mint_fail_unpack_opaque_buffer() {
1695        // input buffer too small
1696        let mut buffer = vec![0, 3];
1697        assert_eq!(
1698            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1699            Err(ProgramError::InvalidAccountData)
1700        );
1701        assert_eq!(
1702            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1703            Err(ProgramError::InvalidAccountData)
1704        );
1705        assert_eq!(
1706            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1707            Err(ProgramError::InvalidAccountData)
1708        );
1709
1710        // tweak the account type
1711        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1712        buffer[BASE_ACCOUNT_LENGTH] = 3;
1713        assert_eq!(
1714            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1715            Err(ProgramError::InvalidAccountData)
1716        );
1717
1718        // clear the mint initialized byte
1719        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1720        buffer[45] = 0;
1721        assert_eq!(
1722            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1723            Err(ProgramError::UninitializedAccount)
1724        );
1725
1726        // tweak the padding
1727        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1728        buffer[PodMint::SIZE_OF] = 100;
1729        assert_eq!(
1730            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1731            Err(ProgramError::InvalidAccountData)
1732        );
1733
1734        // tweak the extension type
1735        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1736        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1737        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1738        assert_eq!(
1739            state.get_extension::<TransferFeeConfig>(),
1740            Err(ProgramError::Custom(
1741                TokenError::ExtensionTypeMismatch as u32
1742            ))
1743        );
1744
1745        // tweak the length, too big
1746        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1747        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1748        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1749        assert_eq!(
1750            state.get_extension::<TransferFeeConfig>(),
1751            Err(ProgramError::InvalidAccountData)
1752        );
1753
1754        // tweak the length, too small
1755        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1756        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1757        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1758        assert_eq!(
1759            state.get_extension::<TransferFeeConfig>(),
1760            Err(ProgramError::InvalidAccountData)
1761        );
1762
1763        // data buffer is too small
1764        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1765        let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1766        assert_eq!(
1767            state.get_extension::<MintCloseAuthority>(),
1768            Err(ProgramError::InvalidAccountData)
1769        );
1770    }
1771
1772    #[test]
1773    fn account_fail_unpack_opaque_buffer() {
1774        // input buffer too small
1775        let mut buffer = vec![0, 3];
1776        assert_eq!(
1777            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1778            Err(ProgramError::InvalidAccountData)
1779        );
1780        assert_eq!(
1781            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1782            Err(ProgramError::InvalidAccountData)
1783        );
1784        assert_eq!(
1785            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1786            Err(ProgramError::InvalidAccountData)
1787        );
1788
1789        // input buffer invalid
1790        // all 5's - not a valid `AccountState`
1791        let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1792        assert_eq!(
1793            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1794            Err(ProgramError::UninitializedAccount)
1795        );
1796        assert_eq!(
1797            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1798            Err(ProgramError::UninitializedAccount)
1799        );
1800
1801        // tweak the account type
1802        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1803        buffer[BASE_ACCOUNT_LENGTH] = 3;
1804        assert_eq!(
1805            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1806            Err(ProgramError::InvalidAccountData)
1807        );
1808
1809        // clear the state byte
1810        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1811        buffer[108] = 0;
1812        assert_eq!(
1813            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1814            Err(ProgramError::UninitializedAccount)
1815        );
1816
1817        // tweak the extension type
1818        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1819        buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1820        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1821        assert_eq!(
1822            state.get_extension::<TransferHookAccount>(),
1823            Err(ProgramError::Custom(
1824                TokenError::ExtensionTypeMismatch as u32
1825            ))
1826        );
1827
1828        // tweak the length, too big
1829        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1830        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1831        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1832        assert_eq!(
1833            state.get_extension::<TransferHookAccount>(),
1834            Err(ProgramError::InvalidAccountData)
1835        );
1836
1837        // tweak the length, too small
1838        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1839        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1840        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1841        assert_eq!(
1842            state.get_extension::<TransferHookAccount>(),
1843            Err(ProgramError::InvalidAccountData)
1844        );
1845
1846        // data buffer is too small
1847        let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1848        let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1849        assert_eq!(
1850            state.get_extension::<TransferHookAccount>(),
1851            Err(ProgramError::InvalidAccountData)
1852        );
1853    }
1854
1855    #[test]
1856    fn get_extension_types_with_opaque_buffer() {
1857        // incorrect due to the length
1858        assert_eq!(
1859            get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1860            ProgramError::InvalidAccountData,
1861        );
1862        // incorrect due to the huge enum number
1863        assert_eq!(
1864            get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1865            ProgramError::InvalidAccountData,
1866        );
1867        // correct due to the good enum number and zero length
1868        assert_eq!(
1869            get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1870            TlvDataInfo {
1871                extension_types: vec![ExtensionType::try_from(1).unwrap()],
1872                used_len: add_type_and_length_to_len(0),
1873            }
1874        );
1875        // correct since it's just uninitialized data at the end
1876        assert_eq!(
1877            get_tlv_data_info(&[0, 0]).unwrap(),
1878            TlvDataInfo {
1879                extension_types: vec![],
1880                used_len: 0
1881            }
1882        );
1883    }
1884
1885    #[test]
1886    fn mint_with_extension_pack_unpack() {
1887        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1888            ExtensionType::MintCloseAuthority,
1889            ExtensionType::TransferFeeConfig,
1890        ])
1891        .unwrap();
1892        let mut buffer = vec![0; mint_size];
1893
1894        // fail unpack
1895        assert_eq!(
1896            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1897            Err(ProgramError::UninitializedAccount),
1898        );
1899
1900        let mut state =
1901            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1902        // fail init account extension
1903        assert_eq!(
1904            state.init_extension::<TransferFeeAmount>(true),
1905            Err(ProgramError::InvalidAccountData),
1906        );
1907
1908        // success write extension
1909        let close_authority =
1910            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
1911        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1912        extension.close_authority = close_authority;
1913        assert_eq!(
1914            &state.get_extension_types().unwrap(),
1915            &[ExtensionType::MintCloseAuthority]
1916        );
1917
1918        // fail init extension when already initialized
1919        assert_eq!(
1920            state.init_extension::<MintCloseAuthority>(false),
1921            Err(ProgramError::Custom(
1922                TokenError::ExtensionAlreadyInitialized as u32
1923            ))
1924        );
1925
1926        // fail unpack as account, a mint extension was written
1927        assert_eq!(
1928            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1929            Err(ProgramError::Custom(
1930                TokenError::ExtensionBaseMismatch as u32
1931            ))
1932        );
1933
1934        // fail unpack again, still no base data
1935        assert_eq!(
1936            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1937            Err(ProgramError::UninitializedAccount),
1938        );
1939
1940        // write base mint
1941        let mut state =
1942            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1943        *state.base = TEST_POD_MINT;
1944        state.init_account_type().unwrap();
1945
1946        // check raw buffer
1947        let mut expect = TEST_MINT_SLICE.to_vec();
1948        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1949        expect.push(AccountType::Mint.into());
1950        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1951        expect
1952            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1953        expect.extend_from_slice(&[1; 32]); // data
1954        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1955        expect.extend_from_slice(&[0; size_of::<Length>()]);
1956        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1957        assert_eq!(expect, buffer);
1958
1959        // unpack uninitialized will now fail because the PodMint is now initialized
1960        assert_eq!(
1961            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1962            Err(TokenError::AlreadyInUse.into()),
1963        );
1964
1965        // check unpacking
1966        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1967
1968        // update base
1969        *state.base = TEST_POD_MINT;
1970        state.base.supply = (u64::from(state.base.supply) + 100).into();
1971
1972        // check unpacking
1973        let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1974        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1975
1976        // update extension
1977        let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1978        unpacked_extension.close_authority = close_authority;
1979
1980        // check updates are propagated
1981        let base = *state.base;
1982        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1983        assert_eq!(state.base, &base);
1984        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1985        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1986
1987        // check raw buffer
1988        let mut expect = vec![];
1989        expect.extend_from_slice(bytemuck::bytes_of(&base));
1990        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1991        expect.push(AccountType::Mint.into());
1992        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1993        expect
1994            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1995        expect.extend_from_slice(&[0; 32]);
1996        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1997        expect.extend_from_slice(&[0; size_of::<Length>()]);
1998        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1999        assert_eq!(expect, buffer);
2000
2001        // fail unpack as an account
2002        assert_eq!(
2003            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2004            Err(ProgramError::UninitializedAccount),
2005        );
2006
2007        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2008        // init one more extension
2009        let mint_transfer_fee = test_transfer_fee_config();
2010        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2011        new_extension.transfer_fee_config_authority =
2012            mint_transfer_fee.transfer_fee_config_authority;
2013        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2014        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2015        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2016        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2017
2018        assert_eq!(
2019            &state.get_extension_types().unwrap(),
2020            &[
2021                ExtensionType::MintCloseAuthority,
2022                ExtensionType::TransferFeeConfig
2023            ]
2024        );
2025
2026        // check raw buffer
2027        let mut expect = vec![];
2028        expect.extend_from_slice(pod_bytes_of(&base));
2029        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2030        expect.push(AccountType::Mint.into());
2031        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2032        expect
2033            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
2034        expect.extend_from_slice(&[0; 32]); // data
2035        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2036        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
2037        expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
2038        assert_eq!(expect, buffer);
2039
2040        // fail to init one more extension that does not fit
2041        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2042        assert_eq!(
2043            state.init_extension::<MintPaddingTest>(true),
2044            Err(ProgramError::InvalidAccountData),
2045        );
2046    }
2047
2048    #[test]
2049    fn mint_extension_any_order() {
2050        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2051            ExtensionType::MintCloseAuthority,
2052            ExtensionType::TransferFeeConfig,
2053        ])
2054        .unwrap();
2055        let mut buffer = vec![0; mint_size];
2056
2057        let mut state =
2058            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2059        // write extensions
2060        let close_authority =
2061            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2062        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2063        extension.close_authority = close_authority;
2064
2065        let mint_transfer_fee = test_transfer_fee_config();
2066        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2067        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2068        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2069        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2070        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2071        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2072
2073        assert_eq!(
2074            &state.get_extension_types().unwrap(),
2075            &[
2076                ExtensionType::MintCloseAuthority,
2077                ExtensionType::TransferFeeConfig
2078            ]
2079        );
2080
2081        // write base mint
2082        let mut state =
2083            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2084        *state.base = TEST_POD_MINT;
2085        state.init_account_type().unwrap();
2086
2087        let mut other_buffer = vec![0; mint_size];
2088        let mut state =
2089            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2090
2091        // write base mint
2092        *state.base = TEST_POD_MINT;
2093        state.init_account_type().unwrap();
2094
2095        // write extensions in a different order
2096        let mint_transfer_fee = test_transfer_fee_config();
2097        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2098        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2099        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2100        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2101        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2102        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2103
2104        let close_authority =
2105            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
2106        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2107        extension.close_authority = close_authority;
2108
2109        assert_eq!(
2110            &state.get_extension_types().unwrap(),
2111            &[
2112                ExtensionType::TransferFeeConfig,
2113                ExtensionType::MintCloseAuthority
2114            ]
2115        );
2116
2117        // buffers are NOT the same because written in a different order
2118        assert_ne!(buffer, other_buffer);
2119        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2120        let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2121
2122        // BUT mint and extensions are the same
2123        assert_eq!(
2124            state.get_extension::<TransferFeeConfig>().unwrap(),
2125            other_state.get_extension::<TransferFeeConfig>().unwrap()
2126        );
2127        assert_eq!(
2128            state.get_extension::<MintCloseAuthority>().unwrap(),
2129            other_state.get_extension::<MintCloseAuthority>().unwrap()
2130        );
2131        assert_eq!(state.base, other_state.base);
2132    }
2133
2134    #[test]
2135    fn mint_with_multisig_len() {
2136        let mut buffer = vec![0; Multisig::LEN];
2137        assert_eq!(
2138            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2139            Err(ProgramError::InvalidAccountData),
2140        );
2141        let mint_size =
2142            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2143                .unwrap();
2144        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2145        let mut buffer = vec![0; mint_size];
2146
2147        // write base mint
2148        let mut state =
2149            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2150        *state.base = TEST_POD_MINT;
2151        state.init_account_type().unwrap();
2152
2153        // write padding
2154        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2155        extension.padding1 = [1; 128];
2156        extension.padding2 = [1; 48];
2157        extension.padding3 = [1; 9];
2158
2159        assert_eq!(
2160            &state.get_extension_types().unwrap(),
2161            &[ExtensionType::MintPaddingTest]
2162        );
2163
2164        // check raw buffer
2165        let mut expect = TEST_MINT_SLICE.to_vec();
2166        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2167        expect.push(AccountType::Mint.into());
2168        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2169        expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
2170        expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
2171        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2172        assert_eq!(expect, buffer);
2173    }
2174
2175    #[test]
2176    fn account_with_extension_pack_unpack() {
2177        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2178            ExtensionType::TransferFeeAmount,
2179        ])
2180        .unwrap();
2181        let mut buffer = vec![0; account_size];
2182
2183        // fail unpack
2184        assert_eq!(
2185            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2186            Err(ProgramError::UninitializedAccount),
2187        );
2188
2189        let mut state =
2190            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2191        // fail init mint extension
2192        assert_eq!(
2193            state.init_extension::<TransferFeeConfig>(true),
2194            Err(ProgramError::InvalidAccountData),
2195        );
2196        // success write extension
2197        let withheld_amount = PodU64::from(u64::MAX);
2198        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2199        extension.withheld_amount = withheld_amount;
2200
2201        assert_eq!(
2202            &state.get_extension_types().unwrap(),
2203            &[ExtensionType::TransferFeeAmount]
2204        );
2205
2206        // fail unpack again, still no base data
2207        assert_eq!(
2208            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2209            Err(ProgramError::UninitializedAccount),
2210        );
2211
2212        // write base account
2213        let mut state =
2214            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2215        *state.base = TEST_POD_ACCOUNT;
2216        state.init_account_type().unwrap();
2217        let base = *state.base;
2218
2219        // check raw buffer
2220        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2221        expect.push(AccountType::Account.into());
2222        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2223        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2224        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2225        assert_eq!(expect, buffer);
2226
2227        // check unpacking
2228        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2229        assert_eq!(state.base, &base);
2230        assert_eq!(
2231            &state.get_extension_types().unwrap(),
2232            &[ExtensionType::TransferFeeAmount]
2233        );
2234
2235        // update base
2236        *state.base = TEST_POD_ACCOUNT;
2237        state.base.amount = (u64::from(state.base.amount) + 100).into();
2238
2239        // check unpacking
2240        let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2241        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2242
2243        // update extension
2244        let withheld_amount = PodU64::from(u32::MAX as u64);
2245        unpacked_extension.withheld_amount = withheld_amount;
2246
2247        // check updates are propagated
2248        let base = *state.base;
2249        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2250        assert_eq!(state.base, &base);
2251        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2252        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2253
2254        // check raw buffer
2255        let mut expect = vec![];
2256        expect.extend_from_slice(pod_bytes_of(&base));
2257        expect.push(AccountType::Account.into());
2258        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2259        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
2260        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2261        assert_eq!(expect, buffer);
2262
2263        // fail unpack as a mint
2264        assert_eq!(
2265            PodStateWithExtensions::<PodMint>::unpack(&buffer),
2266            Err(ProgramError::InvalidAccountData),
2267        );
2268    }
2269
2270    #[test]
2271    fn account_with_multisig_len() {
2272        let mut buffer = vec![0; Multisig::LEN];
2273        assert_eq!(
2274            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2275            Err(ProgramError::InvalidAccountData),
2276        );
2277        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2278            ExtensionType::AccountPaddingTest,
2279        ])
2280        .unwrap();
2281        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2282        let mut buffer = vec![0; account_size];
2283
2284        // write base account
2285        let mut state =
2286            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2287        *state.base = TEST_POD_ACCOUNT;
2288        state.init_account_type().unwrap();
2289
2290        // write padding
2291        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2292        extension.0.padding1 = [2; 128];
2293        extension.0.padding2 = [2; 48];
2294        extension.0.padding3 = [2; 9];
2295
2296        assert_eq!(
2297            &state.get_extension_types().unwrap(),
2298            &[ExtensionType::AccountPaddingTest]
2299        );
2300
2301        // check raw buffer
2302        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2303        expect.push(AccountType::Account.into());
2304        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2305        expect
2306            .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
2307        expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
2308        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2309        assert_eq!(expect, buffer);
2310    }
2311
2312    #[test]
2313    fn test_set_account_type() {
2314        // account with buffer big enough for AccountType and Extension
2315        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2316        let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2317            ExtensionType::ImmutableOwner,
2318        ])
2319        .unwrap()
2320            - buffer.len();
2321        buffer.append(&mut vec![0; needed_len]);
2322        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2323        assert_eq!(err, ProgramError::InvalidAccountData);
2324        set_account_type::<PodAccount>(&mut buffer).unwrap();
2325        // unpack is viable after manual set_account_type
2326        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2327        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2328        assert_eq!(state.account_type[0], AccountType::Account as u8);
2329        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
2330
2331        // account with buffer big enough for AccountType only
2332        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2333        buffer.append(&mut vec![0; 2]);
2334        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2335        assert_eq!(err, ProgramError::InvalidAccountData);
2336        set_account_type::<PodAccount>(&mut buffer).unwrap();
2337        // unpack is viable after manual set_account_type
2338        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2339        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2340        assert_eq!(state.account_type[0], AccountType::Account as u8);
2341
2342        // account with AccountType already set => noop
2343        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2344        buffer.append(&mut vec![2, 0]);
2345        let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2346        set_account_type::<PodAccount>(&mut buffer).unwrap();
2347        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2348        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2349        assert_eq!(state.account_type[0], AccountType::Account as u8);
2350
2351        // account with wrong AccountType fails
2352        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2353        buffer.append(&mut vec![1, 0]);
2354        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2355        assert_eq!(err, ProgramError::InvalidAccountData);
2356        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2357        assert_eq!(err, ProgramError::InvalidAccountData);
2358
2359        // mint with buffer big enough for AccountType and Extension
2360        let mut buffer = TEST_MINT_SLICE.to_vec();
2361        let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2362            ExtensionType::MintCloseAuthority,
2363        ])
2364        .unwrap()
2365            - buffer.len();
2366        buffer.append(&mut vec![0; needed_len]);
2367        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2368        assert_eq!(err, ProgramError::InvalidAccountData);
2369        set_account_type::<PodMint>(&mut buffer).unwrap();
2370        // unpack is viable after manual set_account_type
2371        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2372        assert_eq!(state.base, &TEST_POD_MINT);
2373        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2374        state.init_extension::<MintCloseAuthority>(true).unwrap();
2375
2376        // mint with buffer big enough for AccountType only
2377        let mut buffer = TEST_MINT_SLICE.to_vec();
2378        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2379        buffer.append(&mut vec![0; 2]);
2380        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2381        assert_eq!(err, ProgramError::InvalidAccountData);
2382        set_account_type::<PodMint>(&mut buffer).unwrap();
2383        // unpack is viable after manual set_account_type
2384        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2385        assert_eq!(state.base, &TEST_POD_MINT);
2386        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2387
2388        // mint with AccountType already set => noop
2389        let mut buffer = TEST_MINT_SLICE.to_vec();
2390        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2391        buffer.append(&mut vec![1, 0]);
2392        set_account_type::<PodMint>(&mut buffer).unwrap();
2393        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2394        assert_eq!(state.base, &TEST_POD_MINT);
2395        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2396
2397        // mint with wrong AccountType fails
2398        let mut buffer = TEST_MINT_SLICE.to_vec();
2399        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2400        buffer.append(&mut vec![2, 0]);
2401        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2402        assert_eq!(err, ProgramError::InvalidAccountData);
2403        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2404        assert_eq!(err, ProgramError::InvalidAccountData);
2405    }
2406
2407    #[test]
2408    fn test_set_account_type_wrongly() {
2409        // try to set PodAccount account_type to PodMint
2410        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2411        buffer.append(&mut vec![0; 2]);
2412        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2413        assert_eq!(err, ProgramError::InvalidAccountData);
2414
2415        // try to set PodMint account_type to PodAccount
2416        let mut buffer = TEST_MINT_SLICE.to_vec();
2417        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2418        buffer.append(&mut vec![0; 2]);
2419        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2420        assert_eq!(err, ProgramError::InvalidAccountData);
2421    }
2422
2423    #[test]
2424    fn test_get_required_init_account_extensions() {
2425        // Some mint extensions with no required account extensions
2426        let mint_extensions = vec![
2427            ExtensionType::MintCloseAuthority,
2428            ExtensionType::Uninitialized,
2429        ];
2430        assert_eq!(
2431            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2432            vec![]
2433        );
2434
2435        // One mint extension with required account extension, one without
2436        let mint_extensions = vec![
2437            ExtensionType::TransferFeeConfig,
2438            ExtensionType::MintCloseAuthority,
2439        ];
2440        assert_eq!(
2441            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2442            vec![ExtensionType::TransferFeeAmount]
2443        );
2444
2445        // Some mint extensions both with required account extensions
2446        let mint_extensions = vec![
2447            ExtensionType::TransferFeeConfig,
2448            ExtensionType::MintPaddingTest,
2449        ];
2450        assert_eq!(
2451            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2452            vec![
2453                ExtensionType::TransferFeeAmount,
2454                ExtensionType::AccountPaddingTest
2455            ]
2456        );
2457
2458        // Demonstrate that method does not dedupe inputs or outputs
2459        let mint_extensions = vec![
2460            ExtensionType::TransferFeeConfig,
2461            ExtensionType::TransferFeeConfig,
2462        ];
2463        assert_eq!(
2464            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2465            vec![
2466                ExtensionType::TransferFeeAmount,
2467                ExtensionType::TransferFeeAmount
2468            ]
2469        );
2470    }
2471
2472    #[test]
2473    fn mint_without_extensions() {
2474        let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2475        let mut buffer = vec![0; space];
2476        assert_eq!(
2477            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2478            Err(ProgramError::InvalidAccountData),
2479        );
2480
2481        // write base account
2482        let mut state =
2483            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2484        *state.base = TEST_POD_MINT;
2485        state.init_account_type().unwrap();
2486
2487        // fail init extension
2488        assert_eq!(
2489            state.init_extension::<TransferFeeConfig>(true),
2490            Err(ProgramError::InvalidAccountData),
2491        );
2492
2493        assert_eq!(TEST_MINT_SLICE, buffer);
2494    }
2495
2496    #[test]
2497    fn test_init_nonzero_default() {
2498        let mint_size =
2499            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2500                .unwrap();
2501        let mut buffer = vec![0; mint_size];
2502        let mut state =
2503            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2504        *state.base = TEST_POD_MINT;
2505        state.init_account_type().unwrap();
2506        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2507        assert_eq!(extension.padding1, [1; 128]);
2508        assert_eq!(extension.padding2, [2; 48]);
2509        assert_eq!(extension.padding3, [3; 9]);
2510    }
2511
2512    #[test]
2513    fn test_init_buffer_too_small() {
2514        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2515            ExtensionType::MintCloseAuthority,
2516        ])
2517        .unwrap();
2518        let mut buffer = vec![0; mint_size - 1];
2519        let mut state =
2520            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2521        let err = state
2522            .init_extension::<MintCloseAuthority>(true)
2523            .unwrap_err();
2524        assert_eq!(err, ProgramError::InvalidAccountData);
2525
2526        state.tlv_data[0] = 3;
2527        state.tlv_data[2] = 32;
2528        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2529        assert_eq!(err, ProgramError::InvalidAccountData);
2530
2531        let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2532        let err =
2533            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2534        assert_eq!(err, ProgramError::InvalidAccountData);
2535
2536        // OK since there are two bytes for the type, which is `Uninitialized`
2537        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2538        let mut state =
2539            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2540        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2541        assert_eq!(err, ProgramError::InvalidAccountData);
2542
2543        assert_eq!(state.get_extension_types().unwrap(), vec![]);
2544
2545        // OK, there aren't two bytes for the type, but that's fine
2546        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2547        let state =
2548            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2549        assert_eq!(state.get_extension_types().unwrap(), []);
2550    }
2551
2552    #[test]
2553    fn test_extension_with_no_data() {
2554        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2555            ExtensionType::ImmutableOwner,
2556        ])
2557        .unwrap();
2558        let mut buffer = vec![0; account_size];
2559        let mut state =
2560            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2561        *state.base = TEST_POD_ACCOUNT;
2562        state.init_account_type().unwrap();
2563
2564        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2565        assert_eq!(
2566            err,
2567            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2568        );
2569
2570        state.init_extension::<ImmutableOwner>(true).unwrap();
2571        assert_eq!(
2572            get_first_extension_type(state.tlv_data).unwrap(),
2573            Some(ExtensionType::ImmutableOwner)
2574        );
2575        assert_eq!(
2576            get_tlv_data_info(state.tlv_data).unwrap(),
2577            TlvDataInfo {
2578                extension_types: vec![ExtensionType::ImmutableOwner],
2579                used_len: add_type_and_length_to_len(0)
2580            }
2581        );
2582    }
2583
2584    #[test]
2585    fn fail_account_len_with_metadata() {
2586        assert_eq!(
2587            ExtensionType::try_calculate_account_len::<PodMint>(&[
2588                ExtensionType::MintCloseAuthority,
2589                ExtensionType::VariableLenMintTest,
2590                ExtensionType::TransferFeeConfig,
2591            ])
2592            .unwrap_err(),
2593            ProgramError::InvalidArgument
2594        );
2595    }
2596
2597    #[test]
2598    fn alloc() {
2599        let variable_len = VariableLenMintTest { data: vec![1] };
2600        let alloc_size = variable_len.get_packed_len().unwrap();
2601        let account_size =
2602            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2603        let mut buffer = vec![0; account_size];
2604        let mut state =
2605            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2606        state
2607            .init_variable_len_extension(&variable_len, false)
2608            .unwrap();
2609
2610        // can't double alloc
2611        assert_eq!(
2612            state
2613                .init_variable_len_extension(&variable_len, false)
2614                .unwrap_err(),
2615            TokenError::ExtensionAlreadyInitialized.into()
2616        );
2617
2618        // unless overwrite is set
2619        state
2620            .init_variable_len_extension(&variable_len, true)
2621            .unwrap();
2622
2623        // can't change the size during overwrite though
2624        assert_eq!(
2625            state
2626                .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2627                .unwrap_err(),
2628            TokenError::InvalidLengthForAlloc.into()
2629        );
2630
2631        // try to write too far, fail earlier
2632        assert_eq!(
2633            state
2634                .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2635                .unwrap_err(),
2636            ProgramError::InvalidAccountData
2637        );
2638    }
2639
2640    #[test]
2641    fn realloc() {
2642        let small_variable_len = VariableLenMintTest {
2643            data: vec![1, 2, 3],
2644        };
2645        let base_variable_len = VariableLenMintTest {
2646            data: vec![1, 2, 3, 4],
2647        };
2648        let big_variable_len = VariableLenMintTest {
2649            data: vec![1, 2, 3, 4, 5],
2650        };
2651        let too_big_variable_len = VariableLenMintTest {
2652            data: vec![1, 2, 3, 4, 5, 6],
2653        };
2654        let account_size =
2655            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2656                .unwrap()
2657                + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2658        let mut buffer = vec![0; account_size];
2659        let mut state =
2660            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2661
2662        // alloc both types
2663        state
2664            .init_variable_len_extension(&base_variable_len, false)
2665            .unwrap();
2666        let max_pubkey =
2667            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
2668        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2669        extension.authority = max_pubkey;
2670        extension.metadata_address = max_pubkey;
2671
2672        // realloc first entry to larger
2673        state
2674            .realloc_variable_len_extension(&big_variable_len)
2675            .unwrap();
2676        let extension = state
2677            .get_variable_len_extension::<VariableLenMintTest>()
2678            .unwrap();
2679        assert_eq!(extension, big_variable_len);
2680        let extension = state.get_extension::<MetadataPointer>().unwrap();
2681        assert_eq!(extension.authority, max_pubkey);
2682        assert_eq!(extension.metadata_address, max_pubkey);
2683
2684        // realloc to smaller
2685        state
2686            .realloc_variable_len_extension(&small_variable_len)
2687            .unwrap();
2688        let extension = state
2689            .get_variable_len_extension::<VariableLenMintTest>()
2690            .unwrap();
2691        assert_eq!(extension, small_variable_len);
2692        let extension = state.get_extension::<MetadataPointer>().unwrap();
2693        assert_eq!(extension.authority, max_pubkey);
2694        assert_eq!(extension.metadata_address, max_pubkey);
2695        let diff = big_variable_len.get_packed_len().unwrap()
2696            - small_variable_len.get_packed_len().unwrap();
2697        assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2698
2699        // unpack again since we dropped the last `state`
2700        let mut state =
2701            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2702        // realloc too much, fails
2703        assert_eq!(
2704            state
2705                .realloc_variable_len_extension(&too_big_variable_len)
2706                .unwrap_err(),
2707            ProgramError::InvalidAccountData,
2708        );
2709    }
2710
2711    #[test]
2712    fn account_len() {
2713        let small_variable_len = VariableLenMintTest {
2714            data: vec![20, 30, 40],
2715        };
2716        let variable_len = VariableLenMintTest {
2717            data: vec![20, 30, 40, 50],
2718        };
2719        let big_variable_len = VariableLenMintTest {
2720            data: vec![20, 30, 40, 50, 60],
2721        };
2722        let value_len = variable_len.get_packed_len().unwrap();
2723        let account_size =
2724            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2725        let mut buffer = vec![0; account_size];
2726        let mut state =
2727            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2728
2729        // With a new extension, new length must include padding, 1 byte for
2730        // account type, 2 bytes for type, 2 for length
2731        let current_len = state.try_get_account_len().unwrap();
2732        assert_eq!(current_len, PodMint::SIZE_OF);
2733        let new_len = state
2734            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2735                &variable_len,
2736            )
2737            .unwrap();
2738        assert_eq!(
2739            new_len,
2740            BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2741        );
2742
2743        state
2744            .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2745            .unwrap();
2746        let current_len = state.try_get_account_len().unwrap();
2747        assert_eq!(current_len, new_len);
2748
2749        // Reduce the extension size
2750        let new_len = state
2751            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2752                &small_variable_len,
2753            )
2754            .unwrap();
2755        assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2756
2757        // Increase the extension size
2758        let new_len = state
2759            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2760                &big_variable_len,
2761            )
2762            .unwrap();
2763        assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2764
2765        // Maintain the extension size
2766        let new_len = state
2767            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2768                &variable_len,
2769            )
2770            .unwrap();
2771        assert_eq!(new_len, current_len);
2772    }
2773
2774    /// Test helper for mimicking the data layout an on-chain `AccountInfo`,
2775    /// which permits "reallocs" as the Solana runtime does it
2776    struct SolanaAccountData {
2777        data: Vec<u8>,
2778        lamports: u64,
2779        owner: Pubkey,
2780    }
2781    impl SolanaAccountData {
2782        /// Create a new fake solana account data. The underlying vector is
2783        /// overallocated to mimic the runtime
2784        fn new(account_data: &[u8]) -> Self {
2785            let mut data = vec![];
2786            data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2787            data.extend_from_slice(account_data);
2788            data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2789            Self {
2790                data,
2791                lamports: 10,
2792                owner: Pubkey::new_unique(),
2793            }
2794        }
2795
2796        /// Data lops off the first 8 bytes, since those store the size of the
2797        /// account for the Solana runtime
2798        fn data(&self) -> &[u8] {
2799            let start = size_of::<u64>();
2800            let len = self.len();
2801            &self.data[start..start + len]
2802        }
2803
2804        /// Gets the runtime length of the account data
2805        fn len(&self) -> usize {
2806            self.data
2807                .get(..size_of::<u64>())
2808                .and_then(|slice| slice.try_into().ok())
2809                .map(u64::from_le_bytes)
2810                .unwrap() as usize
2811        }
2812    }
2813    impl GetAccount for SolanaAccountData {
2814        fn get(&mut self) -> (&mut u64, &mut [u8], &Pubkey, bool, Epoch) {
2815            // need to pull out the data here to avoid a double-mutable borrow
2816            let start = size_of::<u64>();
2817            let len = self.len();
2818            (
2819                &mut self.lamports,
2820                &mut self.data[start..start + len],
2821                &self.owner,
2822                false,
2823                Epoch::default(),
2824            )
2825        }
2826    }
2827
2828    #[test]
2829    fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2830        let fixed_len = FixedLenMintTest {
2831            data: [1, 2, 3, 4, 5, 6, 7, 8],
2832        };
2833        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2834        let base_account_size = PodMint::SIZE_OF;
2835        let mut buffer = vec![0; base_account_size];
2836        let state =
2837            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2838        *state.base = TEST_POD_MINT;
2839
2840        let mut data = SolanaAccountData::new(&buffer);
2841        let key = Pubkey::new_unique();
2842        let account_info = (&key, &mut data).into_account_info();
2843
2844        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2845        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2846        assert_eq!(data.len(), new_account_len);
2847        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2848        assert_eq!(
2849            state.get_extension::<FixedLenMintTest>().unwrap(),
2850            &fixed_len,
2851        );
2852
2853        // alloc again succeeds with "overwrite"
2854        let account_info = (&key, &mut data).into_account_info();
2855        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2856
2857        // alloc again fails without "overwrite"
2858        let account_info = (&key, &mut data).into_account_info();
2859        assert_eq!(
2860            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2861            TokenError::ExtensionAlreadyInitialized.into()
2862        );
2863    }
2864
2865    #[test]
2866    fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2867        let variable_len = VariableLenMintTest { data: vec![20, 99] };
2868        let value_len = variable_len.get_packed_len().unwrap();
2869        let base_account_size = PodMint::SIZE_OF;
2870        let mut buffer = vec![0; base_account_size];
2871        let state =
2872            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2873        *state.base = TEST_POD_MINT;
2874
2875        let mut data = SolanaAccountData::new(&buffer);
2876        let key = Pubkey::new_unique();
2877        let account_info = (&key, &mut data).into_account_info();
2878
2879        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2880            &account_info,
2881            &variable_len,
2882            false,
2883        )
2884        .unwrap();
2885        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2886        assert_eq!(data.len(), new_account_len);
2887        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2888        assert_eq!(
2889            state
2890                .get_variable_len_extension::<VariableLenMintTest>()
2891                .unwrap(),
2892            variable_len
2893        );
2894
2895        // alloc again succeeds with "overwrite"
2896        let account_info = (&key, &mut data).into_account_info();
2897        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2898            &account_info,
2899            &variable_len,
2900            true,
2901        )
2902        .unwrap();
2903
2904        // alloc again fails without "overwrite"
2905        let account_info = (&key, &mut data).into_account_info();
2906        assert_eq!(
2907            alloc_and_serialize_variable_len_extension::<PodMint, _>(
2908                &account_info,
2909                &variable_len,
2910                false,
2911            )
2912            .unwrap_err(),
2913            TokenError::ExtensionAlreadyInitialized.into()
2914        );
2915    }
2916
2917    #[test]
2918    fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2919        let fixed_len = FixedLenMintTest {
2920            data: [1, 2, 3, 4, 5, 6, 7, 8],
2921        };
2922        let value_len = pod_get_packed_len::<FixedLenMintTest>();
2923        let account_size =
2924            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2925                .unwrap()
2926                + add_type_and_length_to_len(value_len);
2927        let mut buffer = vec![0; account_size];
2928        let mut state =
2929            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2930        *state.base = TEST_POD_MINT;
2931        state.init_account_type().unwrap();
2932
2933        let test_key =
2934            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2935        let extension = state.init_extension::<GroupPointer>(false).unwrap();
2936        extension.authority = test_key;
2937        extension.group_address = test_key;
2938
2939        let mut data = SolanaAccountData::new(&buffer);
2940        let key = Pubkey::new_unique();
2941        let account_info = (&key, &mut data).into_account_info();
2942
2943        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2944        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2945            + add_type_and_length_to_len(value_len)
2946            + add_type_and_length_to_len(size_of::<GroupPointer>());
2947        assert_eq!(data.len(), new_account_len);
2948        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2949        assert_eq!(
2950            state.get_extension::<FixedLenMintTest>().unwrap(),
2951            &fixed_len,
2952        );
2953        let extension = state.get_extension::<GroupPointer>().unwrap();
2954        assert_eq!(extension.authority, test_key);
2955        assert_eq!(extension.group_address, test_key);
2956
2957        // alloc again succeeds with "overwrite"
2958        let account_info = (&key, &mut data).into_account_info();
2959        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2960
2961        // alloc again fails without "overwrite"
2962        let account_info = (&key, &mut data).into_account_info();
2963        assert_eq!(
2964            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2965            TokenError::ExtensionAlreadyInitialized.into()
2966        );
2967    }
2968
2969    #[test]
2970    fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2971        let variable_len = VariableLenMintTest { data: vec![42, 6] };
2972        let value_len = variable_len.get_packed_len().unwrap();
2973        let account_size =
2974            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2975                .unwrap()
2976                + add_type_and_length_to_len(value_len);
2977        let mut buffer = vec![0; account_size];
2978        let mut state =
2979            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2980        *state.base = TEST_POD_MINT;
2981        state.init_account_type().unwrap();
2982
2983        let test_key =
2984            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
2985        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2986        extension.authority = test_key;
2987        extension.metadata_address = test_key;
2988
2989        let mut data = SolanaAccountData::new(&buffer);
2990        let key = Pubkey::new_unique();
2991        let account_info = (&key, &mut data).into_account_info();
2992
2993        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2994            &account_info,
2995            &variable_len,
2996            false,
2997        )
2998        .unwrap();
2999        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3000            + add_type_and_length_to_len(value_len)
3001            + add_type_and_length_to_len(size_of::<MetadataPointer>());
3002        assert_eq!(data.len(), new_account_len);
3003        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3004        assert_eq!(
3005            state
3006                .get_variable_len_extension::<VariableLenMintTest>()
3007                .unwrap(),
3008            variable_len
3009        );
3010        let extension = state.get_extension::<MetadataPointer>().unwrap();
3011        assert_eq!(extension.authority, test_key);
3012        assert_eq!(extension.metadata_address, test_key);
3013
3014        // alloc again succeeds with "overwrite"
3015        let account_info = (&key, &mut data).into_account_info();
3016        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3017            &account_info,
3018            &variable_len,
3019            true,
3020        )
3021        .unwrap();
3022
3023        // alloc again fails without "overwrite"
3024        let account_info = (&key, &mut data).into_account_info();
3025        assert_eq!(
3026            alloc_and_serialize_variable_len_extension::<PodMint, _>(
3027                &account_info,
3028                &variable_len,
3029                false,
3030            )
3031            .unwrap_err(),
3032            TokenError::ExtensionAlreadyInitialized.into()
3033        );
3034    }
3035
3036    #[test]
3037    fn realloc_variable_len_tlv_in_account_info() {
3038        let variable_len = VariableLenMintTest {
3039            data: vec![1, 2, 3, 4, 5],
3040        };
3041        let alloc_size = variable_len.get_packed_len().unwrap();
3042        let account_size =
3043            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3044                .unwrap()
3045                + add_type_and_length_to_len(alloc_size);
3046        let mut buffer = vec![0; account_size];
3047        let mut state =
3048            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3049        *state.base = TEST_POD_MINT;
3050        state.init_account_type().unwrap();
3051
3052        // alloc both types
3053        state
3054            .init_variable_len_extension(&variable_len, false)
3055            .unwrap();
3056        let max_pubkey =
3057            OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
3058        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3059        extension.authority = max_pubkey;
3060        extension.metadata_address = max_pubkey;
3061
3062        // reallocate to smaller, make sure existing extension is fine
3063        let mut data = SolanaAccountData::new(&buffer);
3064        let key = Pubkey::new_unique();
3065        let account_info = (&key, &mut data).into_account_info();
3066        let variable_len = VariableLenMintTest { data: vec![1, 2] };
3067        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3068            &account_info,
3069            &variable_len,
3070            true,
3071        )
3072        .unwrap();
3073
3074        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3075        let extension = state.get_extension::<MetadataPointer>().unwrap();
3076        assert_eq!(extension.authority, max_pubkey);
3077        assert_eq!(extension.metadata_address, max_pubkey);
3078        let extension = state
3079            .get_variable_len_extension::<VariableLenMintTest>()
3080            .unwrap();
3081        assert_eq!(extension, variable_len);
3082        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3083
3084        // reallocate to larger
3085        let account_info = (&key, &mut data).into_account_info();
3086        let variable_len = VariableLenMintTest {
3087            data: vec![1, 2, 3, 4, 5, 6, 7],
3088        };
3089        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3090            &account_info,
3091            &variable_len,
3092            true,
3093        )
3094        .unwrap();
3095
3096        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3097        let extension = state.get_extension::<MetadataPointer>().unwrap();
3098        assert_eq!(extension.authority, max_pubkey);
3099        assert_eq!(extension.metadata_address, max_pubkey);
3100        let extension = state
3101            .get_variable_len_extension::<VariableLenMintTest>()
3102            .unwrap();
3103        assert_eq!(extension, variable_len);
3104        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3105
3106        // reallocate to same
3107        let account_info = (&key, &mut data).into_account_info();
3108        let variable_len = VariableLenMintTest {
3109            data: vec![7, 6, 5, 4, 3, 2, 1],
3110        };
3111        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3112            &account_info,
3113            &variable_len,
3114            true,
3115        )
3116        .unwrap();
3117
3118        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3119        let extension = state.get_extension::<MetadataPointer>().unwrap();
3120        assert_eq!(extension.authority, max_pubkey);
3121        assert_eq!(extension.metadata_address, max_pubkey);
3122        let extension = state
3123            .get_variable_len_extension::<VariableLenMintTest>()
3124            .unwrap();
3125        assert_eq!(extension, variable_len);
3126        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3127    }
3128}