session_keys_macros_attribute/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7    TypePath,
8};
9
10struct SessionArgs {
11    signer: syn::ExprAssign,
12    authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let signer = input.parse()?;
18
19        input.parse::<Token![,]>()?;
20
21        let authority = input.parse()?;
22        Ok(SessionArgs { signer, authority })
23    }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27    attr.path.is_ident("session")
28}
29
30fn is_option_account_sessiontoken(ty: &Type) -> bool {
31    let Type::Path(TypePath { path, .. }) = ty else {
32        return false;
33    };
34    let Some(seg) = path.segments.first() else {
35        return false;
36    };
37    if seg.ident != "Option" {
38        return false;
39    }
40    let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
41        return false;
42    };
43    let Some(GenericArgument::Type(Type::Path(TypePath {
44        path: acct_path, ..
45    }))) = opt_args.args.first()
46    else {
47        return false;
48    };
49    let Some(acct_seg) = acct_path.segments.first() else {
50        return false;
51    };
52    if acct_seg.ident != "Account" {
53        return false;
54    }
55    let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
56        return false;
57    };
58    let mut args = acct_args.args.iter();
59    let Some(GenericArgument::Lifetime(_)) = args.next() else {
60        return false;
61    };
62    let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
63    else {
64        return false;
65    };
66    st_path.segments.len() == 1 && st_path.segments[0].ident == "SessionToken"
67}
68
69// Macro to derive Session Trait
70#[proc_macro_derive(Session, attributes(session))]
71pub fn derive(input: TokenStream) -> TokenStream {
72    let input_parsed = parse_macro_input!(input as DeriveInput);
73
74    let fields = match input_parsed.data {
75        Data::Struct(data) => match data.fields {
76            Fields::Named(fields) => fields,
77            _ => panic!("Session trait can only be derived for structs with named fields"),
78        },
79        _ => panic!("Session trait can only be derived for structs"),
80    };
81
82    // Ensure that the struct has a session_token field
83    let session_token_field = fields
84        .named
85        .iter()
86        .find(|field| *field.ident.as_ref().unwrap() == "session_token")
87        .expect("Session trait can only be derived for structs with a session_token field");
88    {
89        let session_token_type = &session_token_field.ty;
90        assert!(is_option_account_sessiontoken(session_token_type), "Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>>");
91    }
92
93    // Session Token field must have the #[session] attribute
94    let session_attr = session_token_field
95        .attrs
96        .iter()
97        .find(|attr| is_session(attr))
98        .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
99
100    let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
101
102    let session_signer = session_args.signer.right.into_token_stream();
103
104    // Session Authority
105    let session_authority = session_args.authority.right.into_token_stream();
106
107    let struct_name = &input_parsed.ident;
108    let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
109
110    let output = quote! {
111
112        #[automatically_derived]
113        impl #impl_generics Session #ty_generics for #struct_name #ty_generics #where_clause {
114
115            // Target Program
116            fn target_program(&self) -> Pubkey {
117                crate::id()
118            }
119
120            // Session Token
121            fn session_token(&self) -> Option<Account<'info, SessionToken>> {
122                self.session_token.clone()
123            }
124
125            // Session Authority
126            fn session_authority(&self) -> Pubkey {
127                self.#session_authority
128            }
129
130            // Session Signer
131            fn session_signer(&self) -> Signer<'info> {
132                self.#session_signer.clone()
133            }
134
135        }
136    };
137
138    output.into()
139}
140
141struct SessionAuthArgs(syn::Expr, syn::Expr);
142
143impl Parse for SessionAuthArgs {
144    fn parse(input: ParseStream) -> syn::Result<Self> {
145        let equality_expr = input.parse()?;
146        input.parse::<Token![,]>()?;
147        let error_expr = input.parse()?;
148        Ok(SessionAuthArgs(equality_expr, error_expr))
149    }
150}
151
152#[proc_macro_attribute]
153/// Macro to check if the session or the original authority is the signer
154pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
155    let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
156
157    let input_fn = parse_macro_input!(item as syn::ItemFn);
158    let input_fn_name = input_fn.sig.ident;
159    let input_fn_vis = input_fn.vis;
160    let input_fn_block = input_fn.block;
161    let input_fn_inputs = input_fn.sig.inputs;
162    let input_fn_output = input_fn.sig.output;
163
164    let output = quote! {
165        #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
166            // Automatically generated by session_auth_or macro
167            // BEGIN SESSION AUTH
168            // Current signer is the session signer or the original authority
169            let session_token = ctx.accounts.session_token();
170            if let Some(token) = session_token {
171                require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
172                // Checks that authority of the session is the same as authority of the original account
173                require_eq!(
174                    ctx.accounts.session_authority(),
175                    token.authority.key(),
176                    #error_ty
177                );
178            } else {
179                require!(
180                    #auth_expr,
181                    #error_ty
182                );
183            }
184            // END SESSION AUTH
185            #input_fn_block
186        }
187    };
188    output.into()
189}