tensor_toolbox/token_2022/
extension.rs

1//! Helpers for working with the token extension TLV format.
2//! Taken + adapted from https://github.com/solana-labs/solana-program-library/blob/2124f7562ed27a5f03f29c0ea0b77ef0167d5028/token/program-2022/src/extension/mod.rs
3//!
4//! The main purpose of these helpers is to graciously handle unknown extensions.
5
6#[allow(deprecated)]
7use anchor_lang::{
8    solana_program::{
9        borsh0_10::try_from_slice_unchecked, program_error::ProgramError, program_pack::Pack,
10    },
11    AnchorDeserialize,
12};
13use anchor_spl::{
14    token::spl_token::state::{Account, Multisig},
15    token_interface::spl_token_2022::extension::{
16        AccountType, BaseState, Extension, ExtensionType, Length,
17    },
18};
19use bytemuck::Pod;
20use std::mem::size_of;
21
22const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
23
24const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
25
26struct TlvIndices {
27    pub type_start: usize,
28    pub length_start: usize,
29    pub value_start: usize,
30}
31
32/// Unpack a portion of the TLV data as the desired type
33pub fn get_extension<V: Extension + Pod>(
34    tlv_data: &[u8],
35) -> core::result::Result<&V, ProgramError> {
36    bytemuck::try_from_bytes::<V>(get_extension_bytes::<V>(tlv_data)?)
37        .map_err(|_error| ProgramError::InvalidAccountData)
38}
39
40fn get_extension_bytes<V: Extension>(tlv_data: &[u8]) -> core::result::Result<&[u8], ProgramError> {
41    let TlvIndices {
42        type_start: _,
43        length_start,
44        value_start,
45    } = get_extension_indices::<V>(tlv_data)?;
46    // get_extension_indices has checked that tlv_data is long enough to include these indices
47    let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
48        .map_err(|_error| ProgramError::InvalidAccountData)?;
49    let value_end = value_start.saturating_add(usize::from(*length));
50    if tlv_data.len() < value_end {
51        return Err(ProgramError::InvalidAccountData);
52    }
53    Ok(&tlv_data[value_start..value_end])
54}
55
56fn get_extension_indices<V: Extension>(
57    tlv_data: &[u8],
58) -> core::result::Result<TlvIndices, ProgramError> {
59    let mut start_index = 0;
60    while start_index < tlv_data.len() {
61        let tlv_indices = get_tlv_indices(start_index);
62        if tlv_data.len() < tlv_indices.value_start {
63            return Err(ProgramError::InvalidAccountData);
64        }
65        let extension_type =
66            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start]);
67        // [FEBO] Make sure we don't bubble the error in case we don't recognize
68        // the extension type; the best we can do when we don't recognize the extension is
69        // to keep looking for the one we're interested in
70        if extension_type.is_ok() && extension_type.unwrap() == V::TYPE {
71            // found an instance of the extension that we're looking, return!
72            return Ok(tlv_indices);
73        }
74        let length = bytemuck::try_from_bytes::<Length>(
75            &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
76        )
77        .map_err(|_| ProgramError::InvalidArgument)?;
78        let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
79        start_index = value_end_index;
80    }
81    Err(ProgramError::InvalidAccountData)
82}
83
84/// Helper function to get the current TlvIndices from the current spot
85fn get_tlv_indices(type_start: usize) -> TlvIndices {
86    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
87    let value_start = length_start.saturating_add(size_of::<Length>());
88    TlvIndices {
89        type_start,
90        length_start,
91        value_start,
92    }
93}
94
95/// Fetches the "known" extension types from the TLV data.
96pub fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<IExtensionType>, ProgramError> {
97    let mut extension_types = vec![];
98    let mut start_index = 0;
99    while start_index < tlv_data.len() {
100        let tlv_indices = get_tlv_indices(start_index);
101        if tlv_data.len() < tlv_indices.length_start {
102            // There aren't enough bytes to store the next type, which means we
103            // got to the end. The last byte could be used during a realloc!
104            return Ok(extension_types);
105        }
106        let extension_type = u16::from_le_bytes(
107            (&tlv_data[tlv_indices.type_start..tlv_indices.length_start])
108                .try_into()
109                .map_err(|_| ProgramError::InvalidAccountData)?,
110        );
111        // we recognize the extension type, add it to the list
112        if let Ok(extension_type) = IExtensionType::try_from(extension_type) {
113            extension_types.push(extension_type);
114        }
115
116        if tlv_data.len() < tlv_indices.value_start {
117            // not enough bytes to store the length, malformed
118            return Err(ProgramError::InvalidAccountData);
119        }
120
121        let length = bytemuck::try_from_bytes::<Length>(
122            &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
123        )
124        .map_err(|_| ProgramError::InvalidAccountData)?;
125
126        let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
127        if value_end_index > tlv_data.len() {
128            // value blows past the size of the slice, malformed
129            return Err(ProgramError::InvalidAccountData);
130        }
131        start_index = value_end_index;
132    }
133    Ok(extension_types)
134}
135
136pub fn get_variable_len_extension<V: Extension + AnchorDeserialize>(
137    tlv_data: &[u8],
138) -> core::result::Result<V, ProgramError> {
139    let data = get_extension_bytes::<V>(tlv_data)?;
140    #[allow(deprecated)]
141    try_from_slice_unchecked::<V>(data).map_err(|_error| ProgramError::InvalidAccountData)
142}
143
144#[repr(u16)]
145#[derive(Debug, PartialEq)]
146pub enum IExtensionType {
147    /// [MINT] Includes an optional mint close authority
148    MintCloseAuthority = 3,
149    /// [MINT] Specifies the default Account::state for new Accounts
150    DefaultAccountState = 6,
151    /// [ACCOUNT] Indicates that the Account owner authority cannot be changed
152    ImmutableOwner = 7,
153    /// [MINT] Indicates that the tokens from this mint can't be transfered
154    NonTransferable = 9,
155    /// [ACCOUNT] Locks privileged token operations from happening via CPI
156    CpiGuard = 11,
157    /// [MINT] Includes an optional permanent delegate
158    PermanentDelegate = 12,
159    /// [ACCOUNT] Indicates that the tokens in this account belong to a non-transferable
160    /// mint
161    NonTransferableAccount = 13,
162    /// [MINT] Mint requires a CPI to a program implementing the "transfer hook"
163    /// interface
164    TransferHook = 14,
165    /// [ACCOUNT] Indicates that the tokens in this account belong to a mint with a
166    /// transfer hook
167    TransferHookAccount = 15,
168    /// [MINT] Mint contains a pointer to another account (or the same account) that
169    /// holds metadata
170    MetadataPointer = 18,
171    /// [MINT] Mint contains a pointer to another account (or the same account) that
172    /// holds group configurations
173    GroupPointer = 20,
174    /// [MINT] Mint contains token group configurations
175    TokenGroup = 21,
176    /// [MINT] Mint contains a pointer to another account (or the same account) that
177    /// holds group member configurations
178    GroupMemberPointer = 22,
179    /// [MINT] Mint contains token group member configurations
180    TokenGroupMember = 23,
181}
182
183impl IExtensionType {
184    fn get_type_len(&self) -> usize {
185        match self {
186            IExtensionType::MintCloseAuthority => 32,
187            IExtensionType::DefaultAccountState => 1,
188            IExtensionType::ImmutableOwner => 0,
189            IExtensionType::NonTransferable => 0,
190            IExtensionType::CpiGuard => 1,
191            IExtensionType::PermanentDelegate => 32,
192            IExtensionType::NonTransferableAccount => 0,
193            IExtensionType::TransferHook => 64,
194            IExtensionType::TransferHookAccount => 1,
195            IExtensionType::MetadataPointer => 64,
196            IExtensionType::GroupPointer => 64,
197            IExtensionType::TokenGroup => 72,
198            IExtensionType::GroupMemberPointer => 64,
199            IExtensionType::TokenGroupMember => 68,
200        }
201    }
202
203    /// Get the TLV length for an ExtensionType
204    ///
205    /// Fails if the extension type has a variable length
206    fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
207        Ok(add_type_and_length_to_len(self.get_type_len()))
208    }
209
210    /// Get the TLV length for a set of ExtensionTypes
211    ///
212    /// Fails if any of the extension types has a variable length
213    fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
214        // dedupe extensions
215        let mut extensions = vec![];
216        for extension_type in extension_types {
217            if !extensions.contains(&extension_type) {
218                extensions.push(extension_type);
219            }
220        }
221        extensions.iter().map(|e| e.try_get_tlv_len()).sum()
222    }
223
224    /// Get the required account data length for the given ExtensionTypes
225    ///
226    /// Fails if any of the extension types has a variable length
227    pub fn try_calculate_account_len<S: BaseState>(
228        extension_types: &[Self],
229    ) -> Result<usize, ProgramError> {
230        if extension_types.is_empty() {
231            Ok(S::LEN)
232        } else {
233            let extension_size = Self::try_get_total_tlv_len(extension_types)?;
234            let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
235            Ok(adjust_len_for_multisig(total_len))
236        }
237    }
238
239    /// Based on a set of [MINT] ExtensionTypes, get the list of
240    /// [ACCOUNT] ExtensionTypes required on InitializeAccount
241    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
242        let mut account_extension_types = vec![];
243        for extension_type in mint_extension_types {
244            match extension_type {
245                IExtensionType::NonTransferable => {
246                    account_extension_types.push(IExtensionType::NonTransferableAccount);
247                }
248                IExtensionType::TransferHook => {
249                    account_extension_types.push(IExtensionType::TransferHookAccount);
250                }
251                _ => {}
252            }
253        }
254        account_extension_types
255    }
256}
257
258impl TryFrom<u16> for IExtensionType {
259    type Error = ProgramError;
260    fn try_from(value: u16) -> Result<Self, Self::Error> {
261        let extension = match value {
262            3 => IExtensionType::MintCloseAuthority,
263            6 => IExtensionType::DefaultAccountState,
264            7 => IExtensionType::ImmutableOwner,
265            9 => IExtensionType::NonTransferable,
266            11 => IExtensionType::CpiGuard,
267            12 => IExtensionType::PermanentDelegate,
268            13 => IExtensionType::NonTransferableAccount,
269            14 => IExtensionType::TransferHook,
270            15 => IExtensionType::TransferHookAccount,
271            18 => IExtensionType::MetadataPointer,
272            20 => IExtensionType::GroupPointer,
273            21 => IExtensionType::TokenGroup,
274            22 => IExtensionType::GroupMemberPointer,
275            23 => IExtensionType::TokenGroupMember,
276            _ => return Err(ProgramError::InvalidArgument),
277        };
278        Ok(extension)
279    }
280}
281
282/// Helper function to calculate exactly how many bytes a value will take up,
283/// given the value's length
284const fn add_type_and_length_to_len(value_len: usize) -> usize {
285    value_len
286        .saturating_add(size_of::<ExtensionType>())
287        .saturating_add(size_of::<Length>())
288}
289
290const fn adjust_len_for_multisig(account_len: usize) -> usize {
291    if account_len == Multisig::LEN {
292        account_len.saturating_add(size_of::<ExtensionType>())
293    } else {
294        account_len
295    }
296}