Skip to main content

typhoon_context_macro/
lib.rs

1use {
2    crate::context::ParsingContext,
3    generators::*,
4    injector::FieldInjector,
5    proc_macro::TokenStream,
6    proc_macro2::TokenStream as TokenStream2,
7    quote::{format_ident, quote, ToTokens},
8    sorter::sort_accounts,
9    syn::{
10        parse_macro_input, parse_quote, visit_mut::VisitMut, Attribute, Field, Ident, ItemStruct,
11    },
12};
13
14mod context;
15mod generators;
16mod injector;
17mod remover;
18mod sorter;
19mod visitor;
20
21#[proc_macro_attribute]
22pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream {
23    let context = parse_macro_input!(item as ParsingContext);
24    let generator = match TokenGenerator::new(context) {
25        Ok(gen) => gen,
26        Err(err) => return TokenStream::from(err.into_compile_error()),
27    };
28
29    TokenStream::from(generator.into_token_stream())
30}
31
32type BumpsStruct = (ItemStruct, TokenStream2);
33
34struct TokenGenerator {
35    item_struct: ItemStruct,
36    accounts_token: Vec<TokenStream2>,
37    bumps: Option<BumpsStruct>,
38    args: Option<(Ident, Option<TokenStream2>)>,
39    needs_rent: bool,
40}
41
42impl TokenGenerator {
43    pub fn new(mut context: ParsingContext) -> Result<Self, syn::Error> {
44        sort_accounts(&mut context)?;
45
46        let global_context = GlobalContext::from_parsing_context(&context)?;
47
48        for program in &global_context.program_checks {
49            if context.accounts.iter().all(|el| el.inner_ty != *program) {
50                return Err(syn::Error::new_spanned(
51                    &context.item_struct,
52                    format!("One constraint requires including the `Program<{program}>` account."),
53                ));
54            }
55        }
56
57        let bumps = global_context.generate_bumps(&context);
58        let args = global_context.generate_args(&context);
59
60        let accounts_token = global_context
61            .accounts
62            .into_iter()
63            .map(|acc| acc.generate())
64            .collect::<Result<_, _>>()?;
65
66        Ok(TokenGenerator {
67            needs_rent: global_context.need_rent,
68            item_struct: context.item_struct,
69            accounts_token,
70            bumps,
71            args,
72        })
73    }
74}
75
76impl ToTokens for TokenGenerator {
77    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
78        let name = &self.item_struct.ident;
79        let generics = &self.item_struct.generics;
80
81        let (_, ty_generics, _) = generics.split_for_impl();
82
83        // patch the lifetime of the new context here
84        let generics = &mut generics.to_owned();
85        generics.params.push(parse_quote!('c));
86        if let Some(where_clause) = &mut generics.where_clause {
87            where_clause.predicates.push(parse_quote!('c: 'info));
88        } else {
89            generics.where_clause = Some(parse_quote!(where 'c: 'info));
90        }
91        let (impl_generics, _, where_clause) = generics.split_for_impl();
92
93        let name_list: Vec<&Ident> = self
94            .item_struct
95            .fields
96            .iter()
97            .filter_map(|f| f.ident.as_ref())
98            .collect();
99        let accounts_token = &self.accounts_token;
100        let (bumps_struct, bumps_var) = self.bumps.clone().unzip();
101
102        let mut struct_fields: Vec<&Ident> = name_list.clone();
103
104        let account_struct = &mut self.item_struct.to_owned();
105
106        let bumps_ident = format_ident!("bumps");
107        if let Some(ref bumps) = bumps_struct {
108            let name = &bumps.ident;
109            let bumps_field: Field = parse_quote!(pub #bumps_ident: #name);
110            struct_fields.push(&bumps_ident);
111            FieldInjector::new(bumps_field).visit_item_struct_mut(account_struct);
112        }
113
114        let args_ident = format_ident!("args");
115        let (args_assign, args_struct) = self.args.as_ref().map(|(name, args_struct)| {
116            let args_field: Field = parse_quote!(pub #args_ident: &'info #name);
117            struct_fields.push(&args_ident);
118            FieldInjector::new(args_field).visit_item_struct_mut(account_struct);
119
120            let args_assign = quote!(let Arg(args) = Arg::<#name>::from_entrypoint(program_id, accounts, instruction_data)?;);
121
122            (args_assign, args_struct)
123        }).unzip();
124
125        let rent = self
126            .needs_rent
127            .then_some(quote!(let rent = <Rent as Sysvar>::get()?;));
128
129        let impl_context = quote! {
130            impl #impl_generics HandlerContext<'_, 'info, 'c> for #name #ty_generics #where_clause {
131                #[inline(always)]
132                fn from_entrypoint(
133                    program_id: &Address,
134                    accounts: &mut &'info [AccountView],
135                    instruction_data: &mut &'c [u8],
136                ) -> ProgramResult<Self> {
137                    let [#(#name_list,)* rem @ ..] = accounts else {
138                        return Err(ProgramError::NotEnoughAccountKeys.into());
139                    };
140
141                    #args_assign
142                    #rent
143
144                    #(#accounts_token)*
145
146                    #bumps_var
147                    *accounts = rem;
148
149                    Ok(#name { #(#struct_fields),* })
150                }
151            }
152
153            impl #impl_generics Context for #name #ty_generics #where_clause {}
154        };
155
156        let doc = prettyplease::unparse(
157            &syn::parse2::<syn::File>(quote! {
158                #bumps_struct
159                #args_struct
160
161                #impl_context
162            })
163            .unwrap(),
164        );
165
166        let mut doc_attrs: Vec<Attribute> = parse_quote! {
167            /// # Generated
168            /// ```ignore
169            #[doc = #doc]
170            /// ```
171        };
172
173        account_struct.attrs.append(&mut doc_attrs);
174
175        let expanded = quote! {
176            #bumps_struct
177            #args_struct
178
179            #account_struct
180
181            #impl_context
182
183        };
184        expanded.to_tokens(tokens);
185    }
186}