Skip to main content

spl_token_2022_interface/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use {
6    crate::{
7        error::TokenError,
8        extension::{
9            account_len::TlvLenAccumulator,
10            confidential_mint_burn::ConfidentialMintBurn,
11            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
12            confidential_transfer_fee::{
13                ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
14            },
15            cpi_guard::CpiGuard,
16            default_account_state::DefaultAccountState,
17            group_member_pointer::GroupMemberPointer,
18            group_pointer::GroupPointer,
19            immutable_owner::ImmutableOwner,
20            interest_bearing_mint::InterestBearingConfig,
21            memo_transfer::MemoTransfer,
22            metadata_pointer::MetadataPointer,
23            mint_close_authority::MintCloseAuthority,
24            non_transferable::{NonTransferable, NonTransferableAccount},
25            pausable::{PausableAccount, PausableConfig},
26            permanent_delegate::PermanentDelegate,
27            permissioned_burn::PermissionedBurnConfig,
28            scaled_ui_amount::ScaledUiAmountConfig,
29            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
30            transfer_hook::{TransferHook, TransferHookAccount},
31        },
32        pod::{PodAccount, PodMint},
33        state::{Account, Mint, Multisig, PackedSizeOf},
34    },
35    alloc::{vec, vec::Vec},
36    bytemuck::{Pod, Zeroable},
37    core::{
38        cmp::Ordering,
39        convert::{TryFrom, TryInto},
40        mem::size_of,
41    },
42    num_enum::{IntoPrimitive, TryFromPrimitive},
43    solana_account_info::AccountInfo,
44    solana_program_error::ProgramError,
45    solana_program_pack::{IsInitialized, Pack},
46    solana_zero_copy::unaligned::U16,
47    spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
48    spl_type_length_value::variable_len_pack::VariableLenPack,
49};
50
51/// Account length calculation helpers
52pub mod account_len;
53/// Confidential Transfer extension
54pub mod confidential_transfer;
55/// Confidential Transfer Fee extension
56pub mod confidential_transfer_fee;
57/// CPI Guard extension
58pub mod cpi_guard;
59/// Default Account State extension
60pub mod default_account_state;
61/// Group Member Pointer extension
62pub mod group_member_pointer;
63/// Group Pointer extension
64pub mod group_pointer;
65/// Immutable Owner extension
66pub mod immutable_owner;
67/// Interest-Bearing Mint extension
68pub mod interest_bearing_mint;
69/// Memo Transfer extension
70pub mod memo_transfer;
71/// Metadata Pointer extension
72pub mod metadata_pointer;
73/// Mint Close Authority extension
74pub mod mint_close_authority;
75/// Non Transferable extension
76pub mod non_transferable;
77/// Pausable extension
78pub mod pausable;
79/// Permanent Delegate extension
80pub mod permanent_delegate;
81/// Permissioned burn extension
82pub mod permissioned_burn;
83/// Scaled UI Amount extension
84pub mod scaled_ui_amount;
85/// Token-group extension
86pub mod token_group;
87/// Token-metadata extension
88pub mod token_metadata;
89/// Transfer Fee extension
90pub mod transfer_fee;
91/// Transfer Hook extension
92pub mod transfer_hook;
93
94/// Confidential mint-burn extension
95pub mod confidential_mint_burn;
96
97/// Length in TLV structure
98#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
99#[repr(transparent)]
100pub struct Length(U16);
101impl From<Length> for usize {
102    fn from(n: Length) -> Self {
103        Self::from(u16::from(n.0))
104    }
105}
106impl TryFrom<usize> for Length {
107    type Error = ProgramError;
108    fn try_from(n: usize) -> Result<Self, Self::Error> {
109        u16::try_from(n)
110            .map(|v| Self(U16::from(v)))
111            .map_err(|_| ProgramError::AccountDataTooSmall)
112    }
113}
114
115/// Helper function to get the current `TlvIndices` from the current spot
116fn get_tlv_indices(type_start: usize) -> TlvIndices {
117    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
118    let value_start = length_start.saturating_add(size_of::<Length>());
119    TlvIndices {
120        type_start,
121        length_start,
122        value_start,
123    }
124}
125
126/// Helper function to tack on the size of an extension bytes if an account with
127/// extensions is exactly the size of a multisig
128const fn adjust_len_for_multisig(account_len: usize) -> usize {
129    if account_len == Multisig::LEN {
130        account_len.saturating_add(size_of::<ExtensionType>())
131    } else {
132        account_len
133    }
134}
135
136/// Helper function to calculate exactly how many bytes a value will take up,
137/// given the value's length
138const fn add_type_and_length_to_len(value_len: usize) -> usize {
139    value_len
140        .saturating_add(size_of::<ExtensionType>())
141        .saturating_add(size_of::<Length>())
142}
143
144/// Helper struct for returning the indices of the type, length, and value in
145/// a TLV entry
146#[derive(Debug)]
147struct TlvIndices {
148    pub type_start: usize,
149    pub length_start: usize,
150    pub value_start: usize,
151}
152fn get_extension_indices<V: Extension>(
153    tlv_data: &[u8],
154    init: bool,
155) -> Result<TlvIndices, ProgramError> {
156    let mut start_index = 0;
157    while start_index < tlv_data.len() {
158        let tlv_indices = get_tlv_indices(start_index);
159        if tlv_data.len() < tlv_indices.value_start {
160            return Err(ProgramError::InvalidAccountData);
161        }
162        let extension_type = u16::from_le_bytes(
163            tlv_data[tlv_indices.type_start..tlv_indices.length_start]
164                .try_into()
165                .map_err(|_| ProgramError::InvalidAccountData)?,
166        );
167        if extension_type == u16::from(V::TYPE) {
168            // found an instance of the extension that we're initializing, return!
169            return Ok(tlv_indices);
170        // got to an empty spot, init here, or error if we're searching, since
171        // nothing is written after an Uninitialized spot
172        } else if extension_type == u16::from(ExtensionType::Uninitialized) {
173            if init {
174                return Ok(tlv_indices);
175            } else {
176                return Err(TokenError::ExtensionNotFound.into());
177            }
178        } else {
179            let length = bytemuck::try_from_bytes::<Length>(
180                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
181            )
182            .map_err(|_| ProgramError::InvalidArgument)?;
183            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
184            start_index = value_end_index;
185        }
186    }
187    Err(ProgramError::InvalidAccountData)
188}
189
190/// Basic information about the TLV buffer, collected from iterating through all
191/// entries
192#[derive(Debug, PartialEq)]
193struct TlvDataInfo {
194    /// The extension types written in the TLV buffer
195    extension_types: Vec<ExtensionType>,
196    /// The total number bytes allocated for all TLV entries.
197    ///
198    /// Each TLV entry's allocated bytes comprises two bytes for the `type`, two
199    /// bytes for the `length`, and `length` number of bytes for the `value`.
200    used_len: usize,
201}
202
203/// Iterates through all TLV entries, passing each initialized extension type
204/// to `f`, and returns the total number of bytes used by all TLV entries.
205fn try_for_each_tlv_extension_type<F>(tlv_data: &[u8], mut f: F) -> Result<usize, ProgramError>
206where
207    F: FnMut(ExtensionType) -> Result<(), ProgramError>,
208{
209    let mut start_index = 0;
210    while start_index < tlv_data.len() {
211        let tlv_indices = get_tlv_indices(start_index);
212        if tlv_data.len() < tlv_indices.length_start {
213            // There aren't enough bytes to store the next type, which means we
214            // got to the end. The last byte could be used during a realloc!
215            return Ok(tlv_indices.type_start);
216        }
217        let extension_type =
218            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
219        if extension_type == ExtensionType::Uninitialized {
220            return Ok(tlv_indices.type_start);
221        } else {
222            if tlv_data.len() < tlv_indices.value_start {
223                // not enough bytes to store the length, malformed
224                return Err(ProgramError::InvalidAccountData);
225            }
226            let length = bytemuck::try_from_bytes::<Length>(
227                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
228            )
229            .map_err(|_| ProgramError::InvalidArgument)?;
230
231            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
232            if value_end_index > tlv_data.len() {
233                // value blows past the size of the slice, malformed
234                return Err(ProgramError::InvalidAccountData);
235            }
236            f(extension_type)?;
237            start_index = value_end_index;
238        }
239    }
240    Ok(start_index)
241}
242
243/// Fetches basic information about the TLV buffer by iterating through all
244/// TLV entries.
245fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
246    let mut extension_types = vec![];
247    let used_len = try_for_each_tlv_extension_type(tlv_data, |extension_type| {
248        extension_types.push(extension_type);
249        Ok(())
250    })?;
251    Ok(TlvDataInfo {
252        extension_types,
253        used_len,
254    })
255}
256
257fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
258    if tlv_data.is_empty() {
259        Ok(None)
260    } else {
261        let tlv_indices = get_tlv_indices(0);
262        if tlv_data.len() <= tlv_indices.length_start {
263            return Ok(None);
264        }
265        let extension_type =
266            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
267        if extension_type == ExtensionType::Uninitialized {
268            Ok(None)
269        } else {
270            Ok(Some(extension_type))
271        }
272    }
273}
274
275fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
276    if input.len() == Multisig::LEN || input.len() < minimum_len {
277        Err(ProgramError::InvalidAccountData)
278    } else {
279        Ok(())
280    }
281}
282
283fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
284    if account_type != S::ACCOUNT_TYPE {
285        Err(ProgramError::InvalidAccountData)
286    } else {
287        Ok(())
288    }
289}
290
291/// Any account with extensions must be at least `Account::LEN`.  Both mints and
292/// accounts can have extensions
293/// A mint with extensions that takes it past 165 could be indiscernible from an
294/// Account with an extension, even if we add the account type. For example,
295/// let's say we have:
296///
297/// ```text
298/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
299///                          ^     ^       ^     ^
300///                     acct type  extension length data...
301///
302/// Mint: 82 bytes... + 83 bytes of other extension data
303///     + [2, 0, 3, 0, 100, ....]
304///      (data in extension just happens to look like this)
305/// ```
306///
307/// With this approach, we only start writing the TLV data after `Account::LEN`,
308/// which means we always know that the account type is going to be right after
309/// that. We do a special case checking for a Multisig length, because those
310/// aren't extensible under any circumstances.
311const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
312/// Helper that tacks on the `AccountType` length, which gives the minimum for
313/// any account with extensions
314const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
315
316fn type_and_tlv_indices<S: BaseState>(
317    rest_input: &[u8],
318) -> Result<Option<(usize, usize)>, ProgramError> {
319    if rest_input.is_empty() {
320        Ok(None)
321    } else {
322        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
323        // check padding is all zeroes
324        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
325        if rest_input.len() < tlv_start_index {
326            return Err(ProgramError::InvalidAccountData);
327        }
328        if rest_input[..account_type_index].iter().any(|&b| b != 0) {
329            Err(ProgramError::InvalidAccountData)
330        } else {
331            Ok(Some((account_type_index, tlv_start_index)))
332        }
333    }
334}
335
336/// Checks a base buffer to verify if it is an Account without having to
337/// completely deserialize it
338fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
339    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
340
341    if input.len() != BASE_ACCOUNT_LENGTH {
342        return Err(ProgramError::InvalidAccountData);
343    }
344    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
345}
346
347fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
348    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
349        return Err(ProgramError::InvalidAccountData);
350    }
351    let TlvIndices {
352        type_start: _,
353        length_start,
354        value_start,
355    } = get_extension_indices::<V>(tlv_data, false)?;
356    // get_extension_indices has checked that tlv_data is long enough to include
357    // these indices
358    let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
359        .map_err(|_| ProgramError::InvalidArgument)?;
360    let value_end = value_start.saturating_add(usize::from(*length));
361    if tlv_data.len() < value_end {
362        return Err(ProgramError::InvalidAccountData);
363    }
364    Ok(&tlv_data[value_start..value_end])
365}
366
367fn get_extension_bytes_mut<S: BaseState, V: Extension>(
368    tlv_data: &mut [u8],
369) -> Result<&mut [u8], ProgramError> {
370    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
371        return Err(ProgramError::InvalidAccountData);
372    }
373    let TlvIndices {
374        type_start: _,
375        length_start,
376        value_start,
377    } = get_extension_indices::<V>(tlv_data, false)?;
378    // get_extension_indices has checked that tlv_data is long enough to include
379    // these indices
380    let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
381        .map_err(|_| ProgramError::InvalidArgument)?;
382    let value_end = value_start.saturating_add(usize::from(*length));
383    if tlv_data.len() < value_end {
384        return Err(ProgramError::InvalidAccountData);
385    }
386    Ok(&mut tlv_data[value_start..value_end])
387}
388
389/// Calculate the new expected size if the state allocates the given number
390/// of bytes for the given extension type.
391///
392/// Provides the correct answer regardless if the extension is already present
393/// in the TLV data.
394fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
395    tlv_data: &[u8],
396    new_extension_len: usize,
397) -> Result<usize, ProgramError> {
398    // get the new length used by the extension
399    let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
400    let tlv_info = get_tlv_data_info(tlv_data)?;
401    // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
402    // and account type
403    let current_len = tlv_info
404        .used_len
405        .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
406    // get the current length used by the extension
407    let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
408        .map(|x| add_type_and_length_to_len(x.len()))
409        .unwrap_or(0);
410    let new_len = current_len
411        .saturating_sub(current_extension_len)
412        .saturating_add(new_extension_tlv_len);
413    Ok(adjust_len_for_multisig(new_len))
414}
415
416/// Trait for base state with extension
417pub trait BaseStateWithExtensions<S: BaseState> {
418    /// Get the buffer containing all extension data
419    fn get_tlv_data(&self) -> &[u8];
420
421    /// Fetch the bytes for a TLV entry
422    fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
423        get_extension_bytes::<S, V>(self.get_tlv_data())
424    }
425
426    /// Unpack a portion of the TLV data as the desired type
427    fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
428        bytemuck::try_from_bytes::<V>(self.get_extension_bytes::<V>()?)
429            .map_err(|_| ProgramError::InvalidArgument)
430    }
431
432    /// Unpacks a portion of the TLV data as the desired variable-length type
433    fn get_variable_len_extension<V: Extension + VariableLenPack>(
434        &self,
435    ) -> Result<V, ProgramError> {
436        let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
437        V::unpack_from_slice(data)
438    }
439
440    /// Iterates through the TLV entries, returning only the types
441    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
442        get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
443    }
444
445    /// Get just the first extension type, useful to track mixed initialization
446    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
447        get_first_extension_type(self.get_tlv_data())
448    }
449
450    /// Get the total number of bytes used by TLV entries and the base type
451    fn try_get_account_len(&self) -> Result<usize, ProgramError> {
452        let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
453        if tlv_info.extension_types.is_empty() {
454            Ok(S::SIZE_OF)
455        } else {
456            let total_len = tlv_info
457                .used_len
458                .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
459            Ok(adjust_len_for_multisig(total_len))
460        }
461    }
462    /// Calculate the new expected size if the state allocates the given
463    /// fixed-length extension instance.
464    /// If the state already has the extension, the resulting account length
465    /// will be unchanged.
466    fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
467        try_get_new_account_len_for_extension_len::<S, V>(self.get_tlv_data(), size_of::<V>())
468    }
469
470    /// Calculate the new expected size if the state allocates the given
471    /// variable-length extension instance.
472    fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
473        &self,
474        new_extension: &V,
475    ) -> Result<usize, ProgramError> {
476        try_get_new_account_len_for_extension_len::<S, V>(
477            self.get_tlv_data(),
478            new_extension.get_packed_len()?,
479        )
480    }
481}
482
483/// Encapsulates owned immutable base state data (mint or account) with possible
484/// extensions
485#[derive(Clone, Debug, PartialEq)]
486pub struct StateWithExtensionsOwned<S: BaseState> {
487    /// Unpacked base data
488    pub base: S,
489    /// Raw TLV data, deserialized on demand
490    tlv_data: Vec<u8>,
491}
492impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
493    /// Unpack base state, leaving the extension data as a slice
494    ///
495    /// Fails if the base state is not initialized.
496    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
497        check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
498        let mut rest = input.split_off(S::SIZE_OF);
499        let base = S::unpack(&input)?;
500        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
501            // type_and_tlv_indices() checks that returned indexes are within range
502            let account_type = AccountType::try_from(rest[account_type_index])
503                .map_err(|_| ProgramError::InvalidAccountData)?;
504            check_account_type::<S>(account_type)?;
505            let tlv_data = rest.split_off(tlv_start_index);
506            Ok(Self { base, tlv_data })
507        } else {
508            Ok(Self {
509                base,
510                tlv_data: vec![],
511            })
512        }
513    }
514}
515
516impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
517    fn get_tlv_data(&self) -> &[u8] {
518        &self.tlv_data
519    }
520}
521
522/// Encapsulates immutable base state data (mint or account) with possible
523/// extensions
524#[derive(Debug, PartialEq)]
525pub struct StateWithExtensions<'data, S: BaseState + Pack> {
526    /// Unpacked base data
527    pub base: S,
528    /// Slice of data containing all TLV data, deserialized on demand
529    tlv_data: &'data [u8],
530}
531impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
532    /// Unpack base state, leaving the extension data as a slice
533    ///
534    /// Fails if the base state is not initialized.
535    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
536        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
537        let (base_data, rest) = input.split_at(S::SIZE_OF);
538        let base = S::unpack(base_data)?;
539        let tlv_data = unpack_tlv_data::<S>(rest)?;
540        Ok(Self { base, tlv_data })
541    }
542}
543impl<S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'_, S> {
544    fn get_tlv_data(&self) -> &[u8] {
545        self.tlv_data
546    }
547}
548
549/// Encapsulates immutable base state data (mint or account) with possible
550/// extensions, where the base state is Pod for zero-copy serde.
551#[derive(Debug, PartialEq)]
552pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
553    /// Unpacked base data
554    pub base: &'data S,
555    /// Slice of data containing all TLV data, deserialized on demand
556    tlv_data: &'data [u8],
557}
558impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
559    /// Unpack base state, leaving the extension data as a slice
560    ///
561    /// Fails if the base state is not initialized.
562    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
563        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
564        let (base_data, rest) = input.split_at(S::SIZE_OF);
565        let base =
566            bytemuck::try_from_bytes::<S>(base_data).map_err(|_| ProgramError::InvalidArgument)?;
567        if !base.is_initialized() {
568            Err(ProgramError::UninitializedAccount)
569        } else {
570            let tlv_data = unpack_tlv_data::<S>(rest)?;
571            Ok(Self { base, tlv_data })
572        }
573    }
574}
575impl<S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'_, S> {
576    fn get_tlv_data(&self) -> &[u8] {
577        self.tlv_data
578    }
579}
580
581/// Trait for mutable base state with extension
582pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
583    /// Get the underlying TLV data as mutable
584    fn get_tlv_data_mut(&mut self) -> &mut [u8];
585
586    /// Get the underlying account type as mutable
587    fn get_account_type_mut(&mut self) -> &mut [u8];
588
589    /// Unpack a portion of the TLV data as the base mutable bytes
590    fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
591        get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
592    }
593
594    /// Unpack a portion of the TLV data as the desired type that allows
595    /// modifying the type
596    fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
597        bytemuck::try_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
598            .map_err(|_| ProgramError::InvalidArgument)
599    }
600
601    /// Packs a variable-length extension into its appropriate data segment.
602    /// Fails if space hasn't already been allocated for the given extension
603    fn pack_variable_len_extension<V: Extension + VariableLenPack>(
604        &mut self,
605        extension: &V,
606    ) -> Result<(), ProgramError> {
607        let data = self.get_extension_bytes_mut::<V>()?;
608        // NOTE: Do *not* use `pack`, since the length check will cause
609        // reallocations to smaller sizes to fail
610        extension.pack_into_slice(data)
611    }
612
613    /// Packs the default extension data into an open slot if not already found
614    /// in the data buffer. If extension is already found in the buffer, it
615    /// overwrites the existing extension with the default state if
616    /// `overwrite` is set. If extension found, but `overwrite` is not set,
617    /// it returns error.
618    fn init_extension<V: Extension + Pod + Default>(
619        &mut self,
620        overwrite: bool,
621    ) -> Result<&mut V, ProgramError> {
622        let length = size_of::<V>();
623        let buffer = self.alloc::<V>(length, overwrite)?;
624        let extension_ref =
625            bytemuck::try_from_bytes_mut::<V>(buffer).map_err(|_| ProgramError::InvalidArgument)?;
626        *extension_ref = V::default();
627        Ok(extension_ref)
628    }
629
630    /// Reallocate and overwrite the TLV entry for the given variable-length
631    /// extension.
632    ///
633    /// Returns an error if the extension is not present, or if there is not
634    /// enough space in the buffer.
635    fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
636        &mut self,
637        new_extension: &V,
638    ) -> Result<(), ProgramError> {
639        let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
640        new_extension.pack_into_slice(data)
641    }
642
643    /// Reallocate the TLV entry for the given extension to the given number of
644    /// bytes.
645    ///
646    /// If the new length is smaller, it will compact the rest of the buffer and
647    /// zero out the difference at the end. If it's larger, it will move the
648    /// rest of the buffer data and zero out the new data.
649    ///
650    /// Returns an error if the extension is not present, or if this is not
651    /// enough space in the buffer.
652    fn realloc<V: Extension + VariableLenPack>(
653        &mut self,
654        length: usize,
655    ) -> Result<&mut [u8], ProgramError> {
656        let tlv_data = self.get_tlv_data_mut();
657        let TlvIndices {
658            type_start: _,
659            length_start,
660            value_start,
661        } = get_extension_indices::<V>(tlv_data, false)?;
662        let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
663        let data_len = tlv_data.len();
664
665        let length_ref =
666            bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
667                .map_err(|_| ProgramError::InvalidArgument)?;
668        let old_length = usize::from(*length_ref);
669
670        // Length check to avoid a panic later in `copy_within`
671        if old_length < length {
672            let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
673            if new_tlv_len > data_len {
674                return Err(ProgramError::InvalidAccountData);
675            }
676        }
677
678        // write new length after the check, to avoid getting into a bad situation
679        // if trying to recover from an error
680        *length_ref = Length::try_from(length)?;
681
682        let old_value_end = value_start.saturating_add(old_length);
683        let new_value_end = value_start.saturating_add(length);
684        tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
685        match old_length.cmp(&length) {
686            Ordering::Greater => {
687                // realloc to smaller, zero out the end
688                let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
689                tlv_data[new_tlv_len..tlv_len].fill(0);
690            }
691            Ordering::Less => {
692                // realloc to bigger, zero out the new bytes
693                tlv_data[old_value_end..new_value_end].fill(0);
694            }
695            Ordering::Equal => {} // nothing needed!
696        }
697
698        Ok(&mut tlv_data[value_start..new_value_end])
699    }
700
701    /// Allocate the given number of bytes for the given variable-length
702    /// extension and write its contents into the TLV buffer.
703    ///
704    /// This can only be used for variable-sized types, such as `String` or
705    /// `Vec`. `Pod` types must use `init_extension`
706    fn init_variable_len_extension<V: Extension + VariableLenPack>(
707        &mut self,
708        extension: &V,
709        overwrite: bool,
710    ) -> Result<(), ProgramError> {
711        let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
712        extension.pack_into_slice(data)
713    }
714
715    /// Allocate some space for the extension in the TLV data
716    fn alloc<V: Extension>(
717        &mut self,
718        length: usize,
719        overwrite: bool,
720    ) -> Result<&mut [u8], ProgramError> {
721        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
722            return Err(ProgramError::InvalidAccountData);
723        }
724        let tlv_data = self.get_tlv_data_mut();
725        let TlvIndices {
726            type_start,
727            length_start,
728            value_start,
729        } = get_extension_indices::<V>(tlv_data, true)?;
730
731        if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
732            return Err(ProgramError::InvalidAccountData);
733        }
734        let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
735
736        if extension_type == ExtensionType::Uninitialized || overwrite {
737            // write extension type
738            let extension_type_array: [u8; 2] = V::TYPE.into();
739            let extension_type_ref = &mut tlv_data[type_start..length_start];
740            extension_type_ref.copy_from_slice(&extension_type_array);
741            // write length
742            let length_ref =
743                bytemuck::try_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])
744                    .map_err(|_| ProgramError::InvalidArgument)?;
745
746            // check that the length is the same if we're doing an alloc
747            // with overwrite, otherwise a realloc should be done
748            if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
749                return Err(TokenError::InvalidLengthForAlloc.into());
750            }
751
752            *length_ref = Length::try_from(length)?;
753
754            let value_end = value_start.saturating_add(length);
755            Ok(&mut tlv_data[value_start..value_end])
756        } else {
757            // extension is already initialized, but no overwrite permission
758            Err(TokenError::ExtensionAlreadyInitialized.into())
759        }
760    }
761
762    /// If `extension_type` is an Account-associated `ExtensionType` that
763    /// requires initialization on `InitializeAccount`, this method packs
764    /// the default relevant `Extension` of an `ExtensionType` into an open
765    /// slot if not already found in the data buffer, otherwise overwrites
766    /// the existing extension with the default state. For all other
767    /// `ExtensionType`s, this is a no-op.
768    fn init_account_extension_from_type(
769        &mut self,
770        extension_type: ExtensionType,
771    ) -> Result<(), ProgramError> {
772        if extension_type.get_account_type() != AccountType::Account {
773            return Ok(());
774        }
775        match extension_type {
776            ExtensionType::TransferFeeAmount => {
777                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
778            }
779            ExtensionType::ImmutableOwner => {
780                self.init_extension::<ImmutableOwner>(true).map(|_| ())
781            }
782            ExtensionType::NonTransferableAccount => self
783                .init_extension::<NonTransferableAccount>(true)
784                .map(|_| ()),
785            ExtensionType::TransferHookAccount => {
786                self.init_extension::<TransferHookAccount>(true).map(|_| ())
787            }
788            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
789            // on InitializeAccount
790            ExtensionType::ConfidentialTransferAccount => Ok(()),
791            ExtensionType::PausableAccount => {
792                self.init_extension::<PausableAccount>(true).map(|_| ())
793            }
794            #[cfg(test)]
795            ExtensionType::AccountPaddingTest => {
796                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
797            }
798            _ => unreachable!(),
799        }
800    }
801
802    /// Write the account type into the buffer, done during the base
803    /// state initialization
804    /// Noops if there is no room for an extension in the account, needed for
805    /// pure base mints / accounts.
806    fn init_account_type(&mut self) -> Result<(), ProgramError> {
807        let first_extension_type = self.get_first_extension_type()?;
808        let account_type = self.get_account_type_mut();
809        if !account_type.is_empty() {
810            if let Some(extension_type) = first_extension_type {
811                let account_type = extension_type.get_account_type();
812                if account_type != S::ACCOUNT_TYPE {
813                    return Err(TokenError::ExtensionBaseMismatch.into());
814                }
815            }
816            account_type[0] = S::ACCOUNT_TYPE.into();
817        }
818        Ok(())
819    }
820
821    /// Check that the account type on the account (if initialized) matches the
822    /// account type for any extensions initialized on the TLV data
823    fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
824        if let Some(extension_type) = self.get_first_extension_type()? {
825            let account_type = extension_type.get_account_type();
826            if account_type != S::ACCOUNT_TYPE {
827                return Err(TokenError::ExtensionBaseMismatch.into());
828            }
829        }
830        Ok(())
831    }
832}
833
834/// Encapsulates mutable base state data (mint or account) with possible
835/// extensions
836#[derive(Debug, PartialEq)]
837pub struct StateWithExtensionsMut<'data, S: BaseState> {
838    /// Unpacked base data
839    pub base: S,
840    /// Raw base data
841    base_data: &'data mut [u8],
842    /// Writable account type
843    account_type: &'data mut [u8],
844    /// Slice of data containing all TLV data, deserialized on demand
845    tlv_data: &'data mut [u8],
846}
847impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
848    /// Unpack base state, leaving the extension data as a mutable slice
849    ///
850    /// Fails if the base state is not initialized.
851    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
852        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
853        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
854        let base = S::unpack(base_data)?;
855        let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
856        Ok(Self {
857            base,
858            base_data,
859            account_type,
860            tlv_data,
861        })
862    }
863
864    /// Unpack an uninitialized base state, leaving the extension data as a
865    /// mutable slice
866    ///
867    /// Fails if the base state has already been initialized.
868    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
869        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
870        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
871        let base = S::unpack_unchecked(base_data)?;
872        if base.is_initialized() {
873            return Err(TokenError::AlreadyInUse.into());
874        }
875        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
876        let state = Self {
877            base,
878            base_data,
879            account_type,
880            tlv_data,
881        };
882        state.check_account_type_matches_extension_type()?;
883        Ok(state)
884    }
885
886    /// Packs base state data into the base data portion
887    pub fn pack_base(&mut self) {
888        S::pack_into_slice(&self.base, self.base_data);
889    }
890}
891impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'_, S> {
892    fn get_tlv_data(&self) -> &[u8] {
893        self.tlv_data
894    }
895}
896impl<S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'_, S> {
897    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
898        self.tlv_data
899    }
900    fn get_account_type_mut(&mut self) -> &mut [u8] {
901        self.account_type
902    }
903}
904
905/// Encapsulates mutable base state data (mint or account) with possible
906/// extensions, where the base state is Pod for zero-copy serde.
907#[derive(Debug, PartialEq)]
908pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
909    /// Unpacked base data
910    pub base: &'data mut S,
911    /// Writable account type
912    account_type: &'data mut [u8],
913    /// Slice of data containing all TLV data, deserialized on demand
914    tlv_data: &'data mut [u8],
915}
916impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
917    /// Unpack base state, leaving the extension data as a mutable slice
918    ///
919    /// Fails if the base state is not initialized.
920    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
921        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
922        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
923        let base = bytemuck::try_from_bytes_mut::<S>(base_data)
924            .map_err(|_| ProgramError::InvalidArgument)?;
925        if !base.is_initialized() {
926            Err(ProgramError::UninitializedAccount)
927        } else {
928            let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
929            Ok(Self {
930                base,
931                account_type,
932                tlv_data,
933            })
934        }
935    }
936
937    /// Unpack an uninitialized base state, leaving the extension data as a
938    /// mutable slice
939    ///
940    /// Fails if the base state has already been initialized.
941    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
942        check_min_len_and_not_multisig(input, S::SIZE_OF)?;
943        let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
944        let base = bytemuck::try_from_bytes_mut::<S>(base_data)
945            .map_err(|_| ProgramError::InvalidArgument)?;
946        if base.is_initialized() {
947            return Err(TokenError::AlreadyInUse.into());
948        }
949        let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
950        let state = Self {
951            base,
952            account_type,
953            tlv_data,
954        };
955        state.check_account_type_matches_extension_type()?;
956        Ok(state)
957    }
958}
959
960impl<S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'_, S> {
961    fn get_tlv_data(&self) -> &[u8] {
962        self.tlv_data
963    }
964}
965impl<S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'_, S> {
966    fn get_tlv_data_mut(&mut self) -> &mut [u8] {
967        self.tlv_data
968    }
969    fn get_account_type_mut(&mut self) -> &mut [u8] {
970        self.account_type
971    }
972}
973
974fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
975    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
976        // type_and_tlv_indices() checks that returned indexes are within range
977        let account_type = AccountType::try_from(rest[account_type_index])
978            .map_err(|_| ProgramError::InvalidAccountData)?;
979        check_account_type::<S>(account_type)?;
980        Ok(&rest[tlv_start_index..])
981    } else {
982        Ok(&[])
983    }
984}
985
986fn unpack_type_and_tlv_data_with_check_mut<
987    S: BaseState,
988    F: Fn(AccountType) -> Result<(), ProgramError>,
989>(
990    rest: &mut [u8],
991    check_fn: F,
992) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
993    if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
994        // type_and_tlv_indices() checks that returned indexes are within range
995        let account_type = AccountType::try_from(rest[account_type_index])
996            .map_err(|_| ProgramError::InvalidAccountData)?;
997        check_fn(account_type)?;
998        let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
999        Ok((
1000            &mut account_type[account_type_index..tlv_start_index],
1001            tlv_data,
1002        ))
1003    } else {
1004        Ok((&mut [], &mut []))
1005    }
1006}
1007
1008fn unpack_type_and_tlv_data_mut<S: BaseState>(
1009    rest: &mut [u8],
1010) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1011    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
1012}
1013
1014fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
1015    rest: &mut [u8],
1016) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
1017    unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
1018        if account_type != AccountType::Uninitialized {
1019            Err(ProgramError::InvalidAccountData)
1020        } else {
1021            Ok(())
1022        }
1023    })
1024}
1025
1026/// If `AccountType` is uninitialized, set it to the `BaseState`'s
1027/// `ACCOUNT_TYPE`; if `AccountType` is already set, check is set correctly for
1028/// `BaseState`. This method assumes that the `base_data` has already been
1029/// packed with data of the desired type.
1030pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
1031    check_min_len_and_not_multisig(input, S::SIZE_OF)?;
1032    let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
1033    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
1034        return Err(ProgramError::InvalidAccountData);
1035    }
1036    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
1037        let mut account_type = AccountType::try_from(rest[account_type_index])
1038            .map_err(|_| ProgramError::InvalidAccountData)?;
1039        if account_type == AccountType::Uninitialized {
1040            rest[account_type_index] = S::ACCOUNT_TYPE.into();
1041            account_type = S::ACCOUNT_TYPE;
1042        }
1043        check_account_type::<S>(account_type)?;
1044        Ok(())
1045    } else {
1046        Err(ProgramError::InvalidAccountData)
1047    }
1048}
1049
1050/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig`
1051/// types are determined exclusively by the size of the account, and are not
1052/// included in the account data. `AccountType` is only included if extensions
1053/// have been initialized.
1054#[repr(u8)]
1055#[derive(Clone, Copy, Debug, Default, PartialEq, TryFromPrimitive, IntoPrimitive)]
1056pub enum AccountType {
1057    /// Marker for 0 data
1058    #[default]
1059    Uninitialized,
1060    /// Mint account with additional extensions
1061    Mint,
1062    /// Token holding account with additional extensions
1063    Account,
1064}
1065
1066/// Extensions that can be applied to mints or accounts.  Mint extensions must
1067/// only be applied to mint accounts, and account extensions must only be
1068/// applied to token holding accounts.
1069#[repr(u16)]
1070#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1071#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
1072#[cfg_attr(test, derive(strum_macros::EnumIter))]
1073#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
1074pub enum ExtensionType {
1075    /// Used as padding if the account size would otherwise be 355, same as a
1076    /// multisig
1077    Uninitialized,
1078    /// Includes transfer fee rate info and accompanying authorities to withdraw
1079    /// and set the fee
1080    TransferFeeConfig,
1081    /// Includes withheld transfer fees
1082    TransferFeeAmount,
1083    /// Includes an optional mint close authority
1084    MintCloseAuthority,
1085    /// Auditor configuration for confidential transfers
1086    ConfidentialTransferMint,
1087    /// State for confidential transfers
1088    ConfidentialTransferAccount,
1089    /// Specifies the default Account::state for new Accounts
1090    DefaultAccountState,
1091    /// Indicates that the Account owner authority cannot be changed
1092    ImmutableOwner,
1093    /// Require inbound transfers to have memo
1094    MemoTransfer,
1095    /// Indicates that the tokens from this mint can't be transferred
1096    NonTransferable,
1097    /// Tokens accrue interest over time,
1098    InterestBearingConfig,
1099    /// Locks privileged token operations from happening via CPI
1100    CpiGuard,
1101    /// Includes an optional permanent delegate
1102    PermanentDelegate,
1103    /// Indicates that the tokens in this account belong to a non-transferable
1104    /// mint
1105    NonTransferableAccount,
1106    /// Mint requires a CPI to a program implementing the "transfer hook"
1107    /// interface
1108    TransferHook,
1109    /// Indicates that the tokens in this account belong to a mint with a
1110    /// transfer hook
1111    TransferHookAccount,
1112    /// Includes encrypted withheld fees and the encryption public that they are
1113    /// encrypted under
1114    ConfidentialTransferFeeConfig,
1115    /// Includes confidential withheld transfer fees
1116    ConfidentialTransferFeeAmount,
1117    /// Mint contains a pointer to another account (or the same account) that
1118    /// holds metadata
1119    MetadataPointer,
1120    /// Mint contains token-metadata
1121    TokenMetadata,
1122    /// Mint contains a pointer to another account (or the same account) that
1123    /// holds group configurations
1124    GroupPointer,
1125    /// Mint contains token group configurations
1126    TokenGroup,
1127    /// Mint contains a pointer to another account (or the same account) that
1128    /// holds group member configurations
1129    GroupMemberPointer,
1130    /// Mint contains token group member configurations
1131    TokenGroupMember,
1132    /// Mint allowing the minting and burning of confidential tokens
1133    ConfidentialMintBurn,
1134    /// Tokens whose UI amount is scaled by a given amount
1135    ScaledUiAmount,
1136    /// Tokens where minting / burning / transferring can be paused
1137    Pausable,
1138    /// Indicates that the account belongs to a pausable mint
1139    PausableAccount,
1140    /// Tokens burning requires approval from authority.
1141    PermissionedBurn,
1142
1143    /// Test variable-length mint extension
1144    #[cfg(test)]
1145    VariableLenMintTest = u16::MAX - 2,
1146    /// Padding extension used to make an account exactly Multisig::LEN, used
1147    /// for testing
1148    #[cfg(test)]
1149    AccountPaddingTest,
1150    /// Padding extension used to make a mint exactly Multisig::LEN, used for
1151    /// testing
1152    #[cfg(test)]
1153    MintPaddingTest,
1154}
1155impl TryFrom<&[u8]> for ExtensionType {
1156    type Error = ProgramError;
1157    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
1158        Self::try_from(u16::from_le_bytes(
1159            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
1160        ))
1161        .map_err(|_| ProgramError::InvalidAccountData)
1162    }
1163}
1164impl From<ExtensionType> for [u8; 2] {
1165    fn from(a: ExtensionType) -> Self {
1166        u16::from(a).to_le_bytes()
1167    }
1168}
1169impl ExtensionType {
1170    /// Returns true if the given extension type is sized
1171    ///
1172    /// Most extension types should be sized, so any variable-length extension
1173    /// types should be added here by hand
1174    const fn sized(&self) -> bool {
1175        match self {
1176            ExtensionType::TokenMetadata => false,
1177            #[cfg(test)]
1178            ExtensionType::VariableLenMintTest => false,
1179            _ => true,
1180        }
1181    }
1182
1183    /// Get the data length of the type associated with the enum
1184    ///
1185    /// Fails if the extension type has a variable length
1186    fn try_get_type_len(&self) -> Result<usize, ProgramError> {
1187        if !self.sized() {
1188            return Err(ProgramError::InvalidArgument);
1189        }
1190        Ok(match self {
1191            ExtensionType::Uninitialized => 0,
1192            ExtensionType::TransferFeeConfig => size_of::<TransferFeeConfig>(),
1193            ExtensionType::TransferFeeAmount => size_of::<TransferFeeAmount>(),
1194            ExtensionType::MintCloseAuthority => size_of::<MintCloseAuthority>(),
1195            ExtensionType::ImmutableOwner => size_of::<ImmutableOwner>(),
1196            ExtensionType::ConfidentialTransferMint => size_of::<ConfidentialTransferMint>(),
1197            ExtensionType::ConfidentialTransferAccount => size_of::<ConfidentialTransferAccount>(),
1198            ExtensionType::DefaultAccountState => size_of::<DefaultAccountState>(),
1199            ExtensionType::MemoTransfer => size_of::<MemoTransfer>(),
1200            ExtensionType::NonTransferable => size_of::<NonTransferable>(),
1201            ExtensionType::InterestBearingConfig => size_of::<InterestBearingConfig>(),
1202            ExtensionType::CpiGuard => size_of::<CpiGuard>(),
1203            ExtensionType::PermanentDelegate => size_of::<PermanentDelegate>(),
1204            ExtensionType::NonTransferableAccount => size_of::<NonTransferableAccount>(),
1205            ExtensionType::TransferHook => size_of::<TransferHook>(),
1206            ExtensionType::TransferHookAccount => size_of::<TransferHookAccount>(),
1207            ExtensionType::ConfidentialTransferFeeConfig => {
1208                size_of::<ConfidentialTransferFeeConfig>()
1209            }
1210            ExtensionType::ConfidentialTransferFeeAmount => {
1211                size_of::<ConfidentialTransferFeeAmount>()
1212            }
1213            ExtensionType::MetadataPointer => size_of::<MetadataPointer>(),
1214            ExtensionType::TokenMetadata => unreachable!(),
1215            ExtensionType::GroupPointer => size_of::<GroupPointer>(),
1216            ExtensionType::TokenGroup => size_of::<TokenGroup>(),
1217            ExtensionType::GroupMemberPointer => size_of::<GroupMemberPointer>(),
1218            ExtensionType::TokenGroupMember => size_of::<TokenGroupMember>(),
1219            ExtensionType::ConfidentialMintBurn => size_of::<ConfidentialMintBurn>(),
1220            ExtensionType::ScaledUiAmount => size_of::<ScaledUiAmountConfig>(),
1221            ExtensionType::Pausable => size_of::<PausableConfig>(),
1222            ExtensionType::PausableAccount => size_of::<PausableAccount>(),
1223            ExtensionType::PermissionedBurn => size_of::<PermissionedBurnConfig>(),
1224            #[cfg(test)]
1225            ExtensionType::AccountPaddingTest => size_of::<AccountPaddingTest>(),
1226            #[cfg(test)]
1227            ExtensionType::MintPaddingTest => size_of::<MintPaddingTest>(),
1228            #[cfg(test)]
1229            ExtensionType::VariableLenMintTest => unreachable!(),
1230        })
1231    }
1232
1233    /// Get the TLV length for an `ExtensionType`
1234    ///
1235    /// Fails if the extension type has a variable length
1236    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
1237        Ok(add_type_and_length_to_len(self.try_get_type_len()?))
1238    }
1239
1240    /// Get the required account data length for the given `ExtensionType`s
1241    ///
1242    /// Fails if any of the extension types has a variable length
1243    pub fn try_calculate_account_len<S: BaseState>(
1244        extension_types: &[Self],
1245    ) -> Result<usize, ProgramError> {
1246        let mut tlv_len = TlvLenAccumulator::default();
1247        for &extension_type in extension_types {
1248            tlv_len.insert(extension_type)?;
1249        }
1250        Ok(tlv_len.account_len::<S>())
1251    }
1252
1253    /// Get the associated account type
1254    pub fn get_account_type(&self) -> AccountType {
1255        match self {
1256            ExtensionType::Uninitialized => AccountType::Uninitialized,
1257            ExtensionType::TransferFeeConfig
1258            | ExtensionType::MintCloseAuthority
1259            | ExtensionType::ConfidentialTransferMint
1260            | ExtensionType::DefaultAccountState
1261            | ExtensionType::NonTransferable
1262            | ExtensionType::InterestBearingConfig
1263            | ExtensionType::PermanentDelegate
1264            | ExtensionType::TransferHook
1265            | ExtensionType::ConfidentialTransferFeeConfig
1266            | ExtensionType::MetadataPointer
1267            | ExtensionType::TokenMetadata
1268            | ExtensionType::GroupPointer
1269            | ExtensionType::TokenGroup
1270            | ExtensionType::GroupMemberPointer
1271            | ExtensionType::ConfidentialMintBurn
1272            | ExtensionType::TokenGroupMember
1273            | ExtensionType::ScaledUiAmount
1274            | ExtensionType::Pausable
1275            | ExtensionType::PermissionedBurn => AccountType::Mint,
1276            ExtensionType::ImmutableOwner
1277            | ExtensionType::TransferFeeAmount
1278            | ExtensionType::ConfidentialTransferAccount
1279            | ExtensionType::MemoTransfer
1280            | ExtensionType::NonTransferableAccount
1281            | ExtensionType::TransferHookAccount
1282            | ExtensionType::CpiGuard
1283            | ExtensionType::ConfidentialTransferFeeAmount
1284            | ExtensionType::PausableAccount => AccountType::Account,
1285            #[cfg(test)]
1286            ExtensionType::VariableLenMintTest => AccountType::Mint,
1287            #[cfg(test)]
1288            ExtensionType::AccountPaddingTest => AccountType::Account,
1289            #[cfg(test)]
1290            ExtensionType::MintPaddingTest => AccountType::Mint,
1291        }
1292    }
1293
1294    /// Based on an `AccountType::Mint` `ExtensionType`, get the
1295    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1296    fn required_init_account_extensions(&self) -> &'static [Self] {
1297        match self {
1298            ExtensionType::TransferFeeConfig => &[ExtensionType::TransferFeeAmount],
1299            ExtensionType::NonTransferable => &[
1300                ExtensionType::NonTransferableAccount,
1301                ExtensionType::ImmutableOwner,
1302            ],
1303            ExtensionType::TransferHook => &[ExtensionType::TransferHookAccount],
1304            ExtensionType::Pausable => &[ExtensionType::PausableAccount],
1305            #[cfg(test)]
1306            ExtensionType::MintPaddingTest => &[ExtensionType::AccountPaddingTest],
1307            _ => &[],
1308        }
1309    }
1310
1311    /// Based on a set of `AccountType::Mint` `ExtensionType`s, get the list of
1312    /// `AccountType::Account` `ExtensionType`s required on `InitializeAccount`
1313    #[deprecated(
1314        since = "3.0.1",
1315        note = "Use `account_len::try_for_each_required_init_account_extension` instead"
1316    )]
1317    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
1318        mint_extension_types
1319            .iter()
1320            .flat_map(|extension_type| extension_type.required_init_account_extensions())
1321            .copied()
1322            .collect()
1323    }
1324
1325    /// Check for invalid combination of mint extensions
1326    pub fn check_for_invalid_mint_extension_combinations(
1327        mint_extension_types: &[Self],
1328    ) -> Result<(), TokenError> {
1329        let mut transfer_fee_config = false;
1330        let mut confidential_transfer_mint = false;
1331        let mut confidential_transfer_fee_config = false;
1332        let mut confidential_mint_burn = false;
1333        let mut interest_bearing = false;
1334        let mut scaled_ui_amount = false;
1335        let mut non_transferable = false;
1336
1337        for extension_type in mint_extension_types {
1338            match extension_type {
1339                ExtensionType::TransferFeeConfig => transfer_fee_config = true,
1340                ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
1341                ExtensionType::ConfidentialTransferFeeConfig => {
1342                    confidential_transfer_fee_config = true
1343                }
1344                ExtensionType::ConfidentialMintBurn => confidential_mint_burn = true,
1345                ExtensionType::InterestBearingConfig => interest_bearing = true,
1346                ExtensionType::ScaledUiAmount => scaled_ui_amount = true,
1347                ExtensionType::NonTransferable => non_transferable = true,
1348                _ => (),
1349            }
1350        }
1351
1352        if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
1353        {
1354            return Err(TokenError::InvalidExtensionCombination);
1355        }
1356
1357        if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
1358            return Err(TokenError::InvalidExtensionCombination);
1359        }
1360
1361        if confidential_mint_burn && !confidential_transfer_mint {
1362            return Err(TokenError::InvalidExtensionCombination);
1363        }
1364
1365        if scaled_ui_amount && interest_bearing {
1366            return Err(TokenError::InvalidExtensionCombination);
1367        }
1368
1369        if non_transferable && confidential_transfer_mint && !confidential_mint_burn {
1370            return Err(TokenError::InvalidExtensionCombination);
1371        }
1372
1373        Ok(())
1374    }
1375}
1376
1377/// Trait for base states, specifying the associated enum
1378pub trait BaseState: PackedSizeOf + IsInitialized {
1379    /// Associated extension type enum, checked at the start of TLV entries
1380    const ACCOUNT_TYPE: AccountType;
1381}
1382impl BaseState for Account {
1383    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1384}
1385impl BaseState for Mint {
1386    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1387}
1388impl BaseState for PodAccount {
1389    const ACCOUNT_TYPE: AccountType = AccountType::Account;
1390}
1391impl BaseState for PodMint {
1392    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
1393}
1394
1395/// Trait to be implemented by all extension states, specifying which extension
1396/// and account type they are associated with
1397pub trait Extension {
1398    /// Associated extension type enum, checked at the start of TLV entries
1399    const TYPE: ExtensionType;
1400}
1401
1402/// Padding a mint account to be exactly `Multisig::LEN`.
1403/// We need to pad 185 bytes, since `Multisig::LEN = 355`, `Account::LEN = 165`,
1404/// `size_of::<AccountType>() = 1`, `size_of::<ExtensionType>() = 2`,
1405/// `size_of::<Length>() = 2`.
1406///
1407/// ```
1408/// assert_eq!(355 - 165 - 1 - 2 - 2, 185);
1409/// ```
1410#[cfg(test)]
1411#[repr(C)]
1412#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
1413pub struct MintPaddingTest {
1414    /// Largest value under 185 that implements Pod
1415    pub padding1: [u8; 128],
1416    /// Largest value under 57 that implements Pod
1417    pub padding2: [u8; 48],
1418    /// Exact value needed to finish the padding
1419    pub padding3: [u8; 9],
1420}
1421#[cfg(test)]
1422impl Extension for MintPaddingTest {
1423    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1424}
1425#[cfg(test)]
1426impl Default for MintPaddingTest {
1427    fn default() -> Self {
1428        Self {
1429            padding1: [1; 128],
1430            padding2: [2; 48],
1431            padding3: [3; 9],
1432        }
1433    }
1434}
1435/// Account version of the `MintPadding`
1436#[cfg(test)]
1437#[repr(C)]
1438#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1439pub struct AccountPaddingTest(MintPaddingTest);
1440#[cfg(test)]
1441impl Extension for AccountPaddingTest {
1442    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
1443}
1444
1445/// Packs a fixed-length extension into a TLV space
1446///
1447/// This function reallocates the account as needed to accommodate for the
1448/// change in space.
1449///
1450/// If the extension already exists, it will overwrite the existing extension
1451/// if `overwrite` is `true`, otherwise it will return an error.
1452///
1453/// If the extension does not exist, it will reallocate the account and write
1454/// the extension into the TLV buffer.
1455///
1456/// NOTE: Since this function deals with fixed-size extensions, it does not
1457/// handle _decreasing_ the size of an account's data buffer, like the function
1458/// `alloc_and_serialize_variable_len_extension` does.
1459pub fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
1460    account_info: &AccountInfo,
1461    new_extension: &V,
1462    overwrite: bool,
1463) -> Result<(), ProgramError> {
1464    let previous_account_len = account_info.try_data_len()?;
1465    let new_account_len = {
1466        let data = account_info.try_borrow_data()?;
1467        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1468        state.try_get_new_account_len::<V>()?
1469    };
1470
1471    // Realloc the account first, if needed
1472    if new_account_len > previous_account_len {
1473        account_info.resize(new_account_len)?;
1474    }
1475    let mut buffer = account_info.try_borrow_mut_data()?;
1476    if previous_account_len <= BASE_ACCOUNT_LENGTH {
1477        set_account_type::<S>(*buffer)?;
1478    }
1479    let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1480
1481    // Write the extension
1482    let extension = state.init_extension::<V>(overwrite)?;
1483    *extension = *new_extension;
1484
1485    Ok(())
1486}
1487
1488/// Packs a variable-length extension into a TLV space
1489///
1490/// This function reallocates the account as needed to accommodate for the
1491/// change in space, then reallocates in the TLV buffer, and finally writes the
1492/// bytes.
1493///
1494/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1495/// size of an account if it has too many bytes allocated for the given value.
1496pub fn alloc_and_serialize_variable_len_extension<
1497    S: BaseState + Pod,
1498    V: Extension + VariableLenPack,
1499>(
1500    account_info: &AccountInfo,
1501    new_extension: &V,
1502    overwrite: bool,
1503) -> Result<(), ProgramError> {
1504    let previous_account_len = account_info.try_data_len()?;
1505    let (new_account_len, extension_already_exists) = {
1506        let data = account_info.try_borrow_data()?;
1507        let state = PodStateWithExtensions::<S>::unpack(&data)?;
1508        let new_account_len =
1509            state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
1510        let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
1511        (new_account_len, extension_already_exists)
1512    };
1513
1514    if extension_already_exists && !overwrite {
1515        return Err(TokenError::ExtensionAlreadyInitialized.into());
1516    }
1517
1518    if previous_account_len < new_account_len {
1519        // account size increased, so realloc the account, then the TLV entry, then
1520        // write data
1521        account_info.resize(new_account_len)?;
1522        let mut buffer = account_info.try_borrow_mut_data()?;
1523        if extension_already_exists {
1524            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1525            state.realloc_variable_len_extension(new_extension)?;
1526        } else {
1527            if previous_account_len <= BASE_ACCOUNT_LENGTH {
1528                set_account_type::<S>(*buffer)?;
1529            }
1530            // now alloc in the TLV buffer and write the data
1531            let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1532            state.init_variable_len_extension(new_extension, false)?;
1533        }
1534    } else {
1535        // do it backwards otherwise, write the state, realloc TLV, then the account
1536        let mut buffer = account_info.try_borrow_mut_data()?;
1537        let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
1538        if extension_already_exists {
1539            state.realloc_variable_len_extension(new_extension)?;
1540        } else {
1541            // this situation can happen if we have an overallocated buffer
1542            state.init_variable_len_extension(new_extension, false)?;
1543        }
1544
1545        let removed_bytes = previous_account_len
1546            .checked_sub(new_account_len)
1547            .ok_or(ProgramError::AccountDataTooSmall)?;
1548        if removed_bytes > 0 {
1549            // this is probably fine, but be safe and avoid invalidating references
1550            drop(buffer);
1551            account_info.resize(new_account_len)?;
1552        }
1553    }
1554    Ok(())
1555}
1556
1557#[cfg(test)]
1558mod test {
1559    use {
1560        super::*,
1561        crate::{
1562            pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
1563            state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
1564        },
1565        bytemuck::Pod,
1566        solana_account_info::{
1567            Account as GetAccount, IntoAccountInfo, MAX_PERMITTED_DATA_INCREASE,
1568        },
1569        solana_address::Address,
1570        solana_nullable::MaybeNull,
1571        solana_zero_copy::unaligned::{Bool, U64},
1572        transfer_fee::test::test_transfer_fee_config,
1573    };
1574
1575    /// Test fixed-length struct
1576    #[repr(C)]
1577    #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
1578    struct FixedLenMintTest {
1579        data: [u8; 8],
1580    }
1581    impl Extension for FixedLenMintTest {
1582        const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
1583    }
1584
1585    /// Test variable-length struct
1586    #[derive(Clone, Debug, PartialEq)]
1587    struct VariableLenMintTest {
1588        data: Vec<u8>,
1589    }
1590    impl Extension for VariableLenMintTest {
1591        const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
1592    }
1593    impl VariableLenPack for VariableLenMintTest {
1594        fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
1595            let data_start = size_of::<u64>();
1596            let end = data_start + self.data.len();
1597            if dst.len() < end {
1598                Err(ProgramError::InvalidAccountData)
1599            } else {
1600                dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
1601                dst[data_start..end].copy_from_slice(&self.data);
1602                Ok(())
1603            }
1604        }
1605        fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
1606            let data_start = size_of::<u64>();
1607            let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
1608            if src[data_start..data_start + length].len() != length {
1609                return Err(ProgramError::InvalidAccountData);
1610            }
1611            let data = Vec::from(&src[data_start..data_start + length]);
1612            Ok(Self { data })
1613        }
1614        fn get_packed_len(&self) -> Result<usize, ProgramError> {
1615            Ok(size_of::<u64>().saturating_add(self.data.len()))
1616        }
1617    }
1618
1619    const MINT_WITH_ACCOUNT_TYPE: &[u8] = &[
1620        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1621        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1622        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1623        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1624        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1625        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1626        1, // account type
1627    ];
1628
1629    const MINT_WITH_EXTENSION: &[u8] = &[
1630        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1631        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1632        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
1633        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1634        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1635        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1636        1, // account type
1637        3, 0, // extension type
1638        32, 0, // length
1639        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1640        1, 1, // data
1641    ];
1642
1643    const ACCOUNT_WITH_EXTENSION: &[u8] = &[
1644        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1645        1, 1, // mint
1646        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1647        2, 2, // owner
1648        3, 0, 0, 0, 0, 0, 0, 0, // amount
1649        1, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
1650        4, 4, 4, 4, 4, 4, // delegate
1651        2, // account state
1652        1, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, // is native
1653        6, 0, 0, 0, 0, 0, 0, 0, // delegated amount
1654        1, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1655        7, 7, 7, 7, 7, 7, // close authority
1656        2, // account type
1657        15, 0, // extension type
1658        1, 0, // length
1659        1, // data
1660    ];
1661
1662    #[test]
1663    fn unpack_opaque_buffer() {
1664        // Mint
1665        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_ACCOUNT_TYPE).unwrap();
1666        assert_eq!(state.base, &TEST_POD_MINT);
1667        let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
1668        assert_eq!(state.base, &TEST_POD_MINT);
1669        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
1670        let close_authority: MaybeNull<Address> =
1671            Some(Address::new_from_array([1; 32])).try_into().unwrap();
1672        assert_eq!(extension.close_authority, close_authority);
1673        assert_eq!(
1674            state.get_extension::<TransferFeeConfig>(),
1675            Err(ProgramError::InvalidAccountData)
1676        );
1677        assert_eq!(
1678            PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
1679            Err(ProgramError::UninitializedAccount)
1680        );
1681
1682        let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
1683        assert_eq!(state.base, &TEST_POD_MINT);
1684
1685        let mut test_mint = TEST_MINT_SLICE.to_vec();
1686        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
1687        assert_eq!(state.base, &TEST_POD_MINT);
1688
1689        // Account
1690        let state = PodStateWithExtensions::<PodAccount>::unpack(ACCOUNT_WITH_EXTENSION).unwrap();
1691        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1692        let extension = state.get_extension::<TransferHookAccount>().unwrap();
1693        let transferring = Bool::from(true);
1694        assert_eq!(extension.transferring, transferring);
1695        assert_eq!(
1696            PodStateWithExtensions::<PodMint>::unpack(ACCOUNT_WITH_EXTENSION),
1697            Err(ProgramError::InvalidAccountData)
1698        );
1699
1700        let state = PodStateWithExtensions::<PodAccount>::unpack(TEST_ACCOUNT_SLICE).unwrap();
1701        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1702
1703        let mut test_account = TEST_ACCOUNT_SLICE.to_vec();
1704        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut test_account).unwrap();
1705        assert_eq!(state.base, &TEST_POD_ACCOUNT);
1706    }
1707
1708    #[test]
1709    fn mint_fail_unpack_opaque_buffer() {
1710        // input buffer too small
1711        let mut buffer = vec![0, 3];
1712        assert_eq!(
1713            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1714            Err(ProgramError::InvalidAccountData)
1715        );
1716        assert_eq!(
1717            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1718            Err(ProgramError::InvalidAccountData)
1719        );
1720        assert_eq!(
1721            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
1722            Err(ProgramError::InvalidAccountData)
1723        );
1724
1725        // tweak the account type
1726        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1727        buffer[BASE_ACCOUNT_LENGTH] = 3;
1728        assert_eq!(
1729            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1730            Err(ProgramError::InvalidAccountData)
1731        );
1732
1733        // clear the mint initialized byte
1734        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1735        buffer[45] = 0;
1736        assert_eq!(
1737            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1738            Err(ProgramError::UninitializedAccount)
1739        );
1740
1741        // tweak the padding
1742        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1743        buffer[PodMint::SIZE_OF] = 100;
1744        assert_eq!(
1745            PodStateWithExtensions::<PodMint>::unpack(&buffer),
1746            Err(ProgramError::InvalidAccountData)
1747        );
1748
1749        // tweak the extension type
1750        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1751        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
1752        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1753        assert_eq!(
1754            state.get_extension::<TransferFeeConfig>(),
1755            Err(ProgramError::InvalidAccountData)
1756        );
1757
1758        // tweak the length, too big
1759        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1760        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1761        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1762        assert_eq!(
1763            state.get_extension::<TransferFeeConfig>(),
1764            Err(ProgramError::InvalidAccountData)
1765        );
1766
1767        // tweak the length, too small
1768        let mut buffer = MINT_WITH_EXTENSION.to_vec();
1769        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1770        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1771        assert_eq!(
1772            state.get_extension::<TransferFeeConfig>(),
1773            Err(ProgramError::InvalidAccountData)
1774        );
1775
1776        // data buffer is too small
1777        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
1778        let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
1779        assert_eq!(
1780            state.get_extension::<MintCloseAuthority>(),
1781            Err(ProgramError::InvalidAccountData)
1782        );
1783    }
1784
1785    #[test]
1786    fn account_fail_unpack_opaque_buffer() {
1787        // input buffer too small
1788        let mut buffer = vec![0, 3];
1789        assert_eq!(
1790            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1791            Err(ProgramError::InvalidAccountData)
1792        );
1793        assert_eq!(
1794            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1795            Err(ProgramError::InvalidAccountData)
1796        );
1797        assert_eq!(
1798            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1799            Err(ProgramError::InvalidAccountData)
1800        );
1801
1802        // input buffer invalid
1803        // all 5's - not a valid `AccountState`
1804        let mut buffer = vec![5; BASE_ACCOUNT_LENGTH];
1805        assert_eq!(
1806            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1807            Err(ProgramError::UninitializedAccount)
1808        );
1809        assert_eq!(
1810            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
1811            Err(ProgramError::UninitializedAccount)
1812        );
1813
1814        // tweak the account type
1815        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1816        buffer[BASE_ACCOUNT_LENGTH] = 3;
1817        assert_eq!(
1818            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1819            Err(ProgramError::InvalidAccountData)
1820        );
1821
1822        // clear the state byte
1823        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1824        buffer[108] = 0;
1825        assert_eq!(
1826            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
1827            Err(ProgramError::UninitializedAccount)
1828        );
1829
1830        // tweak the extension type
1831        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1832        buffer[BASE_ACCOUNT_LENGTH + 1] = 12;
1833        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1834        assert_eq!(
1835            state.get_extension::<TransferHookAccount>(),
1836            Err(ProgramError::InvalidAccountData),
1837        );
1838
1839        // tweak the length, too big
1840        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1841        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
1842        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1843        assert_eq!(
1844            state.get_extension::<TransferHookAccount>(),
1845            Err(ProgramError::InvalidAccountData)
1846        );
1847
1848        // tweak the length, too small
1849        let mut buffer = ACCOUNT_WITH_EXTENSION.to_vec();
1850        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
1851        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
1852        assert_eq!(
1853            state.get_extension::<TransferHookAccount>(),
1854            Err(ProgramError::InvalidAccountData)
1855        );
1856
1857        // data buffer is too small
1858        let buffer = &ACCOUNT_WITH_EXTENSION[..ACCOUNT_WITH_EXTENSION.len() - 1];
1859        let state = PodStateWithExtensions::<PodAccount>::unpack(buffer).unwrap();
1860        assert_eq!(
1861            state.get_extension::<TransferHookAccount>(),
1862            Err(ProgramError::InvalidAccountData)
1863        );
1864    }
1865
1866    #[test]
1867    fn get_extension_types_with_opaque_buffer() {
1868        // incorrect due to the length
1869        assert_eq!(
1870            get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
1871            ProgramError::InvalidAccountData,
1872        );
1873        // incorrect due to the huge enum number
1874        assert_eq!(
1875            get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
1876            ProgramError::InvalidAccountData,
1877        );
1878        // correct due to the good enum number and zero length
1879        assert_eq!(
1880            get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
1881            TlvDataInfo {
1882                extension_types: vec![ExtensionType::try_from(1).unwrap()],
1883                used_len: add_type_and_length_to_len(0),
1884            }
1885        );
1886        // correct since it's just uninitialized data at the end
1887        assert_eq!(
1888            get_tlv_data_info(&[0, 0]).unwrap(),
1889            TlvDataInfo {
1890                extension_types: vec![],
1891                used_len: 0
1892            }
1893        );
1894    }
1895
1896    #[test]
1897    fn mint_with_extension_pack_unpack() {
1898        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
1899            ExtensionType::MintCloseAuthority,
1900            ExtensionType::TransferFeeConfig,
1901        ])
1902        .unwrap();
1903        let mut buffer = vec![0; mint_size];
1904
1905        // fail unpack
1906        assert_eq!(
1907            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
1908            Err(ProgramError::UninitializedAccount),
1909        );
1910
1911        let mut state =
1912            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1913        // fail init account extension
1914        assert_eq!(
1915            state.init_extension::<TransferFeeAmount>(true),
1916            Err(ProgramError::InvalidAccountData),
1917        );
1918
1919        // success write extension
1920        let close_authority: MaybeNull<Address> =
1921            Some(Address::new_from_array([1; 32])).try_into().unwrap();
1922        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1923        extension.close_authority = close_authority;
1924        assert_eq!(
1925            &state.get_extension_types().unwrap(),
1926            &[ExtensionType::MintCloseAuthority]
1927        );
1928
1929        // fail init extension when already initialized
1930        assert_eq!(
1931            state.init_extension::<MintCloseAuthority>(false),
1932            Err(ProgramError::Custom(
1933                TokenError::ExtensionAlreadyInitialized as u32
1934            ))
1935        );
1936
1937        // fail unpack as account, a mint extension was written
1938        assert_eq!(
1939            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
1940            Err(ProgramError::Custom(
1941                TokenError::ExtensionBaseMismatch as u32
1942            ))
1943        );
1944
1945        // fail unpack again, still no base data
1946        assert_eq!(
1947            PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
1948            Err(ProgramError::UninitializedAccount),
1949        );
1950
1951        // write base mint
1952        let mut state =
1953            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
1954        *state.base = TEST_POD_MINT;
1955        state.init_account_type().unwrap();
1956
1957        // check raw buffer
1958        let mut expect = TEST_MINT_SLICE.to_vec();
1959        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
1960        expect.push(AccountType::Mint.into());
1961        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1962        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
1963        expect.extend_from_slice(&[1; 32]); // data
1964        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1965        expect.extend_from_slice(&[0; size_of::<Length>()]);
1966        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1967        assert_eq!(expect, buffer);
1968
1969        // unpack uninitialized will now fail because the PodMint is now initialized
1970        assert_eq!(
1971            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
1972            Err(TokenError::AlreadyInUse.into()),
1973        );
1974
1975        // check unpacking
1976        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
1977
1978        // update base
1979        *state.base = TEST_POD_MINT;
1980        state.base.supply = (u64::from(state.base.supply) + 100).into();
1981
1982        // check unpacking
1983        let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1984        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1985
1986        // update extension
1987        let close_authority: MaybeNull<Address> = None.try_into().unwrap();
1988        unpacked_extension.close_authority = close_authority;
1989
1990        // check updates are propagated
1991        let base = *state.base;
1992        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
1993        assert_eq!(state.base, &base);
1994        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1995        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1996
1997        // check raw buffer
1998        let mut expect = vec![];
1999        expect.extend_from_slice(bytemuck::bytes_of(&base));
2000        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2001        expect.push(AccountType::Mint.into());
2002        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2003        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2004        expect.extend_from_slice(&[0; 32]);
2005        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
2006        expect.extend_from_slice(&[0; size_of::<Length>()]);
2007        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
2008        assert_eq!(expect, buffer);
2009
2010        // fail unpack as an account
2011        assert_eq!(
2012            PodStateWithExtensions::<PodAccount>::unpack(&buffer),
2013            Err(ProgramError::UninitializedAccount),
2014        );
2015
2016        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2017        // init one more extension
2018        let mint_transfer_fee = test_transfer_fee_config();
2019        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2020        new_extension.transfer_fee_config_authority =
2021            mint_transfer_fee.transfer_fee_config_authority;
2022        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2023        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
2024        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2025        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2026
2027        assert_eq!(
2028            &state.get_extension_types().unwrap(),
2029            &[
2030                ExtensionType::MintCloseAuthority,
2031                ExtensionType::TransferFeeConfig
2032            ]
2033        );
2034
2035        // check raw buffer
2036        let mut expect = vec![];
2037        expect.extend_from_slice(bytemuck::bytes_of(&base));
2038        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2039        expect.push(AccountType::Mint.into());
2040        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
2041        expect.extend_from_slice(&(size_of::<MintCloseAuthority>() as u16).to_le_bytes());
2042        expect.extend_from_slice(&[0; 32]); // data
2043        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
2044        expect.extend_from_slice(&(size_of::<TransferFeeConfig>() as u16).to_le_bytes());
2045        expect.extend_from_slice(bytemuck::bytes_of(&mint_transfer_fee));
2046        assert_eq!(expect, buffer);
2047
2048        // fail to init one more extension that does not fit
2049        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2050        assert_eq!(
2051            state.init_extension::<MintPaddingTest>(true),
2052            Err(ProgramError::InvalidAccountData),
2053        );
2054    }
2055
2056    #[test]
2057    fn mint_extension_any_order() {
2058        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2059            ExtensionType::MintCloseAuthority,
2060            ExtensionType::TransferFeeConfig,
2061        ])
2062        .unwrap();
2063        let mut buffer = vec![0; mint_size];
2064
2065        let mut state =
2066            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2067        // write extensions
2068        let close_authority: MaybeNull<Address> =
2069            Some(Address::new_from_array([1; 32])).try_into().unwrap();
2070        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2071        extension.close_authority = close_authority;
2072
2073        let mint_transfer_fee = test_transfer_fee_config();
2074        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2075        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2076        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2077        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2078        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2079        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2080
2081        assert_eq!(
2082            &state.get_extension_types().unwrap(),
2083            &[
2084                ExtensionType::MintCloseAuthority,
2085                ExtensionType::TransferFeeConfig
2086            ]
2087        );
2088
2089        // write base mint
2090        let mut state =
2091            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2092        *state.base = TEST_POD_MINT;
2093        state.init_account_type().unwrap();
2094
2095        let mut other_buffer = vec![0; mint_size];
2096        let mut state =
2097            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
2098
2099        // write base mint
2100        *state.base = TEST_POD_MINT;
2101        state.init_account_type().unwrap();
2102
2103        // write extensions in a different order
2104        let mint_transfer_fee = test_transfer_fee_config();
2105        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
2106        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
2107        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
2108        extension.withheld_amount = mint_transfer_fee.withheld_amount;
2109        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
2110        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
2111
2112        let close_authority: MaybeNull<Address> =
2113            Some(Address::new_from_array([1; 32])).try_into().unwrap();
2114        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
2115        extension.close_authority = close_authority;
2116
2117        assert_eq!(
2118            &state.get_extension_types().unwrap(),
2119            &[
2120                ExtensionType::TransferFeeConfig,
2121                ExtensionType::MintCloseAuthority
2122            ]
2123        );
2124
2125        // buffers are NOT the same because written in a different order
2126        assert_ne!(buffer, other_buffer);
2127        let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
2128        let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
2129
2130        // BUT mint and extensions are the same
2131        assert_eq!(
2132            state.get_extension::<TransferFeeConfig>().unwrap(),
2133            other_state.get_extension::<TransferFeeConfig>().unwrap()
2134        );
2135        assert_eq!(
2136            state.get_extension::<MintCloseAuthority>().unwrap(),
2137            other_state.get_extension::<MintCloseAuthority>().unwrap()
2138        );
2139        assert_eq!(state.base, other_state.base);
2140    }
2141
2142    #[test]
2143    fn mint_with_multisig_len() {
2144        let mut buffer = vec![0; Multisig::LEN];
2145        assert_eq!(
2146            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
2147            Err(ProgramError::InvalidAccountData),
2148        );
2149        let mint_size =
2150            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2151                .unwrap();
2152        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
2153        let mut buffer = vec![0; mint_size];
2154
2155        // write base mint
2156        let mut state =
2157            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2158        *state.base = TEST_POD_MINT;
2159        state.init_account_type().unwrap();
2160
2161        // write padding
2162        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2163        extension.padding1 = [1; 128];
2164        extension.padding2 = [1; 48];
2165        extension.padding3 = [1; 9];
2166
2167        assert_eq!(
2168            &state.get_extension_types().unwrap(),
2169            &[ExtensionType::MintPaddingTest]
2170        );
2171
2172        // check raw buffer
2173        let mut expect = TEST_MINT_SLICE.to_vec();
2174        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); // padding
2175        expect.push(AccountType::Mint.into());
2176        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
2177        expect.extend_from_slice(&(size_of::<MintPaddingTest>() as u16).to_le_bytes());
2178        expect.extend_from_slice(&vec![1; size_of::<MintPaddingTest>()]);
2179        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2180        assert_eq!(expect, buffer);
2181    }
2182
2183    #[test]
2184    fn account_with_extension_pack_unpack() {
2185        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2186            ExtensionType::TransferFeeAmount,
2187        ])
2188        .unwrap();
2189        let mut buffer = vec![0; account_size];
2190
2191        // fail unpack
2192        assert_eq!(
2193            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
2194            Err(ProgramError::UninitializedAccount),
2195        );
2196
2197        let mut state =
2198            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2199        // fail init mint extension
2200        assert_eq!(
2201            state.init_extension::<TransferFeeConfig>(true),
2202            Err(ProgramError::InvalidAccountData),
2203        );
2204        // success write extension
2205        let withheld_amount = U64::from(u64::MAX);
2206        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
2207        extension.withheld_amount = withheld_amount;
2208
2209        assert_eq!(
2210            &state.get_extension_types().unwrap(),
2211            &[ExtensionType::TransferFeeAmount]
2212        );
2213
2214        // fail unpack again, still no base data
2215        assert_eq!(
2216            PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
2217            Err(ProgramError::UninitializedAccount),
2218        );
2219
2220        // write base account
2221        let mut state =
2222            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2223        *state.base = TEST_POD_ACCOUNT;
2224        state.init_account_type().unwrap();
2225        let base = *state.base;
2226
2227        // check raw buffer
2228        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2229        expect.push(AccountType::Account.into());
2230        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2231        expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2232        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2233        assert_eq!(expect, buffer);
2234
2235        // check unpacking
2236        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2237        assert_eq!(state.base, &base);
2238        assert_eq!(
2239            &state.get_extension_types().unwrap(),
2240            &[ExtensionType::TransferFeeAmount]
2241        );
2242
2243        // update base
2244        *state.base = TEST_POD_ACCOUNT;
2245        state.base.amount = (u64::from(state.base.amount) + 100).into();
2246
2247        // check unpacking
2248        let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
2249        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2250
2251        // update extension
2252        let withheld_amount = U64::from(u32::MAX as u64);
2253        unpacked_extension.withheld_amount = withheld_amount;
2254
2255        // check updates are propagated
2256        let base = *state.base;
2257        let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
2258        assert_eq!(state.base, &base);
2259        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
2260        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
2261
2262        // check raw buffer
2263        let mut expect = vec![];
2264        expect.extend_from_slice(bytemuck::bytes_of(&base));
2265        expect.push(AccountType::Account.into());
2266        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
2267        expect.extend_from_slice(&(size_of::<TransferFeeAmount>() as u16).to_le_bytes());
2268        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
2269        assert_eq!(expect, buffer);
2270
2271        // fail unpack as a mint
2272        assert_eq!(
2273            PodStateWithExtensions::<PodMint>::unpack(&buffer),
2274            Err(ProgramError::InvalidAccountData),
2275        );
2276    }
2277
2278    #[test]
2279    fn account_with_multisig_len() {
2280        let mut buffer = vec![0; Multisig::LEN];
2281        assert_eq!(
2282            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2283            Err(ProgramError::InvalidAccountData),
2284        );
2285        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2286            ExtensionType::AccountPaddingTest,
2287        ])
2288        .unwrap();
2289        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
2290        let mut buffer = vec![0; account_size];
2291
2292        // write base account
2293        let mut state =
2294            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2295        *state.base = TEST_POD_ACCOUNT;
2296        state.init_account_type().unwrap();
2297
2298        // write padding
2299        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
2300        extension.0.padding1 = [2; 128];
2301        extension.0.padding2 = [2; 48];
2302        extension.0.padding3 = [2; 9];
2303
2304        assert_eq!(
2305            &state.get_extension_types().unwrap(),
2306            &[ExtensionType::AccountPaddingTest]
2307        );
2308
2309        // check raw buffer
2310        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
2311        expect.push(AccountType::Account.into());
2312        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
2313        expect.extend_from_slice(&(size_of::<AccountPaddingTest>() as u16).to_le_bytes());
2314        expect.extend_from_slice(&vec![2; size_of::<AccountPaddingTest>()]);
2315        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
2316        assert_eq!(expect, buffer);
2317    }
2318
2319    #[test]
2320    fn test_set_account_type() {
2321        // account with buffer big enough for AccountType and Extension
2322        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2323        let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2324            ExtensionType::ImmutableOwner,
2325        ])
2326        .unwrap()
2327            - buffer.len();
2328        buffer.append(&mut vec![0; needed_len]);
2329        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2330        assert_eq!(err, ProgramError::InvalidAccountData);
2331        set_account_type::<PodAccount>(&mut buffer).unwrap();
2332        // unpack is viable after manual set_account_type
2333        let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2334        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2335        assert_eq!(state.account_type[0], AccountType::Account as u8);
2336        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
2337
2338        // account with buffer big enough for AccountType only
2339        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2340        buffer.append(&mut vec![0; 2]);
2341        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2342        assert_eq!(err, ProgramError::InvalidAccountData);
2343        set_account_type::<PodAccount>(&mut buffer).unwrap();
2344        // unpack is viable after manual set_account_type
2345        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2346        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2347        assert_eq!(state.account_type[0], AccountType::Account as u8);
2348
2349        // account with AccountType already set => noop
2350        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2351        buffer.append(&mut vec![2, 0]);
2352        let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2353        set_account_type::<PodAccount>(&mut buffer).unwrap();
2354        let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
2355        assert_eq!(state.base, &TEST_POD_ACCOUNT);
2356        assert_eq!(state.account_type[0], AccountType::Account as u8);
2357
2358        // account with wrong AccountType fails
2359        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2360        buffer.append(&mut vec![1, 0]);
2361        let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
2362        assert_eq!(err, ProgramError::InvalidAccountData);
2363        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2364        assert_eq!(err, ProgramError::InvalidAccountData);
2365
2366        // mint with buffer big enough for AccountType and Extension
2367        let mut buffer = TEST_MINT_SLICE.to_vec();
2368        let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
2369            ExtensionType::MintCloseAuthority,
2370        ])
2371        .unwrap()
2372            - buffer.len();
2373        buffer.append(&mut vec![0; needed_len]);
2374        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2375        assert_eq!(err, ProgramError::InvalidAccountData);
2376        set_account_type::<PodMint>(&mut buffer).unwrap();
2377        // unpack is viable after manual set_account_type
2378        let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2379        assert_eq!(state.base, &TEST_POD_MINT);
2380        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2381        state.init_extension::<MintCloseAuthority>(true).unwrap();
2382
2383        // mint with buffer big enough for AccountType only
2384        let mut buffer = TEST_MINT_SLICE.to_vec();
2385        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2386        buffer.append(&mut vec![0; 2]);
2387        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2388        assert_eq!(err, ProgramError::InvalidAccountData);
2389        set_account_type::<PodMint>(&mut buffer).unwrap();
2390        // unpack is viable after manual set_account_type
2391        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2392        assert_eq!(state.base, &TEST_POD_MINT);
2393        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2394
2395        // mint with AccountType already set => noop
2396        let mut buffer = TEST_MINT_SLICE.to_vec();
2397        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2398        buffer.append(&mut vec![1, 0]);
2399        set_account_type::<PodMint>(&mut buffer).unwrap();
2400        let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
2401        assert_eq!(state.base, &TEST_POD_MINT);
2402        assert_eq!(state.account_type[0], AccountType::Mint as u8);
2403
2404        // mint with wrong AccountType fails
2405        let mut buffer = TEST_MINT_SLICE.to_vec();
2406        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2407        buffer.append(&mut vec![2, 0]);
2408        let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
2409        assert_eq!(err, ProgramError::InvalidAccountData);
2410        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2411        assert_eq!(err, ProgramError::InvalidAccountData);
2412    }
2413
2414    #[test]
2415    fn test_set_account_type_wrongly() {
2416        // try to set PodAccount account_type to PodMint
2417        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
2418        buffer.append(&mut vec![0; 2]);
2419        let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
2420        assert_eq!(err, ProgramError::InvalidAccountData);
2421
2422        // try to set PodMint account_type to PodAccount
2423        let mut buffer = TEST_MINT_SLICE.to_vec();
2424        buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
2425        buffer.append(&mut vec![0; 2]);
2426        let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
2427        assert_eq!(err, ProgramError::InvalidAccountData);
2428    }
2429
2430    #[test]
2431    #[allow(deprecated)]
2432    fn test_get_required_init_account_extensions() {
2433        // Some mint extensions with no required account extensions
2434        let mint_extensions = vec![
2435            ExtensionType::MintCloseAuthority,
2436            ExtensionType::Uninitialized,
2437        ];
2438        assert_eq!(
2439            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2440            vec![]
2441        );
2442
2443        // One mint extension with required account extension, one without
2444        let mint_extensions = vec![
2445            ExtensionType::TransferFeeConfig,
2446            ExtensionType::MintCloseAuthority,
2447        ];
2448        assert_eq!(
2449            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2450            vec![ExtensionType::TransferFeeAmount]
2451        );
2452
2453        // Some mint extensions both with required account extensions
2454        let mint_extensions = vec![
2455            ExtensionType::TransferFeeConfig,
2456            ExtensionType::MintPaddingTest,
2457        ];
2458        assert_eq!(
2459            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2460            vec![
2461                ExtensionType::TransferFeeAmount,
2462                ExtensionType::AccountPaddingTest
2463            ]
2464        );
2465
2466        // Demonstrate that method does not dedupe inputs or outputs
2467        let mint_extensions = vec![
2468            ExtensionType::TransferFeeConfig,
2469            ExtensionType::TransferFeeConfig,
2470        ];
2471        assert_eq!(
2472            ExtensionType::get_required_init_account_extensions(&mint_extensions),
2473            vec![
2474                ExtensionType::TransferFeeAmount,
2475                ExtensionType::TransferFeeAmount
2476            ]
2477        );
2478    }
2479
2480    #[test]
2481    fn mint_without_extensions() {
2482        let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
2483        let mut buffer = vec![0; space];
2484        assert_eq!(
2485            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
2486            Err(ProgramError::InvalidAccountData),
2487        );
2488
2489        // write base account
2490        let mut state =
2491            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2492        *state.base = TEST_POD_MINT;
2493        state.init_account_type().unwrap();
2494
2495        // fail init extension
2496        assert_eq!(
2497            state.init_extension::<TransferFeeConfig>(true),
2498            Err(ProgramError::InvalidAccountData),
2499        );
2500
2501        assert_eq!(TEST_MINT_SLICE, buffer);
2502    }
2503
2504    #[test]
2505    fn test_init_nonzero_default() {
2506        let mint_size =
2507            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
2508                .unwrap();
2509        let mut buffer = vec![0; mint_size];
2510        let mut state =
2511            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2512        *state.base = TEST_POD_MINT;
2513        state.init_account_type().unwrap();
2514        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
2515        assert_eq!(extension.padding1, [1; 128]);
2516        assert_eq!(extension.padding2, [2; 48]);
2517        assert_eq!(extension.padding3, [3; 9]);
2518    }
2519
2520    #[test]
2521    fn test_init_buffer_too_small() {
2522        let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
2523            ExtensionType::MintCloseAuthority,
2524        ])
2525        .unwrap();
2526        let mut buffer = vec![0; mint_size - 1];
2527        let mut state =
2528            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2529        let err = state
2530            .init_extension::<MintCloseAuthority>(true)
2531            .unwrap_err();
2532        assert_eq!(err, ProgramError::InvalidAccountData);
2533
2534        state.tlv_data[0] = 3;
2535        state.tlv_data[2] = 32;
2536        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2537        assert_eq!(err, ProgramError::InvalidAccountData);
2538
2539        let mut buffer = vec![0; PodMint::SIZE_OF + 2];
2540        let err =
2541            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
2542        assert_eq!(err, ProgramError::InvalidAccountData);
2543
2544        // OK since there are two bytes for the type, which is `Uninitialized`
2545        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
2546        let mut state =
2547            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2548        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
2549        assert_eq!(err, ProgramError::InvalidAccountData);
2550
2551        assert_eq!(state.get_extension_types().unwrap(), vec![]);
2552
2553        // OK, there aren't two bytes for the type, but that's fine
2554        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
2555        let state =
2556            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2557        assert_eq!(state.get_extension_types().unwrap(), []);
2558    }
2559
2560    #[test]
2561    fn test_extension_with_no_data() {
2562        let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
2563            ExtensionType::ImmutableOwner,
2564        ])
2565        .unwrap();
2566        let mut buffer = vec![0; account_size];
2567        let mut state =
2568            PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
2569        *state.base = TEST_POD_ACCOUNT;
2570        state.init_account_type().unwrap();
2571
2572        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
2573        assert_eq!(
2574            err,
2575            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
2576        );
2577
2578        state.init_extension::<ImmutableOwner>(true).unwrap();
2579        assert_eq!(
2580            get_first_extension_type(state.tlv_data).unwrap(),
2581            Some(ExtensionType::ImmutableOwner)
2582        );
2583        assert_eq!(
2584            get_tlv_data_info(state.tlv_data).unwrap(),
2585            TlvDataInfo {
2586                extension_types: vec![ExtensionType::ImmutableOwner],
2587                used_len: add_type_and_length_to_len(0)
2588            }
2589        );
2590    }
2591
2592    #[test]
2593    fn fail_account_len_with_metadata() {
2594        assert_eq!(
2595            ExtensionType::try_calculate_account_len::<PodMint>(&[
2596                ExtensionType::MintCloseAuthority,
2597                ExtensionType::VariableLenMintTest,
2598                ExtensionType::TransferFeeConfig,
2599            ])
2600            .unwrap_err(),
2601            ProgramError::InvalidArgument
2602        );
2603    }
2604
2605    #[test]
2606    fn alloc() {
2607        let variable_len = VariableLenMintTest { data: vec![1] };
2608        let alloc_size = variable_len.get_packed_len().unwrap();
2609        let account_size =
2610            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
2611        let mut buffer = vec![0; account_size];
2612        let mut state =
2613            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2614        state
2615            .init_variable_len_extension(&variable_len, false)
2616            .unwrap();
2617
2618        // can't double alloc
2619        assert_eq!(
2620            state
2621                .init_variable_len_extension(&variable_len, false)
2622                .unwrap_err(),
2623            TokenError::ExtensionAlreadyInitialized.into()
2624        );
2625
2626        // unless overwrite is set
2627        state
2628            .init_variable_len_extension(&variable_len, true)
2629            .unwrap();
2630
2631        // can't change the size during overwrite though
2632        assert_eq!(
2633            state
2634                .init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
2635                .unwrap_err(),
2636            TokenError::InvalidLengthForAlloc.into()
2637        );
2638
2639        // try to write too far, fail earlier
2640        assert_eq!(
2641            state
2642                .init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
2643                .unwrap_err(),
2644            ProgramError::InvalidAccountData
2645        );
2646    }
2647
2648    #[test]
2649    fn realloc() {
2650        let small_variable_len = VariableLenMintTest {
2651            data: vec![1, 2, 3],
2652        };
2653        let base_variable_len = VariableLenMintTest {
2654            data: vec![1, 2, 3, 4],
2655        };
2656        let big_variable_len = VariableLenMintTest {
2657            data: vec![1, 2, 3, 4, 5],
2658        };
2659        let too_big_variable_len = VariableLenMintTest {
2660            data: vec![1, 2, 3, 4, 5, 6],
2661        };
2662        let account_size =
2663            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2664                .unwrap()
2665                + add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
2666        let mut buffer = vec![0; account_size];
2667        let mut state =
2668            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2669
2670        // alloc both types
2671        state
2672            .init_variable_len_extension(&base_variable_len, false)
2673            .unwrap();
2674        let max_pubkey: MaybeNull<Address> =
2675            Some(Address::new_from_array([255; 32])).try_into().unwrap();
2676        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2677        extension.authority = max_pubkey;
2678        extension.metadata_address = max_pubkey;
2679
2680        // realloc first entry to larger
2681        state
2682            .realloc_variable_len_extension(&big_variable_len)
2683            .unwrap();
2684        let extension = state
2685            .get_variable_len_extension::<VariableLenMintTest>()
2686            .unwrap();
2687        assert_eq!(extension, big_variable_len);
2688        let extension = state.get_extension::<MetadataPointer>().unwrap();
2689        assert_eq!(extension.authority, max_pubkey);
2690        assert_eq!(extension.metadata_address, max_pubkey);
2691
2692        // realloc to smaller
2693        state
2694            .realloc_variable_len_extension(&small_variable_len)
2695            .unwrap();
2696        let extension = state
2697            .get_variable_len_extension::<VariableLenMintTest>()
2698            .unwrap();
2699        assert_eq!(extension, small_variable_len);
2700        let extension = state.get_extension::<MetadataPointer>().unwrap();
2701        assert_eq!(extension.authority, max_pubkey);
2702        assert_eq!(extension.metadata_address, max_pubkey);
2703        let diff = big_variable_len.get_packed_len().unwrap()
2704            - small_variable_len.get_packed_len().unwrap();
2705        assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
2706
2707        // unpack again since we dropped the last `state`
2708        let mut state =
2709            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2710        // realloc too much, fails
2711        assert_eq!(
2712            state
2713                .realloc_variable_len_extension(&too_big_variable_len)
2714                .unwrap_err(),
2715            ProgramError::InvalidAccountData,
2716        );
2717    }
2718
2719    #[test]
2720    fn account_len() {
2721        let small_variable_len = VariableLenMintTest {
2722            data: vec![20, 30, 40],
2723        };
2724        let variable_len = VariableLenMintTest {
2725            data: vec![20, 30, 40, 50],
2726        };
2727        let big_variable_len = VariableLenMintTest {
2728            data: vec![20, 30, 40, 50, 60],
2729        };
2730        let value_len = variable_len.get_packed_len().unwrap();
2731        let account_size =
2732            BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
2733        let mut buffer = vec![0; account_size];
2734        let mut state =
2735            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2736
2737        // With a new extension, new length must include padding, 1 byte for
2738        // account type, 2 bytes for type, 2 for length
2739        let current_len = state.try_get_account_len().unwrap();
2740        assert_eq!(current_len, PodMint::SIZE_OF);
2741        let new_len = state
2742            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2743                &variable_len,
2744            )
2745            .unwrap();
2746        assert_eq!(
2747            new_len,
2748            BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
2749        );
2750
2751        state
2752            .init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
2753            .unwrap();
2754        let current_len = state.try_get_account_len().unwrap();
2755        assert_eq!(current_len, new_len);
2756
2757        // Reduce the extension size
2758        let new_len = state
2759            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2760                &small_variable_len,
2761            )
2762            .unwrap();
2763        assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
2764
2765        // Increase the extension size
2766        let new_len = state
2767            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2768                &big_variable_len,
2769            )
2770            .unwrap();
2771        assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
2772
2773        // Maintain the extension size
2774        let new_len = state
2775            .try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
2776                &variable_len,
2777            )
2778            .unwrap();
2779        assert_eq!(new_len, current_len);
2780    }
2781
2782    /// Test helper for mimicking the data layout an on-chain `AccountInfo`,
2783    /// which permits "reallocs" as the Solana runtime does it
2784    struct SolanaAccountData {
2785        data: Vec<u8>,
2786        lamports: u64,
2787        owner: Address,
2788    }
2789    impl SolanaAccountData {
2790        /// Create a new fake solana account data. The underlying vector is
2791        /// overallocated to mimic the runtime
2792        fn new(account_data: &[u8]) -> Self {
2793            let mut data = vec![];
2794            data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
2795            data.extend_from_slice(account_data);
2796            data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
2797            Self {
2798                data,
2799                lamports: 10,
2800                owner: Address::new_unique(),
2801            }
2802        }
2803
2804        /// Data lops off the first 8 bytes, since those store the size of the
2805        /// account for the Solana runtime
2806        fn data(&self) -> &[u8] {
2807            let start = size_of::<u64>();
2808            let len = self.len();
2809            &self.data[start..start + len]
2810        }
2811
2812        /// Gets the runtime length of the account data
2813        fn len(&self) -> usize {
2814            self.data
2815                .get(..size_of::<u64>())
2816                .and_then(|slice| slice.try_into().ok())
2817                .map(u64::from_le_bytes)
2818                .unwrap() as usize
2819        }
2820    }
2821    impl GetAccount for SolanaAccountData {
2822        fn get(&mut self) -> (&mut u64, &mut [u8], &Address, bool) {
2823            // need to pull out the data here to avoid a double-mutable borrow
2824            let start = size_of::<u64>();
2825            let len = self.len();
2826            (
2827                &mut self.lamports,
2828                &mut self.data[start..start + len],
2829                &self.owner,
2830                false,
2831            )
2832        }
2833    }
2834
2835    #[test]
2836    fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
2837        let fixed_len = FixedLenMintTest {
2838            data: [1, 2, 3, 4, 5, 6, 7, 8],
2839        };
2840        let value_len = size_of::<FixedLenMintTest>();
2841        let base_account_size = PodMint::SIZE_OF;
2842        let mut buffer = vec![0; base_account_size];
2843        let state =
2844            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2845        *state.base = TEST_POD_MINT;
2846
2847        let mut data = SolanaAccountData::new(&buffer);
2848        let key = Address::new_unique();
2849        let account_info = (&key, &mut data).into_account_info();
2850
2851        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2852        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2853        assert_eq!(data.len(), new_account_len);
2854        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2855        assert_eq!(
2856            state.get_extension::<FixedLenMintTest>().unwrap(),
2857            &fixed_len,
2858        );
2859
2860        // alloc again succeeds with "overwrite"
2861        let account_info = (&key, &mut data).into_account_info();
2862        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2863
2864        // alloc again fails without "overwrite"
2865        let account_info = (&key, &mut data).into_account_info();
2866        assert_eq!(
2867            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2868            TokenError::ExtensionAlreadyInitialized.into()
2869        );
2870    }
2871
2872    #[test]
2873    fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
2874        let variable_len = VariableLenMintTest { data: vec![20, 99] };
2875        let value_len = variable_len.get_packed_len().unwrap();
2876        let base_account_size = PodMint::SIZE_OF;
2877        let mut buffer = vec![0; base_account_size];
2878        let state =
2879            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2880        *state.base = TEST_POD_MINT;
2881
2882        let mut data = SolanaAccountData::new(&buffer);
2883        let key = Address::new_unique();
2884        let account_info = (&key, &mut data).into_account_info();
2885
2886        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2887            &account_info,
2888            &variable_len,
2889            false,
2890        )
2891        .unwrap();
2892        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
2893        assert_eq!(data.len(), new_account_len);
2894        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2895        assert_eq!(
2896            state
2897                .get_variable_len_extension::<VariableLenMintTest>()
2898                .unwrap(),
2899            variable_len
2900        );
2901
2902        // alloc again succeeds with "overwrite"
2903        let account_info = (&key, &mut data).into_account_info();
2904        alloc_and_serialize_variable_len_extension::<PodMint, _>(
2905            &account_info,
2906            &variable_len,
2907            true,
2908        )
2909        .unwrap();
2910
2911        // alloc again fails without "overwrite"
2912        let account_info = (&key, &mut data).into_account_info();
2913        assert_eq!(
2914            alloc_and_serialize_variable_len_extension::<PodMint, _>(
2915                &account_info,
2916                &variable_len,
2917                false,
2918            )
2919            .unwrap_err(),
2920            TokenError::ExtensionAlreadyInitialized.into()
2921        );
2922    }
2923
2924    #[test]
2925    fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
2926        let fixed_len = FixedLenMintTest {
2927            data: [1, 2, 3, 4, 5, 6, 7, 8],
2928        };
2929        let value_len = size_of::<FixedLenMintTest>();
2930        let account_size =
2931            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
2932                .unwrap()
2933                + add_type_and_length_to_len(value_len);
2934        let mut buffer = vec![0; account_size];
2935        let mut state =
2936            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2937        *state.base = TEST_POD_MINT;
2938        state.init_account_type().unwrap();
2939
2940        let test_key: MaybeNull<Address> =
2941            Some(Address::new_from_array([20; 32])).try_into().unwrap();
2942        let extension = state.init_extension::<GroupPointer>(false).unwrap();
2943        extension.authority = test_key;
2944        extension.group_address = test_key;
2945
2946        let mut data = SolanaAccountData::new(&buffer);
2947        let key = Address::new_unique();
2948        let account_info = (&key, &mut data).into_account_info();
2949
2950        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
2951        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2952            + add_type_and_length_to_len(value_len)
2953            + add_type_and_length_to_len(size_of::<GroupPointer>());
2954        assert_eq!(data.len(), new_account_len);
2955        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
2956        assert_eq!(
2957            state.get_extension::<FixedLenMintTest>().unwrap(),
2958            &fixed_len,
2959        );
2960        let extension = state.get_extension::<GroupPointer>().unwrap();
2961        assert_eq!(extension.authority, test_key);
2962        assert_eq!(extension.group_address, test_key);
2963
2964        // alloc again succeeds with "overwrite"
2965        let account_info = (&key, &mut data).into_account_info();
2966        alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
2967
2968        // alloc again fails without "overwrite"
2969        let account_info = (&key, &mut data).into_account_info();
2970        assert_eq!(
2971            alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
2972            TokenError::ExtensionAlreadyInitialized.into()
2973        );
2974    }
2975
2976    #[test]
2977    fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
2978        let variable_len = VariableLenMintTest { data: vec![42, 6] };
2979        let value_len = variable_len.get_packed_len().unwrap();
2980        let account_size =
2981            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
2982                .unwrap()
2983                + add_type_and_length_to_len(value_len);
2984        let mut buffer = vec![0; account_size];
2985        let mut state =
2986            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
2987        *state.base = TEST_POD_MINT;
2988        state.init_account_type().unwrap();
2989
2990        let test_key: MaybeNull<Address> =
2991            Some(Address::new_from_array([20; 32])).try_into().unwrap();
2992        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
2993        extension.authority = test_key;
2994        extension.metadata_address = test_key;
2995
2996        let mut data = SolanaAccountData::new(&buffer);
2997        let key = Address::new_unique();
2998        let account_info = (&key, &mut data).into_account_info();
2999
3000        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3001            &account_info,
3002            &variable_len,
3003            false,
3004        )
3005        .unwrap();
3006        let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
3007            + add_type_and_length_to_len(value_len)
3008            + add_type_and_length_to_len(size_of::<MetadataPointer>());
3009        assert_eq!(data.len(), new_account_len);
3010        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3011        assert_eq!(
3012            state
3013                .get_variable_len_extension::<VariableLenMintTest>()
3014                .unwrap(),
3015            variable_len
3016        );
3017        let extension = state.get_extension::<MetadataPointer>().unwrap();
3018        assert_eq!(extension.authority, test_key);
3019        assert_eq!(extension.metadata_address, test_key);
3020
3021        // alloc again succeeds with "overwrite"
3022        let account_info = (&key, &mut data).into_account_info();
3023        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3024            &account_info,
3025            &variable_len,
3026            true,
3027        )
3028        .unwrap();
3029
3030        // alloc again fails without "overwrite"
3031        let account_info = (&key, &mut data).into_account_info();
3032        assert_eq!(
3033            alloc_and_serialize_variable_len_extension::<PodMint, _>(
3034                &account_info,
3035                &variable_len,
3036                false,
3037            )
3038            .unwrap_err(),
3039            TokenError::ExtensionAlreadyInitialized.into()
3040        );
3041    }
3042
3043    #[test]
3044    fn realloc_variable_len_tlv_in_account_info() {
3045        let variable_len = VariableLenMintTest {
3046            data: vec![1, 2, 3, 4, 5],
3047        };
3048        let alloc_size = variable_len.get_packed_len().unwrap();
3049        let account_size =
3050            ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
3051                .unwrap()
3052                + add_type_and_length_to_len(alloc_size);
3053        let mut buffer = vec![0; account_size];
3054        let mut state =
3055            PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
3056        *state.base = TEST_POD_MINT;
3057        state.init_account_type().unwrap();
3058
3059        // alloc both types
3060        state
3061            .init_variable_len_extension(&variable_len, false)
3062            .unwrap();
3063        let max_pubkey: MaybeNull<Address> =
3064            Some(Address::new_from_array([255; 32])).try_into().unwrap();
3065        let extension = state.init_extension::<MetadataPointer>(false).unwrap();
3066        extension.authority = max_pubkey;
3067        extension.metadata_address = max_pubkey;
3068
3069        // reallocate to smaller, make sure existing extension is fine
3070        let mut data = SolanaAccountData::new(&buffer);
3071        let key = Address::new_unique();
3072        let account_info = (&key, &mut data).into_account_info();
3073        let variable_len = VariableLenMintTest { data: vec![1, 2] };
3074        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3075            &account_info,
3076            &variable_len,
3077            true,
3078        )
3079        .unwrap();
3080
3081        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3082        let extension = state.get_extension::<MetadataPointer>().unwrap();
3083        assert_eq!(extension.authority, max_pubkey);
3084        assert_eq!(extension.metadata_address, max_pubkey);
3085        let extension = state
3086            .get_variable_len_extension::<VariableLenMintTest>()
3087            .unwrap();
3088        assert_eq!(extension, variable_len);
3089        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3090
3091        // reallocate to larger
3092        let account_info = (&key, &mut data).into_account_info();
3093        let variable_len = VariableLenMintTest {
3094            data: vec![1, 2, 3, 4, 5, 6, 7],
3095        };
3096        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3097            &account_info,
3098            &variable_len,
3099            true,
3100        )
3101        .unwrap();
3102
3103        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3104        let extension = state.get_extension::<MetadataPointer>().unwrap();
3105        assert_eq!(extension.authority, max_pubkey);
3106        assert_eq!(extension.metadata_address, max_pubkey);
3107        let extension = state
3108            .get_variable_len_extension::<VariableLenMintTest>()
3109            .unwrap();
3110        assert_eq!(extension, variable_len);
3111        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3112
3113        // reallocate to same
3114        let account_info = (&key, &mut data).into_account_info();
3115        let variable_len = VariableLenMintTest {
3116            data: vec![7, 6, 5, 4, 3, 2, 1],
3117        };
3118        alloc_and_serialize_variable_len_extension::<PodMint, _>(
3119            &account_info,
3120            &variable_len,
3121            true,
3122        )
3123        .unwrap();
3124
3125        let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
3126        let extension = state.get_extension::<MetadataPointer>().unwrap();
3127        assert_eq!(extension.authority, max_pubkey);
3128        assert_eq!(extension.metadata_address, max_pubkey);
3129        let extension = state
3130            .get_variable_len_extension::<VariableLenMintTest>()
3131            .unwrap();
3132        assert_eq!(extension, variable_len);
3133        assert_eq!(data.len(), state.try_get_account_len().unwrap());
3134    }
3135}