typed_transaction_macros/
lib.rs1extern crate proc_macro;
2use proc_macro::{TokenStream, TokenTree};
3use quote::quote;
4use syn::{parse, parse_macro_input, Data, DeriveInput, Expr, Fields, ItemStruct};
5use sha2::{Sha256, Digest};
6use heck::ToSnakeCase;
7
8#[proc_macro_derive(TypedAccounts, attributes(account))]
9pub fn from_account_metas_derive(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 let name = input.ident;
12
13 let fields = if let Data::Struct(data_struct) = input.data {
14 match data_struct.fields {
15 Fields::Named(fields_named) => fields_named.named,
16 _ => panic!("FromAccountMetas can only be derived for structs with named fields"),
17 }
18 } else {
19 panic!("FromAccountMetas can only be derived for structs");
20 };
21
22 let mut field_checks = Vec::new();
23 let mut clones = Vec::new();
24 let mut assignments = Vec::new();
25
26 let field_count = fields.len();
28
29 for (field_index, field) in fields.iter().enumerate() {
30 let field_name = &field.ident;
31 let mut is_mutable = false;
32 let mut is_signer = false;
33
34 for attr in &field.attrs {
36 if attr.path().is_ident("account") {
37 attr.parse_nested_meta(|meta| {
39 if let Some(ident) = meta.path.get_ident() {
40 match ident.to_string().as_str() {
41 "mut" => is_mutable = true,
42 "signer" => is_signer = true,
43 _ => panic!("Invalid attribute: {}", ident),
44 }
45 } else {
46 panic!("Invalid attribute format");
47 }
48 Ok(())
49 }).unwrap();
50 }
51 }
52
53 let check_writable = if is_mutable {
55 quote! {
56 if !account_metas[#field_index].is_writable {
57 return Err(anchor_lang::error::ErrorCode::ConstraintMut.into());
58 }
59 }
60 } else {
61 quote! {}
62 };
63
64 let check_signer = if is_signer {
65 quote! {
66 if !account_metas[#field_index].is_signer {
67 return Err(anchor_lang::error::ErrorCode::ConstraintSigner.into());
68 }
69 }
70 } else {
71 quote! {}
72 };
73
74 if !check_writable.is_empty() || !check_signer.is_empty() {
76 field_checks.push(quote! {
77 #check_writable
78 #check_signer
79 });
80 }
81
82 clones.push(quote! {
83 #field_name: self.#field_name.clone()
84 });
85
86 assignments.push(quote! {
87 #field_name: account_metas[#field_index].clone()
88 });
89 }
90
91 let expanded = quote! {
93 impl FromAccountMetas for #name {
94 fn from_account_metas(account_metas: &[AccountMeta]) -> Result<Self> {
95 if account_metas.len() != #field_count {
96 return Err(anchor_lang::error::ErrorCode::ConstraintSigner.into());
97 }
98
99 #(#field_checks)*
100
101 Ok(Self {
102 #(#assignments),* })
104 }
105 }
106
107 impl Clone for #name {
108 fn clone(&self) -> Self {
109 Self {
110 #(#clones),*
111 }
112 }
113 }
114 };
115
116 TokenStream::from(expanded)
117}
118
119
120fn parse_instruction_attribute(attr: TokenStream) -> (Option<Vec<u8>>, Option<proc_macro2::TokenStream>) {
121 let mut tokens = attr.into_iter();
122 let mut discriminator_value = None;
123 let mut owner_value = None;
124
125 while let Some(token) = tokens.next() {
126 match token {
127 TokenTree::Ident(ident) if ident.to_string() == "discriminator" => {
129 if let Some(TokenTree::Punct(punct)) = tokens.next() {
131 if punct.as_char() == '=' {
132 if let Some(TokenTree::Group(group)) = tokens.next() {
134 let mut group_tokens = group.stream().into_iter();
135 let mut values = vec![];
136
137 while let Some(TokenTree::Literal(lit)) = group_tokens.next() {
138 if let Ok(parsed_u8) = lit.to_string().trim_start_matches("0x").parse::<u8>() {
139 values.push(parsed_u8);
140 }
141 }
142 discriminator_value = Some(values);
143 }
144 }
145 }
146 }
147 TokenTree::Ident(ident) if ident.to_string() == "owner" => {
148 if let Some(TokenTree::Punct(punct)) = tokens.next() {
150 if punct.as_char() == '=' {
151 let mut owner_tokens = TokenStream::new();
153 while let Some(next_token) = tokens.next() {
154 match &next_token {
155 TokenTree::Punct(punct) if punct.as_char() == ',' => break,
156 _ => owner_tokens.extend(Some(next_token)),
157 }
158 }
159
160 let owner_path_expr: Expr = parse(owner_tokens.into()).expect("Expected a valid owner expression");
162 owner_value = Some(quote! { #owner_path_expr });
163 }
164 }
165 }
166 _ => {}
167 }
168 }
169
170 (discriminator_value, owner_value)
172}
173
174
175#[proc_macro_attribute]
176pub fn typed_instruction(attr: TokenStream, item: TokenStream) -> TokenStream {
177 let input = parse_macro_input!(item as ItemStruct);
179 let name = &input.ident;
180 let struct_name_snake_case = name.to_string().to_snake_case(); let (discriminator, owner) = parse_instruction_attribute(attr);
183 let discriminator = match discriminator {
185 Some(d) => quote! { &[#(#d),*] },
186 None => {
187 let default_discriminator = {
188 let mut hasher = Sha256::new();
189 hasher.update(format!("global:{}", struct_name_snake_case));
190 let result = hasher.finalize();
191 result[..8].to_vec() };
193 let bytes: Vec<_> = default_discriminator.iter().map(|byte| quote! { #byte }).collect();
194 quote! { &[#(#bytes),*] }
195 },
196 };
197
198 let owner = match owner {
199 Some(o) => quote! {
200 impl InstructionOwner for #name {
201 fn check_owner(pubkey: &Pubkey) -> Result<()> {
202 match pubkey.eq(&#o) {
203 true => Ok(()),
204 false => Err(anchor_lang::error::ErrorCode::IdlInstructionInvalidProgram.into())
205 }
206 }
207 }
208 },
209 None => quote! {
210 impl InstructionOwner for #name {
211 fn check_owner(pubkey: &Pubkey) -> Result<()> {
212 Ok(())
213 }
214 }
215 },
216 };
217
218 let expanded = quote! {
220 #[derive(BorshDeserialize, Debug, Clone)]
221 #input
222
223 impl VariableDiscriminator for #name {
225 const DISCRIMINATOR: &'static [u8] = #discriminator;
226 }
227
228 #owner
229
230 impl DeserializeWithDiscriminator for #name {
232 fn try_deserialize(bytes: &[u8]) -> Result<Self> {
233
234 if bytes.len() < Self::DISCRIMINATOR.len() {
236 return Err(anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into());
237 }
238
239 if &bytes[..Self::DISCRIMINATOR.len()] != Self::DISCRIMINATOR {
241 return Err(anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into());
242 }
243
244 let data = &bytes[Self::DISCRIMINATOR.len()..];
246 Self::try_from_slice(data).map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotDeserialize.into())
247 }
248 }
249 };
250 TokenStream::from(expanded)
251}
252
253
254#[proc_macro_derive(FromSignedTransaction)]
255pub fn from_signed_transaction_derive(input: TokenStream) -> TokenStream {
256 let input = parse_macro_input!(input as DeriveInput);
257 let name = input.ident;
258
259 let fields = if let Data::Struct(data_struct) = input.data {
260 match data_struct.fields {
261 Fields::Named(fields_named) => fields_named.named,
262 _ => panic!("FromAccountMetas can only be derived for structs with named fields"),
263 }
264 } else {
265 panic!("FromAccountMetas can only be derived for structs");
266 };
267
268 let mut assignments = Vec::new();
269
270 let mut clones = Vec::new();
271
272 let mut field_index = 0usize;
274
275 for field in fields.iter() {
276 let field_name = &field.ident;
277
278 if let Some(name) = field_name {
279 if ["header", "recent_blockhash"].contains(&name.to_string().as_str()) {
280 continue;
281 }
282 }
283
284 assignments.push(quote! {
285 #field_name: TypedInstruction::try_from(&value.instructions[#field_index])?
286 });
287 clones.push(quote! {
288 #field_name: self.#field_name.clone()
289 });
290 field_index += 1;
291 }
292
293 let expanded = quote! {
295 impl Clone for #name {
296 fn clone(&self) -> Self {
297 Self {
298 header: self.header.clone(),
299 recent_blockhash: self.recent_blockhash.clone(),
300 #(#clones),*
301 }
302 }
303 }
304
305 impl Discriminator for #name {
306 const DISCRIMINATOR: [u8; 8] = [0u8;8];
307 }
308
309 impl TryFrom<SignedTransaction> for #name {
310
311 type Error = anchor_lang::error::Error;
312
313 fn try_from(value: SignedTransaction) -> Result<#name> {
314 Ok(#name {
315 header: value.header,
316 recent_blockhash: value.recent_blockhash,
317 #(#assignments),*
318 })
319 }
320 }
321
322 impl Owner for #name {
323 fn owner() -> Pubkey {
324 anchor_lang::solana_program::sysvar::ID
325 }
326 }
327
328 impl AccountDeserialize for #name {
329 fn try_deserialize(buf: &mut &[u8]) -> Result<Self> {
330 Self::try_deserialize_unchecked(buf)
331 }
332
333 fn try_deserialize_unchecked(buf: &mut &[u8]) -> Result<Self> {
334 Self::try_from(SignedTransaction::try_deserialize_transaction(buf).map_err(|_| ProgramError::InvalidInstructionData)?)
335 }
336 }
337
338 impl AccountSerialize for #name {
339
340 }
341
342 #[cfg(feature = "idl-build")]
343 impl anchor_lang::IdlBuild for #name {}
344 };
345
346 TokenStream::from(expanded)
347}