session_keys_macros_attribute/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7 TypePath,
8};
9
10struct SessionArgs {
11 signer: syn::ExprAssign,
12 authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let signer = input.parse()?;
18
19 input.parse::<Token![,]>()?;
20
21 let authority = input.parse()?;
22 Ok(SessionArgs { signer, authority })
23 }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27 attr.path().is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32 V1,
33 V2,
34}
35
36struct SessionTokenInfo {
38 token_type: SessionTokenType,
39 lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44 let Type::Path(TypePath { path, .. }) = ty else {
45 return None;
46 };
47 let seg = path.segments.last()?;
48 if seg.ident != "Option" {
49 return None;
50 }
51 let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
52 return None;
53 };
54 let Some(GenericArgument::Type(Type::Path(TypePath {
55 path: acct_path, ..
56 }))) = opt_args.args.first()
57 else {
58 return None;
59 };
60 let acct_seg = acct_path.segments.last()?;
61 if acct_seg.ident != "Account" {
62 return None;
63 }
64 let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
65 return None;
66 };
67 let mut args = acct_args.args.iter();
68 let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
69 return None;
70 };
71 let lifetime = lifetime.clone();
72 let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
73 else {
74 return None;
75 };
76
77 if let Some(st_seg) = st_path.segments.last() {
78 if st_seg.ident == "SessionToken" {
79 return Some(SessionTokenInfo { token_type: SessionTokenType::V1, lifetime });
80 } else if st_seg.ident == "SessionTokenV2" {
81 return Some(SessionTokenInfo { token_type: SessionTokenType::V2, lifetime });
82 }
83 }
84 None
85}
86
87fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
91 let input_parsed = parse_macro_input!(input as DeriveInput);
92
93 let fields = match input_parsed.data {
94 Data::Struct(data) => match data.fields {
95 Fields::Named(fields) => fields,
96 _ => panic!("Session trait can only be derived for structs with named fields"),
97 },
98 _ => panic!("Session trait can only be derived for structs"),
99 };
100
101 let session_token_field = fields
103 .named
104 .iter()
105 .find(|field| *field.ident.as_ref().unwrap() == "session_token")
106 .expect("Session trait can only be derived for structs with a session_token field");
107
108 let session_token_type = &session_token_field.ty;
109 let info = get_session_token_info(session_token_type)
110 .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
111 let token_type = info.token_type;
112
113 if let Some(expected) = expected {
114 if token_type != expected {
115 return syn::Error::new_spanned(
116 &session_token_field.ty,
117 "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
118 )
119 .to_compile_error()
120 .into();
121 }
122 }
123
124 let session_attr = session_token_field
126 .attrs
127 .iter()
128 .find(|attr| is_session(attr))
129 .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
130
131 let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
132
133 let session_signer = session_args.signer.right.into_token_stream();
134
135 let session_authority = session_args.authority.right.into_token_stream();
137
138 let struct_name = &input_parsed.ident;
139 let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
140
141 let info_lifetime = info.lifetime;
144
145 let output = match token_type {
146 SessionTokenType::V1 => quote! {
147 #[automatically_derived]
148 impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
149
150 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
151 crate::id()
152 }
153
154 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
155 self.session_token.clone()
156 }
157
158 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
159 self.#session_authority
160 }
161
162 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
163 self.#session_signer.clone()
164 }
165
166 }
167 },
168 SessionTokenType::V2 => quote! {
169 #[automatically_derived]
170 impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
171
172 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
173 crate::id()
174 }
175
176 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
177 self.session_token.clone()
178 }
179
180 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
181 self.#session_authority
182 }
183
184 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
185 self.#session_signer.clone()
186 }
187
188 }
189 },
190 };
191
192 output.into()
193}
194
195#[proc_macro_derive(Session, attributes(session))]
200pub fn derive_session(input: TokenStream) -> TokenStream {
201 derive_impl(input, None)
202}
203
204#[proc_macro_derive(SessionV2, attributes(session))]
207pub fn derive_session_v2(input: TokenStream) -> TokenStream {
208 derive_impl(input, Some(SessionTokenType::V2))
209}
210
211struct SessionAuthArgs(syn::Expr, syn::Expr);
212
213impl Parse for SessionAuthArgs {
214 fn parse(input: ParseStream) -> syn::Result<Self> {
215 let equality_expr = input.parse()?;
216 input.parse::<Token![,]>()?;
217 let error_expr = input.parse()?;
218 Ok(SessionAuthArgs(equality_expr, error_expr))
219 }
220}
221
222#[proc_macro_attribute]
223pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
226 let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
227
228 let input_fn = parse_macro_input!(item as syn::ItemFn);
229 let input_fn_name = input_fn.sig.ident;
230 let input_fn_vis = input_fn.vis;
231 let input_fn_block = input_fn.block;
232 let input_fn_inputs = input_fn.sig.inputs;
233 let input_fn_output = input_fn.sig.output;
234
235 let output = quote! {
236 #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
237 use ::session_keys::{Session as _, SessionV2 as _};
241 let session_token = ctx.accounts.session_token();
244 if let Some(token) = session_token {
245 require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
246 require_eq!(
248 ctx.accounts.session_authority(),
249 token.authority.key(),
250 #error_ty
251 );
252 } else {
253 require!(
254 #auth_expr,
255 #error_ty
256 );
257 }
258 #input_fn_block
260 }
261 };
262 output.into()
263}