session_keys_macros_attribute/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7 TypePath,
8};
9
10struct SessionArgs {
11 signer: syn::ExprAssign,
12 authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let signer = input.parse()?;
18
19 input.parse::<Token![,]>()?;
20
21 let authority = input.parse()?;
22 Ok(SessionArgs { signer, authority })
23 }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27 attr.path().is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32 V1,
33 V2,
34}
35
36struct SessionTokenInfo {
38 token_type: SessionTokenType,
39 lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44 let Type::Path(TypePath { path, .. }) = ty else {
45 return None;
46 };
47 let seg = path.segments.last()?;
48 if seg.ident != "Option" {
49 return None;
50 }
51 let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
52 return None;
53 };
54 let Some(GenericArgument::Type(Type::Path(TypePath {
55 path: acct_path, ..
56 }))) = opt_args.args.first()
57 else {
58 return None;
59 };
60 let acct_seg = acct_path.segments.last()?;
61 if acct_seg.ident != "Account" {
62 return None;
63 }
64 let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
65 return None;
66 };
67 let mut args = acct_args.args.iter();
68 let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
69 return None;
70 };
71 let lifetime = lifetime.clone();
72 let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
73 else {
74 return None;
75 };
76
77 if let Some(st_seg) = st_path.segments.last() {
78 if st_seg.ident == "SessionToken" {
79 return Some(SessionTokenInfo {
80 token_type: SessionTokenType::V1,
81 lifetime,
82 });
83 } else if st_seg.ident == "SessionTokenV2" {
84 return Some(SessionTokenInfo {
85 token_type: SessionTokenType::V2,
86 lifetime,
87 });
88 }
89 }
90 None
91}
92
93fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
97 let input_parsed = parse_macro_input!(input as DeriveInput);
98
99 let fields = match input_parsed.data {
100 Data::Struct(data) => match data.fields {
101 Fields::Named(fields) => fields,
102 _ => panic!("Session trait can only be derived for structs with named fields"),
103 },
104 _ => panic!("Session trait can only be derived for structs"),
105 };
106
107 let session_token_field = fields
109 .named
110 .iter()
111 .find(|field| *field.ident.as_ref().unwrap() == "session_token")
112 .expect("Session trait can only be derived for structs with a session_token field");
113
114 let session_token_type = &session_token_field.ty;
115 let info = get_session_token_info(session_token_type)
116 .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
117 let token_type = info.token_type;
118
119 if let Some(expected) = expected {
120 if token_type != expected {
121 return syn::Error::new_spanned(
122 &session_token_field.ty,
123 "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
124 )
125 .to_compile_error()
126 .into();
127 }
128 }
129
130 let session_attr = session_token_field
132 .attrs
133 .iter()
134 .find(|attr| is_session(attr))
135 .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
136
137 let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
138
139 let session_signer = session_args.signer.right.into_token_stream();
140
141 let session_authority = session_args.authority.right.into_token_stream();
143
144 let struct_name = &input_parsed.ident;
145 let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
146
147 let info_lifetime = info.lifetime;
150
151 let output = match token_type {
152 SessionTokenType::V1 => quote! {
153 #[automatically_derived]
154 impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
155
156 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
157 crate::id()
158 }
159
160 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
161 self.session_token.clone()
162 }
163
164 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
165 self.#session_authority
166 }
167
168 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
169 self.#session_signer.clone()
170 }
171
172 }
173 },
174 SessionTokenType::V2 => quote! {
175 #[automatically_derived]
176 impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
177
178 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
179 crate::id()
180 }
181
182 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
183 self.session_token.clone()
184 }
185
186 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
187 self.#session_authority
188 }
189
190 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
191 self.#session_signer.clone()
192 }
193
194 }
195 },
196 };
197
198 output.into()
199}
200
201#[proc_macro_derive(Session, attributes(session))]
206pub fn derive_session(input: TokenStream) -> TokenStream {
207 derive_impl(input, None)
208}
209
210#[proc_macro_derive(SessionV2, attributes(session))]
213pub fn derive_session_v2(input: TokenStream) -> TokenStream {
214 derive_impl(input, Some(SessionTokenType::V2))
215}
216
217struct SessionAuthArgs(syn::Expr, syn::Expr);
218
219impl Parse for SessionAuthArgs {
220 fn parse(input: ParseStream) -> syn::Result<Self> {
221 let equality_expr = input.parse()?;
222 input.parse::<Token![,]>()?;
223 let error_expr = input.parse()?;
224 Ok(SessionAuthArgs(equality_expr, error_expr))
225 }
226}
227
228#[proc_macro_attribute]
229pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
232 let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
233
234 let input_fn = parse_macro_input!(item as syn::ItemFn);
235 let input_fn_name = input_fn.sig.ident;
236 let input_fn_vis = input_fn.vis;
237 let input_fn_block = input_fn.block;
238 let input_fn_inputs = input_fn.sig.inputs;
239 let input_fn_output = input_fn.sig.output;
240
241 let output = quote! {
242 #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
243 use ::session_keys::{Session as _, SessionV2 as _};
247 let session_token = ctx.accounts.session_token();
250 if let Some(token) = session_token {
251 require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
252 require_eq!(
254 ctx.accounts.session_authority(),
255 token.authority.key(),
256 #error_ty
257 );
258 } else {
259 require!(
260 #auth_expr,
261 #error_ty
262 );
263 }
264 #input_fn_block
266 }
267 };
268 output.into()
269}