Skip to main content

session_keys_macros_attribute/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7    TypePath,
8};
9
10struct SessionArgs {
11    signer: syn::ExprAssign,
12    authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let signer = input.parse()?;
18
19        input.parse::<Token![,]>()?;
20
21        let authority = input.parse()?;
22        Ok(SessionArgs { signer, authority })
23    }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27    attr.path().is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32    V1,
33    V2,
34}
35
36/// Result of parsing the session_token field type.
37struct SessionTokenInfo {
38    token_type: SessionTokenType,
39    /// The lifetime from Account<'info, ...> — used in generated impl.
40    lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44    let Type::Path(TypePath { path, .. }) = ty else {
45        return None;
46    };
47    let seg = path.segments.last()?;
48    if seg.ident != "Option" {
49        return None;
50    }
51    let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
52        return None;
53    };
54    let Some(GenericArgument::Type(Type::Path(TypePath {
55        path: acct_path, ..
56    }))) = opt_args.args.first()
57    else {
58        return None;
59    };
60    let acct_seg = acct_path.segments.last()?;
61    if acct_seg.ident != "Account" {
62        return None;
63    }
64    let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
65        return None;
66    };
67    let mut args = acct_args.args.iter();
68    let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
69        return None;
70    };
71    let lifetime = lifetime.clone();
72    let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
73    else {
74        return None;
75    };
76
77    if let Some(st_seg) = st_path.segments.last() {
78        if st_seg.ident == "SessionToken" {
79            return Some(SessionTokenInfo {
80                token_type: SessionTokenType::V1,
81                lifetime,
82            });
83        } else if st_seg.ident == "SessionTokenV2" {
84            return Some(SessionTokenInfo {
85                token_type: SessionTokenType::V2,
86                lifetime,
87            });
88        }
89    }
90    None
91}
92
93/// Core derive implementation that supports both V1 (`SessionToken`) and V2 (`SessionTokenV2`).
94/// Auto-detects which variant to use based on the `session_token` field type.
95/// If `expected` is `Some`, validates that the detected type matches.
96fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
97    let input_parsed = parse_macro_input!(input as DeriveInput);
98
99    let fields = match input_parsed.data {
100        Data::Struct(data) => match data.fields {
101            Fields::Named(fields) => fields,
102            _ => panic!("Session trait can only be derived for structs with named fields"),
103        },
104        _ => panic!("Session trait can only be derived for structs"),
105    };
106
107    // Ensure that the struct has a session_token field
108    let session_token_field = fields
109        .named
110        .iter()
111        .find(|field| *field.ident.as_ref().unwrap() == "session_token")
112        .expect("Session trait can only be derived for structs with a session_token field");
113
114    let session_token_type = &session_token_field.ty;
115    let info = get_session_token_info(session_token_type)
116        .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
117    let token_type = info.token_type;
118
119    if let Some(expected) = expected {
120        if token_type != expected {
121            return syn::Error::new_spanned(
122                &session_token_field.ty,
123                "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
124            )
125            .to_compile_error()
126            .into();
127        }
128    }
129
130    // Session Token field must have the #[session] attribute
131    let session_attr = session_token_field
132        .attrs
133        .iter()
134        .find(|attr| is_session(attr))
135        .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
136
137    let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
138
139    let session_signer = session_args.signer.right.into_token_stream();
140
141    // Session Authority
142    let session_authority = session_args.authority.right.into_token_stream();
143
144    let struct_name = &input_parsed.ident;
145    let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
146
147    // Use the lifetime extracted from the session_token field type (Account<'info, ...>).
148    // This ensures we use the exact lifetime from the field, not the struct's first lifetime.
149    let info_lifetime = info.lifetime;
150
151    let output = match token_type {
152        SessionTokenType::V1 => quote! {
153            #[automatically_derived]
154            impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
155
156                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
157                    crate::id()
158                }
159
160                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
161                    self.session_token.clone()
162                }
163
164                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
165                    self.#session_authority
166                }
167
168                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
169                    self.#session_signer.clone()
170                }
171
172            }
173        },
174        SessionTokenType::V2 => quote! {
175            #[automatically_derived]
176            impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
177
178                fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
179                    crate::id()
180                }
181
182                fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
183                    self.session_token.clone()
184                }
185
186                fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
187                    self.#session_authority
188                }
189
190                fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
191                    self.#session_signer.clone()
192                }
193
194            }
195        },
196    };
197
198    output.into()
199}
200
201/// Derive macro for the `Session` trait (V1) or `SessionV2` trait.
202/// Auto-detects based on the field type:
203///   - `Option<Account<'info, SessionToken>>` → implements `Session`
204///   - `Option<Account<'info, SessionTokenV2>>` → implements `SessionV2`
205#[proc_macro_derive(Session, attributes(session))]
206pub fn derive_session(input: TokenStream) -> TokenStream {
207    derive_impl(input, None)
208}
209
210/// Explicit V2 derive macro — same implementation as `#[derive(Session)]`,
211/// provided for clarity when using `SessionTokenV2`.
212#[proc_macro_derive(SessionV2, attributes(session))]
213pub fn derive_session_v2(input: TokenStream) -> TokenStream {
214    derive_impl(input, Some(SessionTokenType::V2))
215}
216
217struct SessionAuthArgs(syn::Expr, syn::Expr);
218
219impl Parse for SessionAuthArgs {
220    fn parse(input: ParseStream) -> syn::Result<Self> {
221        let equality_expr = input.parse()?;
222        input.parse::<Token![,]>()?;
223        let error_expr = input.parse()?;
224        Ok(SessionAuthArgs(equality_expr, error_expr))
225    }
226}
227
228#[proc_macro_attribute]
229/// Macro to check if the session (V1 or V2) or the original authority is the signer.
230/// Works with both `Session` and `SessionV2` traits.
231pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
232    let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
233
234    let input_fn = parse_macro_input!(item as syn::ItemFn);
235    let input_fn_name = input_fn.sig.ident;
236    let input_fn_vis = input_fn.vis;
237    let input_fn_block = input_fn.block;
238    let input_fn_inputs = input_fn.sig.inputs;
239    let input_fn_output = input_fn.sig.output;
240
241    let output = quote! {
242        #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
243            // Automatically generated by session_auth_or macro
244            // Import both traits so is_valid()/session_token()/session_authority() resolve
245            // regardless of which trait the derive emitted.
246            use ::session_keys::{Session as _, SessionV2 as _};
247            // BEGIN SESSION AUTH
248            // Current signer is the session signer or the original authority
249            let session_token = ctx.accounts.session_token();
250            if let Some(token) = session_token {
251                require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
252                // Checks that authority of the session is the same as authority of the original account
253                require_eq!(
254                    ctx.accounts.session_authority(),
255                    token.authority.key(),
256                    #error_ty
257                );
258            } else {
259                require!(
260                    #auth_expr,
261                    #error_ty
262                );
263            }
264            // END SESSION AUTH
265            #input_fn_block
266        }
267    };
268    output.into()
269}