Skip to main content

session_keys_macros_attribute/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7    TypePath,
8};
9
10struct SessionArgs {
11    signer: syn::ExprAssign,
12    authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let signer = input.parse()?;
18
19        input.parse::<Token![,]>()?;
20
21        let authority = input.parse()?;
22        Ok(SessionArgs { signer, authority })
23    }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27    attr.path.is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32    V1,
33    V2,
34}
35
36/// Result of parsing the session_token field type.
37struct SessionTokenInfo {
38    token_type: SessionTokenType,
39    /// The lifetime from Account<'info, ...> — used in generated impl.
40    lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44    let Type::Path(TypePath { path, .. }) = ty else {
45        return None;
46    };
47    let Some(seg) = path.segments.last() else {
48        return None;
49    };
50    if seg.ident != "Option" {
51        return None;
52    }
53    let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
54        return None;
55    };
56    let Some(GenericArgument::Type(Type::Path(TypePath {
57        path: acct_path, ..
58    }))) = opt_args.args.first()
59    else {
60        return None;
61    };
62    let Some(acct_seg) = acct_path.segments.last() else {
63        return None;
64    };
65    if acct_seg.ident != "Account" {
66        return None;
67    }
68    let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
69        return None;
70    };
71    let mut args = acct_args.args.iter();
72    let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
73        return None;
74    };
75    let lifetime = lifetime.clone();
76    let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
77    else {
78        return None;
79    };
80
81    if let Some(st_seg) = st_path.segments.last() {
82        if st_seg.ident == "SessionToken" {
83            return Some(SessionTokenInfo { token_type: SessionTokenType::V1, lifetime });
84        } else if st_seg.ident == "SessionTokenV2" {
85            return Some(SessionTokenInfo { token_type: SessionTokenType::V2, lifetime });
86        }
87    }
88    None
89}
90
91/// Core derive implementation that supports both V1 (`SessionToken`) and V2 (`SessionTokenV2`).
92/// Auto-detects which variant to use based on the `session_token` field type.
93/// If `expected` is `Some`, validates that the detected type matches.
94fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
95    let input_parsed = parse_macro_input!(input as DeriveInput);
96
97    let fields = match input_parsed.data {
98        Data::Struct(data) => match data.fields {
99            Fields::Named(fields) => fields,
100            _ => panic!("Session trait can only be derived for structs with named fields"),
101        },
102        _ => panic!("Session trait can only be derived for structs"),
103    };
104
105    // Ensure that the struct has a session_token field
106    let session_token_field = fields
107        .named
108        .iter()
109        .find(|field| field.ident.as_ref().unwrap().to_string() == "session_token")
110        .expect("Session trait can only be derived for structs with a session_token field");
111
112    let session_token_type = &session_token_field.ty;
113    let info = get_session_token_info(session_token_type)
114        .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
115    let token_type = info.token_type;
116
117    if let Some(expected) = expected {
118        if token_type != expected {
119            return syn::Error::new_spanned(
120                &session_token_field.ty,
121                "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
122            )
123            .to_compile_error()
124            .into();
125        }
126    }
127
128    // Session Token field must have the #[session] attribute
129    let session_attr = session_token_field
130        .attrs
131        .iter()
132        .find(|attr| is_session(attr))
133        .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
134
135    let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
136
137    let session_signer = session_args.signer.right.into_token_stream();
138
139    // Session Authority
140    let session_authority = session_args.authority.right.into_token_stream();
141
142    let struct_name = &input_parsed.ident;
143    let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
144
145    // Use the lifetime extracted from the session_token field type (Account<'info, ...>).
146    // This ensures we use the exact lifetime from the field, not the struct's first lifetime.
147    let info_lifetime = info.lifetime;
148
149    let output = match token_type {
150        SessionTokenType::V1 => quote! {
151            #[automatically_derived]
152            impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
153
154                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
155                    crate::id()
156                }
157
158                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
159                    self.session_token.clone()
160                }
161
162                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
163                    self.#session_authority
164                }
165
166                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
167                    self.#session_signer.clone()
168                }
169
170            }
171        },
172        SessionTokenType::V2 => quote! {
173            #[automatically_derived]
174            impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
175
176                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
177                    crate::id()
178                }
179
180                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
181                    self.session_token.clone()
182                }
183
184                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
185                    self.#session_authority
186                }
187
188                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
189                    self.#session_signer.clone()
190                }
191
192            }
193        },
194    };
195
196    output.into()
197}
198
199/// Derive macro for the `Session` trait (V1) or `SessionV2` trait.
200/// Auto-detects based on the field type:
201///   - `Option<Account<'info, SessionToken>>` → implements `Session`
202///   - `Option<Account<'info, SessionTokenV2>>` → implements `SessionV2`
203#[proc_macro_derive(Session, attributes(session))]
204pub fn derive_session(input: TokenStream) -> TokenStream {
205    derive_impl(input, None)
206}
207
208/// Explicit V2 derive macro — same implementation as `#[derive(Session)]`,
209/// provided for clarity when using `SessionTokenV2`.
210#[proc_macro_derive(SessionV2, attributes(session))]
211pub fn derive_session_v2(input: TokenStream) -> TokenStream {
212    derive_impl(input, Some(SessionTokenType::V2))
213}
214
215struct SessionAuthArgs(syn::Expr, syn::Expr);
216
217impl Parse for SessionAuthArgs {
218    fn parse(input: ParseStream) -> syn::Result<Self> {
219        let equality_expr = input.parse()?;
220        input.parse::<Token![,]>()?;
221        let error_expr = input.parse()?;
222        Ok(SessionAuthArgs(equality_expr, error_expr))
223    }
224}
225
226#[proc_macro_attribute]
227/// Macro to check if the session (V1 or V2) or the original authority is the signer.
228/// Works with both `Session` and `SessionV2` traits.
229pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
230    let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
231
232    let input_fn = parse_macro_input!(item as syn::ItemFn);
233    let input_fn_name = input_fn.sig.ident;
234    let input_fn_vis = input_fn.vis;
235    let input_fn_block = input_fn.block;
236    let input_fn_inputs = input_fn.sig.inputs;
237    let input_fn_output = input_fn.sig.output;
238
239    let output = quote! {
240        #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
241            // Automatically generated by session_auth_or macro
242            // Import both traits so is_valid()/session_token()/session_authority() resolve
243            // regardless of which trait the derive emitted.
244            use ::session_keys::{Session as _, SessionV2 as _};
245            // BEGIN SESSION AUTH
246            // Current signer is the session signer or the original authority
247            let session_token = ctx.accounts.session_token();
248            if let Some(token) = session_token {
249                require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
250                // Checks that authority of the session is the same as authority of the original account
251                require_eq!(
252                    ctx.accounts.session_authority(),
253                    token.authority.key(),
254                    #error_ty
255                );
256            } else {
257                require!(
258                    #auth_expr,
259                    #error_ty
260                );
261            }
262            // END SESSION AUTH
263            #input_fn_block
264        }
265    };
266    output.into()
267}