session_keys_macros_attribute/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7 TypePath,
8};
9
10struct SessionArgs {
11 signer: syn::ExprAssign,
12 authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let signer = input.parse()?;
18
19 input.parse::<Token![,]>()?;
20
21 let authority = input.parse()?;
22 Ok(SessionArgs { signer, authority })
23 }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27 attr.path.is_ident("session")
28}
29
30#[derive(Copy, Clone, Eq, PartialEq)]
31enum SessionTokenType {
32 V1,
33 V2,
34}
35
36struct SessionTokenInfo {
38 token_type: SessionTokenType,
39 lifetime: syn::Lifetime,
41}
42
43fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
44 let Type::Path(TypePath { path, .. }) = ty else {
45 return None;
46 };
47 let Some(seg) = path.segments.last() else {
48 return None;
49 };
50 if seg.ident != "Option" {
51 return None;
52 }
53 let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
54 return None;
55 };
56 let Some(GenericArgument::Type(Type::Path(TypePath {
57 path: acct_path, ..
58 }))) = opt_args.args.first()
59 else {
60 return None;
61 };
62 let Some(acct_seg) = acct_path.segments.last() else {
63 return None;
64 };
65 if acct_seg.ident != "Account" {
66 return None;
67 }
68 let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
69 return None;
70 };
71 let mut args = acct_args.args.iter();
72 let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
73 return None;
74 };
75 let lifetime = lifetime.clone();
76 let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
77 else {
78 return None;
79 };
80
81 if let Some(st_seg) = st_path.segments.last() {
82 if st_seg.ident == "SessionToken" {
83 return Some(SessionTokenInfo { token_type: SessionTokenType::V1, lifetime });
84 } else if st_seg.ident == "SessionTokenV2" {
85 return Some(SessionTokenInfo { token_type: SessionTokenType::V2, lifetime });
86 }
87 }
88 None
89}
90
91fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
95 let input_parsed = parse_macro_input!(input as DeriveInput);
96
97 let fields = match input_parsed.data {
98 Data::Struct(data) => match data.fields {
99 Fields::Named(fields) => fields,
100 _ => panic!("Session trait can only be derived for structs with named fields"),
101 },
102 _ => panic!("Session trait can only be derived for structs"),
103 };
104
105 let session_token_field = fields
107 .named
108 .iter()
109 .find(|field| field.ident.as_ref().unwrap().to_string() == "session_token")
110 .expect("Session trait can only be derived for structs with a session_token field");
111
112 let session_token_type = &session_token_field.ty;
113 let info = get_session_token_info(session_token_type)
114 .expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
115 let token_type = info.token_type;
116
117 if let Some(expected) = expected {
118 if token_type != expected {
119 return syn::Error::new_spanned(
120 &session_token_field.ty,
121 "#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
122 )
123 .to_compile_error()
124 .into();
125 }
126 }
127
128 let session_attr = session_token_field
130 .attrs
131 .iter()
132 .find(|attr| is_session(attr))
133 .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
134
135 let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
136
137 let session_signer = session_args.signer.right.into_token_stream();
138
139 let session_authority = session_args.authority.right.into_token_stream();
141
142 let struct_name = &input_parsed.ident;
143 let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
144
145 let info_lifetime = info.lifetime;
148
149 let output = match token_type {
150 SessionTokenType::V1 => quote! {
151 #[automatically_derived]
152 impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
153
154 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
155 crate::id()
156 }
157
158 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
159 self.session_token.clone()
160 }
161
162 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
163 self.#session_authority
164 }
165
166 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
167 self.#session_signer.clone()
168 }
169
170 }
171 },
172 SessionTokenType::V2 => quote! {
173 #[automatically_derived]
174 impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
175
176 fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
177 crate::id()
178 }
179
180 fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
181 self.session_token.clone()
182 }
183
184 fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
185 self.#session_authority
186 }
187
188 fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
189 self.#session_signer.clone()
190 }
191
192 }
193 },
194 };
195
196 output.into()
197}
198
199#[proc_macro_derive(Session, attributes(session))]
204pub fn derive_session(input: TokenStream) -> TokenStream {
205 derive_impl(input, None)
206}
207
208#[proc_macro_derive(SessionV2, attributes(session))]
211pub fn derive_session_v2(input: TokenStream) -> TokenStream {
212 derive_impl(input, Some(SessionTokenType::V2))
213}
214
215struct SessionAuthArgs(syn::Expr, syn::Expr);
216
217impl Parse for SessionAuthArgs {
218 fn parse(input: ParseStream) -> syn::Result<Self> {
219 let equality_expr = input.parse()?;
220 input.parse::<Token![,]>()?;
221 let error_expr = input.parse()?;
222 Ok(SessionAuthArgs(equality_expr, error_expr))
223 }
224}
225
226#[proc_macro_attribute]
227pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
230 let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
231
232 let input_fn = parse_macro_input!(item as syn::ItemFn);
233 let input_fn_name = input_fn.sig.ident;
234 let input_fn_vis = input_fn.vis;
235 let input_fn_block = input_fn.block;
236 let input_fn_inputs = input_fn.sig.inputs;
237 let input_fn_output = input_fn.sig.output;
238
239 let output = quote! {
240 #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
241 use ::session_keys::{Session as _, SessionV2 as _};
245 let session_token = ctx.accounts.session_token();
248 if let Some(token) = session_token {
249 require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
250 require_eq!(
252 ctx.accounts.session_authority(),
253 token.authority.key(),
254 #error_ty
255 );
256 } else {
257 require!(
258 #auth_expr,
259 #error_ty
260 );
261 }
262 #input_fn_block
264 }
265 };
266 output.into()
267}