Skip to main content

session_keys_macros_attribute/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7    TypePath,
8};
9
10struct SessionArgs {
11    signer: syn::ExprAssign,
12    authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let signer = input.parse()?;
18
19        input.parse::<Token![,]>()?;
20
21        let authority = input.parse()?;
22        Ok(SessionArgs { signer, authority })
23    }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27    attr.path().is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32    V1,
33    V2,
34}
35
36/// Result of parsing the session_token field type.
37struct SessionTokenInfo {
38    token_type: SessionTokenType,
39    /// The lifetime from Account<'info, ...> — used in generated impl.
40    lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44    let Type::Path(TypePath { path, .. }) = ty else {
45        return None;
46    };
47    let seg = path.segments.last()?;
48    if seg.ident != "Option" {
49        return None;
50    }
51    let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
52        return None;
53    };
54    let Some(GenericArgument::Type(Type::Path(TypePath {
55        path: acct_path, ..
56    }))) = opt_args.args.first()
57    else {
58        return None;
59    };
60    let acct_seg = acct_path.segments.last()?;
61    if acct_seg.ident != "Account" {
62        return None;
63    }
64    let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
65        return None;
66    };
67    let mut args = acct_args.args.iter();
68    let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
69        return None;
70    };
71    let lifetime = lifetime.clone();
72    let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
73    else {
74        return None;
75    };
76
77    if let Some(st_seg) = st_path.segments.last() {
78        if st_seg.ident == "SessionToken" {
79            return Some(SessionTokenInfo { token_type: SessionTokenType::V1, lifetime });
80        } else if st_seg.ident == "SessionTokenV2" {
81            return Some(SessionTokenInfo { token_type: SessionTokenType::V2, lifetime });
82        }
83    }
84    None
85}
86
87/// Core derive implementation that supports both V1 (`SessionToken`) and V2 (`SessionTokenV2`).
88/// Auto-detects which variant to use based on the `session_token` field type.
89/// If `expected` is `Some`, validates that the detected type matches.
90fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
91    let input_parsed = parse_macro_input!(input as DeriveInput);
92
93    let fields = match input_parsed.data {
94        Data::Struct(data) => match data.fields {
95            Fields::Named(fields) => fields,
96            _ => panic!("Session trait can only be derived for structs with named fields"),
97        },
98        _ => panic!("Session trait can only be derived for structs"),
99    };
100
101    // Ensure that the struct has a session_token field
102    let session_token_field = fields
103        .named
104        .iter()
105        .find(|field| *field.ident.as_ref().unwrap() == "session_token")
106        .expect("Session trait can only be derived for structs with a session_token field");
107
108    let session_token_type = &session_token_field.ty;
109    let info = get_session_token_info(session_token_type)
110        .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
111    let token_type = info.token_type;
112
113    if let Some(expected) = expected {
114        if token_type != expected {
115            return syn::Error::new_spanned(
116                &session_token_field.ty,
117                "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
118            )
119            .to_compile_error()
120            .into();
121        }
122    }
123
124    // Session Token field must have the #[session] attribute
125    let session_attr = session_token_field
126        .attrs
127        .iter()
128        .find(|attr| is_session(attr))
129        .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
130
131    let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
132
133    let session_signer = session_args.signer.right.into_token_stream();
134
135    // Session Authority
136    let session_authority = session_args.authority.right.into_token_stream();
137
138    let struct_name = &input_parsed.ident;
139    let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
140
141    // Use the lifetime extracted from the session_token field type (Account<'info, ...>).
142    // This ensures we use the exact lifetime from the field, not the struct's first lifetime.
143    let info_lifetime = info.lifetime;
144
145    let output = match token_type {
146        SessionTokenType::V1 => quote! {
147            #[automatically_derived]
148            impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
149
150                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
151                    crate::id()
152                }
153
154                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
155                    self.session_token.clone()
156                }
157
158                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
159                    self.#session_authority
160                }
161
162                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
163                    self.#session_signer.clone()
164                }
165
166            }
167        },
168        SessionTokenType::V2 => quote! {
169            #[automatically_derived]
170            impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
171
172                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
173                    crate::id()
174                }
175
176                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
177                    self.session_token.clone()
178                }
179
180                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
181                    self.#session_authority
182                }
183
184                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
185                    self.#session_signer.clone()
186                }
187
188            }
189        },
190    };
191
192    output.into()
193}
194
195/// Derive macro for the `Session` trait (V1) or `SessionV2` trait.
196/// Auto-detects based on the field type:
197///   - `Option<Account<'info, SessionToken>>` → implements `Session`
198///   - `Option<Account<'info, SessionTokenV2>>` → implements `SessionV2`
199#[proc_macro_derive(Session, attributes(session))]
200pub fn derive_session(input: TokenStream) -> TokenStream {
201    derive_impl(input, None)
202}
203
204/// Explicit V2 derive macro — same implementation as `#[derive(Session)]`,
205/// provided for clarity when using `SessionTokenV2`.
206#[proc_macro_derive(SessionV2, attributes(session))]
207pub fn derive_session_v2(input: TokenStream) -> TokenStream {
208    derive_impl(input, Some(SessionTokenType::V2))
209}
210
211struct SessionAuthArgs(syn::Expr, syn::Expr);
212
213impl Parse for SessionAuthArgs {
214    fn parse(input: ParseStream) -> syn::Result<Self> {
215        let equality_expr = input.parse()?;
216        input.parse::<Token![,]>()?;
217        let error_expr = input.parse()?;
218        Ok(SessionAuthArgs(equality_expr, error_expr))
219    }
220}
221
222#[proc_macro_attribute]
223/// Macro to check if the session (V1 or V2) or the original authority is the signer.
224/// Works with both `Session` and `SessionV2` traits.
225pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
226    let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
227
228    let input_fn = parse_macro_input!(item as syn::ItemFn);
229    let input_fn_name = input_fn.sig.ident;
230    let input_fn_vis = input_fn.vis;
231    let input_fn_block = input_fn.block;
232    let input_fn_inputs = input_fn.sig.inputs;
233    let input_fn_output = input_fn.sig.output;
234
235    let output = quote! {
236        #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
237            // Automatically generated by session_auth_or macro
238            // Import both traits so is_valid()/session_token()/session_authority() resolve
239            // regardless of which trait the derive emitted.
240            use ::session_keys::{Session as _, SessionV2 as _};
241            // BEGIN SESSION AUTH
242            // Current signer is the session signer or the original authority
243            let session_token = ctx.accounts.session_token();
244            if let Some(token) = session_token {
245                require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
246                // Checks that authority of the session is the same as authority of the original account
247                require_eq!(
248                    ctx.accounts.session_authority(),
249                    token.authority.key(),
250                    #error_ty
251                );
252            } else {
253                require!(
254                    #auth_expr,
255                    #error_ty
256                );
257            }
258            // END SESSION AUTH
259            #input_fn_block
260        }
261    };
262    output.into()
263}