typed_transaction_macros/
lib.rs

1extern crate proc_macro;
2use proc_macro::{TokenStream, TokenTree};
3use quote::quote;
4use syn::{parse, parse_macro_input, Data, DeriveInput, Expr, Fields, ItemStruct};
5use sha2::{Sha256, Digest};
6use heck::ToSnakeCase;
7
8#[proc_macro_derive(TypedAccounts, attributes(account))]
9pub fn from_account_metas_derive(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = input.ident;
12
13    let fields = if let Data::Struct(data_struct) = input.data {
14        match data_struct.fields {
15            Fields::Named(fields_named) => fields_named.named,
16            _ => panic!("FromAccountMetas can only be derived for structs with named fields"),
17        }
18    } else {
19        panic!("FromAccountMetas can only be derived for structs");
20    };
21
22    let mut field_checks = Vec::new();
23    let mut clones = Vec::new();
24    let mut assignments = Vec::new();
25
26    // Count the number of fields
27    let field_count = fields.len();
28
29    for (field_index, field) in fields.iter().enumerate() {
30        let field_name = &field.ident;
31        let mut is_mutable = false;
32        let mut is_signer = false;
33
34        // Check for #[account(mut)] and #[account(signer)]
35        for attr in &field.attrs {
36            if attr.path().is_ident("account") {
37                // Parse the attribute's nested meta items using `parse_nested_meta`
38                attr.parse_nested_meta(|meta| {
39                    if let Some(ident) = meta.path.get_ident() {
40                        match ident.to_string().as_str() {
41                            "mut" => is_mutable = true,
42                            "signer" => is_signer = true,
43                            _ => panic!("Invalid attribute: {}", ident),
44                        }
45                    } else {
46                        panic!("Invalid attribute format");
47                    }
48                    Ok(())
49                }).unwrap();
50            }
51        }
52
53        // Generate checks for the AccountMeta fields
54        let check_writable = if is_mutable {
55            quote! {
56                if !account_metas[#field_index].is_writable {
57                    return Err(anchor_lang::error::ErrorCode::ConstraintMut.into());
58                }
59            }
60        } else {
61            quote! {}
62        };
63
64        let check_signer = if is_signer {
65            quote! {
66                if !account_metas[#field_index].is_signer {
67                    return Err(anchor_lang::error::ErrorCode::ConstraintSigner.into());
68                }
69            }
70        } else {
71            quote! {}
72        };
73
74        // Only add checks if they are non-empty
75        if !check_writable.is_empty() || !check_signer.is_empty() {
76            field_checks.push(quote! {
77                #check_writable
78                #check_signer
79            });
80        }
81
82        clones.push(quote! {
83            #field_name: self.#field_name.clone()
84        });
85
86        assignments.push(quote! {
87            #field_name: account_metas[#field_index].clone()
88        });
89    }
90
91    // Pass the literal value of `field_count` into the generated code
92    let expanded = quote! {
93        impl FromAccountMetas for #name {
94            fn from_account_metas(account_metas: &[AccountMeta]) -> Result<Self> {
95                if account_metas.len() != #field_count {
96                    return Err(anchor_lang::error::ErrorCode::ConstraintSigner.into());
97                }
98
99                #(#field_checks)*
100
101                Ok(Self {
102                    #(#assignments),*  // Ensure no extra commas
103                })
104            }
105        }
106
107        impl Clone for #name {
108            fn clone(&self) -> Self {
109                Self { 
110                    #(#clones),*
111                }
112            }   
113        }
114    };
115
116    TokenStream::from(expanded)
117}
118
119
120fn parse_instruction_attribute(attr: TokenStream) -> (Option<Vec<u8>>, Option<proc_macro2::TokenStream>) {
121    let mut tokens = attr.into_iter();
122    let mut discriminator_value = None;
123    let mut owner_value = None;
124
125    while let Some(token) = tokens.next() {
126        match token {
127            // Look for the "discriminator" or "owner" identifier
128            TokenTree::Ident(ident) if ident.to_string() == "discriminator" => {
129                // Expect '=' next
130                if let Some(TokenTree::Punct(punct)) = tokens.next() {
131                    if punct.as_char() == '=' {
132                        // Now look for the value, which should be in a group (array of literals)
133                        if let Some(TokenTree::Group(group)) = tokens.next() {
134                            let mut group_tokens = group.stream().into_iter();
135                            let mut values = vec![];
136
137                            while let Some(TokenTree::Literal(lit)) = group_tokens.next() {
138                                if let Ok(parsed_u8) = lit.to_string().trim_start_matches("0x").parse::<u8>() {
139                                    values.push(parsed_u8);
140                                }
141                            }
142                            discriminator_value = Some(values);
143                        }
144                    }
145                }
146            }
147            TokenTree::Ident(ident) if ident.to_string() == "owner" => {
148                // Expect '=' next
149                if let Some(TokenTree::Punct(punct)) = tokens.next() {
150                    if punct.as_char() == '=' {
151                        // Collect all tokens until we hit a comma or another punctuation
152                        let mut owner_tokens = TokenStream::new();
153                        while let Some(next_token) = tokens.next() {
154                            match &next_token {
155                                TokenTree::Punct(punct) if punct.as_char() == ',' => break,
156                                _ => owner_tokens.extend(Some(next_token)),
157                            }
158                        }
159
160                        // Parse the accumulated tokens as a `syn::Expr`
161                        let owner_path_expr: Expr = parse(owner_tokens.into()).expect("Expected a valid owner expression");
162                        owner_value = Some(quote! { #owner_path_expr });
163                    }
164                }
165            }
166            _ => {}
167        }
168    }
169
170    // Output the parsed value
171    (discriminator_value, owner_value)
172}
173
174
175#[proc_macro_attribute]
176pub fn typed_instruction(attr: TokenStream, item: TokenStream) -> TokenStream {
177    // Parse the struct where the attribute is applied
178    let input = parse_macro_input!(item as ItemStruct);
179    let name = &input.ident;
180    let struct_name_snake_case = name.to_string().to_snake_case(); // Convert the struct name to snake_case
181
182    let (discriminator, owner) = parse_instruction_attribute(attr);
183    // Parse the attribute for custom discriminator, e.g. `#[typed_instruction(discriminator = "...")]`
184    let discriminator = match discriminator {
185        Some(d) => quote! { &[#(#d),*] },
186        None => {
187            let default_discriminator = {
188                let mut hasher = Sha256::new();
189                hasher.update(format!("global:{}", struct_name_snake_case));
190                let result = hasher.finalize();
191                result[..8].to_vec() // Truncate to the first 8 bytes
192            };
193            let bytes: Vec<_> = default_discriminator.iter().map(|byte| quote! { #byte }).collect();
194            quote! { &[#(#bytes),*] }
195        },
196    };
197
198    let owner = match owner {
199        Some(o) => quote! {
200            impl InstructionOwner for #name {
201                fn check_owner(pubkey: &Pubkey) -> Result<()> {
202                    match pubkey.eq(&#o) {
203                        true => Ok(()),
204                        false => Err(anchor_lang::error::ErrorCode::IdlInstructionInvalidProgram.into())
205                    }
206                }
207            }
208        },
209        None => quote! {
210            impl InstructionOwner for #name {
211                fn check_owner(pubkey: &Pubkey) -> Result<()> {
212                    Ok(())
213                }
214            }
215        },
216    };
217
218    // Generate the output with `BorshDeserialize` derive and `VariableDiscriminator` trait implementation
219    let expanded = quote! {
220        #[derive(BorshDeserialize, Debug, Clone)]
221        #input
222
223        // Implement the VariableDiscriminator trait
224        impl VariableDiscriminator for #name {
225            const DISCRIMINATOR: &'static [u8] = #discriminator;
226        }
227
228        #owner
229
230        // Implement the try_deserialize function for the struct
231        impl DeserializeWithDiscriminator for #name {
232            fn try_deserialize(bytes: &[u8]) -> Result<Self> {
233
234                // Check that the byte array is at least as long as the discriminator
235                if bytes.len() < Self::DISCRIMINATOR.len() {
236                    return Err(anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into());
237                }
238
239                // Check if the discriminator matches
240                if &bytes[..Self::DISCRIMINATOR.len()] != Self::DISCRIMINATOR {
241                    return Err(anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into());
242                }
243
244                // Deserialize the remaining bytes
245                let data = &bytes[Self::DISCRIMINATOR.len()..];
246                Self::try_from_slice(data).map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into())
247            }
248        }
249    };
250    TokenStream::from(expanded)
251}
252
253
254#[proc_macro_derive(FromSignedTransaction)]
255pub fn from_signed_transaction_derive(input: TokenStream) -> TokenStream {
256    let input = parse_macro_input!(input as DeriveInput);
257    let name = input.ident;
258
259    let fields = if let Data::Struct(data_struct) = input.data {
260        match data_struct.fields {
261            Fields::Named(fields_named) => fields_named.named,
262            _ => panic!("FromAccountMetas can only be derived for structs with named fields"),
263        }
264    } else {
265        panic!("FromAccountMetas can only be derived for structs");
266    };
267
268    let mut assignments = Vec::new();
269
270    let mut clones = Vec::new();
271
272    // Count the number of fields
273    let mut field_index = 0usize;
274
275    for field in fields.iter() {
276        let field_name = &field.ident;
277
278        if let Some(name) = field_name {
279            if ["header", "recent_blockhash"].contains(&name.to_string().as_str()) {
280                continue;
281            }
282        }
283
284        assignments.push(quote! {
285            #field_name: TypedInstruction::try_from(&value.instructions[#field_index])?
286        });
287        clones.push(quote! {
288            #field_name: self.#field_name.clone()
289        });
290        field_index += 1;
291    }
292
293    // Pass the literal value of `field_count` into the generated code
294    let expanded = quote! {
295        impl Clone for #name {
296            fn clone(&self) -> Self {
297                Self { 
298                    header: self.header.clone(), 
299                    recent_blockhash: self.recent_blockhash.clone(),
300                    #(#clones),*
301                }
302            }
303        }
304
305        impl Discriminator for #name {
306            const DISCRIMINATOR: [u8; 8] = [0u8;8];
307        }
308
309        impl TryFrom<SignedTransaction> for #name {
310
311            type Error = anchor_lang::error::Error;
312
313            fn try_from(value: SignedTransaction) -> Result<#name> {
314                Ok(#name {
315                    header: value.header,
316                    recent_blockhash: value.recent_blockhash,
317                    #(#assignments),*
318                })
319            }
320        }
321
322        impl Owner for #name {
323            fn owner() -> Pubkey {
324                anchor_lang::solana_program::sysvar::ID
325            }
326        }
327        
328        impl AccountDeserialize for #name {
329            fn try_deserialize(buf: &mut &[u8]) -> Result<Self> {
330                Self::try_deserialize_unchecked(buf)
331            }
332        
333            fn try_deserialize_unchecked(buf: &mut &[u8]) -> Result<Self> {
334                Self::try_from(SignedTransaction::try_deserialize_transaction(buf).map_err(|_| ProgramError::InvalidInstructionData)?)
335            }
336        }
337        
338        impl AccountSerialize for #name {
339        
340        }
341
342        #[cfg(feature = "idl-build")]
343        impl anchor_lang::IdlBuild for #name {}
344    };
345
346    TokenStream::from(expanded)
347}