tensor_toolbox/token_2022/
extension.rs1#[allow(deprecated)]
7use anchor_lang::{
8 solana_program::{
9 borsh0_10::try_from_slice_unchecked, program_error::ProgramError, program_pack::Pack,
10 },
11 AnchorDeserialize,
12};
13use anchor_spl::{
14 token::spl_token::state::{Account, Multisig},
15 token_interface::spl_token_2022::extension::{
16 AccountType, BaseState, Extension, ExtensionType, Length,
17 },
18};
19use bytemuck::Pod;
20use std::mem::size_of;
21
22const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
23
24const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
25
26struct TlvIndices {
27 pub type_start: usize,
28 pub length_start: usize,
29 pub value_start: usize,
30}
31
32pub fn get_extension<V: Extension + Pod>(
34 tlv_data: &[u8],
35) -> core::result::Result<&V, ProgramError> {
36 bytemuck::try_from_bytes::<V>(get_extension_bytes::<V>(tlv_data)?)
37 .map_err(|_error| ProgramError::InvalidAccountData)
38}
39
40fn get_extension_bytes<V: Extension>(tlv_data: &[u8]) -> core::result::Result<&[u8], ProgramError> {
41 let TlvIndices {
42 type_start: _,
43 length_start,
44 value_start,
45 } = get_extension_indices::<V>(tlv_data)?;
46 let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
48 .map_err(|_error| ProgramError::InvalidAccountData)?;
49 let value_end = value_start.saturating_add(usize::from(*length));
50 if tlv_data.len() < value_end {
51 return Err(ProgramError::InvalidAccountData);
52 }
53 Ok(&tlv_data[value_start..value_end])
54}
55
56fn get_extension_indices<V: Extension>(
57 tlv_data: &[u8],
58) -> core::result::Result<TlvIndices, ProgramError> {
59 let mut start_index = 0;
60 while start_index < tlv_data.len() {
61 let tlv_indices = get_tlv_indices(start_index);
62 if tlv_data.len() < tlv_indices.value_start {
63 return Err(ProgramError::InvalidAccountData);
64 }
65 let extension_type =
66 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start]);
67 if extension_type.is_ok() && extension_type.unwrap() == V::TYPE {
71 return Ok(tlv_indices);
73 }
74 let length = bytemuck::try_from_bytes::<Length>(
75 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
76 )
77 .map_err(|_| ProgramError::InvalidArgument)?;
78 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
79 start_index = value_end_index;
80 }
81 Err(ProgramError::InvalidAccountData)
82}
83
84fn get_tlv_indices(type_start: usize) -> TlvIndices {
86 let length_start = type_start.saturating_add(size_of::<ExtensionType>());
87 let value_start = length_start.saturating_add(size_of::<Length>());
88 TlvIndices {
89 type_start,
90 length_start,
91 value_start,
92 }
93}
94
95pub fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<IExtensionType>, ProgramError> {
97 let mut extension_types = vec![];
98 let mut start_index = 0;
99 while start_index < tlv_data.len() {
100 let tlv_indices = get_tlv_indices(start_index);
101 if tlv_data.len() < tlv_indices.length_start {
102 return Ok(extension_types);
105 }
106 let extension_type = u16::from_le_bytes(
107 (&tlv_data[tlv_indices.type_start..tlv_indices.length_start])
108 .try_into()
109 .map_err(|_| ProgramError::InvalidAccountData)?,
110 );
111 if let Ok(extension_type) = IExtensionType::try_from(extension_type) {
113 extension_types.push(extension_type);
114 }
115
116 if tlv_data.len() < tlv_indices.value_start {
117 return Err(ProgramError::InvalidAccountData);
119 }
120
121 let length = bytemuck::try_from_bytes::<Length>(
122 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
123 )
124 .map_err(|_| ProgramError::InvalidAccountData)?;
125
126 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
127 if value_end_index > tlv_data.len() {
128 return Err(ProgramError::InvalidAccountData);
130 }
131 start_index = value_end_index;
132 }
133 Ok(extension_types)
134}
135
136pub fn get_variable_len_extension<V: Extension + AnchorDeserialize>(
137 tlv_data: &[u8],
138) -> core::result::Result<V, ProgramError> {
139 let data = get_extension_bytes::<V>(tlv_data)?;
140 #[allow(deprecated)]
141 try_from_slice_unchecked::<V>(data).map_err(|_error| ProgramError::InvalidAccountData)
142}
143
144#[repr(u16)]
145#[derive(Debug, PartialEq)]
146pub enum IExtensionType {
147 MintCloseAuthority = 3,
149 DefaultAccountState = 6,
151 ImmutableOwner = 7,
153 NonTransferable = 9,
155 CpiGuard = 11,
157 PermanentDelegate = 12,
159 NonTransferableAccount = 13,
162 TransferHook = 14,
165 TransferHookAccount = 15,
168 MetadataPointer = 18,
171 GroupPointer = 20,
174 TokenGroup = 21,
176 GroupMemberPointer = 22,
179 TokenGroupMember = 23,
181}
182
183impl IExtensionType {
184 fn get_type_len(&self) -> usize {
185 match self {
186 IExtensionType::MintCloseAuthority => 32,
187 IExtensionType::DefaultAccountState => 1,
188 IExtensionType::ImmutableOwner => 0,
189 IExtensionType::NonTransferable => 0,
190 IExtensionType::CpiGuard => 1,
191 IExtensionType::PermanentDelegate => 32,
192 IExtensionType::NonTransferableAccount => 0,
193 IExtensionType::TransferHook => 64,
194 IExtensionType::TransferHookAccount => 1,
195 IExtensionType::MetadataPointer => 64,
196 IExtensionType::GroupPointer => 64,
197 IExtensionType::TokenGroup => 72,
198 IExtensionType::GroupMemberPointer => 64,
199 IExtensionType::TokenGroupMember => 68,
200 }
201 }
202
203 fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
207 Ok(add_type_and_length_to_len(self.get_type_len()))
208 }
209
210 fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
214 let mut extensions = vec![];
216 for extension_type in extension_types {
217 if !extensions.contains(&extension_type) {
218 extensions.push(extension_type);
219 }
220 }
221 extensions.iter().map(|e| e.try_get_tlv_len()).sum()
222 }
223
224 pub fn try_calculate_account_len<S: BaseState>(
228 extension_types: &[Self],
229 ) -> Result<usize, ProgramError> {
230 if extension_types.is_empty() {
231 Ok(S::LEN)
232 } else {
233 let extension_size = Self::try_get_total_tlv_len(extension_types)?;
234 let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
235 Ok(adjust_len_for_multisig(total_len))
236 }
237 }
238
239 pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
242 let mut account_extension_types = vec![];
243 for extension_type in mint_extension_types {
244 match extension_type {
245 IExtensionType::NonTransferable => {
246 account_extension_types.push(IExtensionType::NonTransferableAccount);
247 }
248 IExtensionType::TransferHook => {
249 account_extension_types.push(IExtensionType::TransferHookAccount);
250 }
251 _ => {}
252 }
253 }
254 account_extension_types
255 }
256}
257
258impl TryFrom<u16> for IExtensionType {
259 type Error = ProgramError;
260 fn try_from(value: u16) -> Result<Self, Self::Error> {
261 let extension = match value {
262 3 => IExtensionType::MintCloseAuthority,
263 6 => IExtensionType::DefaultAccountState,
264 7 => IExtensionType::ImmutableOwner,
265 9 => IExtensionType::NonTransferable,
266 11 => IExtensionType::CpiGuard,
267 12 => IExtensionType::PermanentDelegate,
268 13 => IExtensionType::NonTransferableAccount,
269 14 => IExtensionType::TransferHook,
270 15 => IExtensionType::TransferHookAccount,
271 18 => IExtensionType::MetadataPointer,
272 20 => IExtensionType::GroupPointer,
273 21 => IExtensionType::TokenGroup,
274 22 => IExtensionType::GroupMemberPointer,
275 23 => IExtensionType::TokenGroupMember,
276 _ => return Err(ProgramError::InvalidArgument),
277 };
278 Ok(extension)
279 }
280}
281
282const fn add_type_and_length_to_len(value_len: usize) -> usize {
285 value_len
286 .saturating_add(size_of::<ExtensionType>())
287 .saturating_add(size_of::<Length>())
288}
289
290const fn adjust_len_for_multisig(account_len: usize) -> usize {
291 if account_len == Multisig::LEN {
292 account_len.saturating_add(size_of::<ExtensionType>())
293 } else {
294 account_len
295 }
296}