session_keys_macros_attribute/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
7 TypePath,
8};
9
10struct SessionArgs {
11 signer: syn::ExprAssign,
12 authority: syn::ExprAssign,
13}
14
15impl Parse for SessionArgs {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let signer = input.parse()?;
18
19 input.parse::<Token![,]>()?;
20
21 let authority = input.parse()?;
22 Ok(SessionArgs { signer, authority })
23 }
24}
25
26fn is_session(attr: &syn::Attribute) -> bool {
27 attr.path.is_ident("session")
28}
29
30fn is_option_account_sessiontoken(ty: &Type) -> bool {
31 let Type::Path(TypePath { path, .. }) = ty else {
32 return false;
33 };
34 let Some(seg) = path.segments.first() else {
35 return false;
36 };
37 if seg.ident != "Option" {
38 return false;
39 }
40 let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
41 return false;
42 };
43 let Some(GenericArgument::Type(Type::Path(TypePath {
44 path: acct_path, ..
45 }))) = opt_args.args.first()
46 else {
47 return false;
48 };
49 let Some(acct_seg) = acct_path.segments.first() else {
50 return false;
51 };
52 if acct_seg.ident != "Account" {
53 return false;
54 }
55 let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
56 return false;
57 };
58 let mut args = acct_args.args.iter();
59 let Some(GenericArgument::Lifetime(_)) = args.next() else {
60 return false;
61 };
62 let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
63 else {
64 return false;
65 };
66 st_path.segments.len() == 1 && st_path.segments[0].ident == "SessionToken"
67}
68
69#[proc_macro_derive(Session, attributes(session))]
71pub fn derive(input: TokenStream) -> TokenStream {
72 let input_parsed = parse_macro_input!(input as DeriveInput);
73
74 let fields = match input_parsed.data {
75 Data::Struct(data) => match data.fields {
76 Fields::Named(fields) => fields,
77 _ => panic!("Session trait can only be derived for structs with named fields"),
78 },
79 _ => panic!("Session trait can only be derived for structs"),
80 };
81
82 let session_token_field = fields
84 .named
85 .iter()
86 .find(|field| *field.ident.as_ref().unwrap() == "session_token")
87 .expect("Session trait can only be derived for structs with a session_token field");
88 {
89 let session_token_type = &session_token_field.ty;
90 assert!(is_option_account_sessiontoken(session_token_type), "Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>>");
91 }
92
93 let session_attr = session_token_field
95 .attrs
96 .iter()
97 .find(|attr| is_session(attr))
98 .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
99
100 let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
101
102 let session_signer = session_args.signer.right.into_token_stream();
103
104 let session_authority = session_args.authority.right.into_token_stream();
106
107 let struct_name = &input_parsed.ident;
108 let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
109
110 let output = quote! {
111
112 #[automatically_derived]
113 impl #impl_generics Session #ty_generics for #struct_name #ty_generics #where_clause {
114
115 fn target_program(&self) -> Pubkey {
117 crate::id()
118 }
119
120 fn session_token(&self) -> Option<Account<'info, SessionToken>> {
122 self.session_token.clone()
123 }
124
125 fn session_authority(&self) -> Pubkey {
127 self.#session_authority
128 }
129
130 fn session_signer(&self) -> Signer<'info> {
132 self.#session_signer.clone()
133 }
134
135 }
136 };
137
138 output.into()
139}
140
141struct SessionAuthArgs(syn::Expr, syn::Expr);
142
143impl Parse for SessionAuthArgs {
144 fn parse(input: ParseStream) -> syn::Result<Self> {
145 let equality_expr = input.parse()?;
146 input.parse::<Token![,]>()?;
147 let error_expr = input.parse()?;
148 Ok(SessionAuthArgs(equality_expr, error_expr))
149 }
150}
151
152#[proc_macro_attribute]
153pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
155 let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
156
157 let input_fn = parse_macro_input!(item as syn::ItemFn);
158 let input_fn_name = input_fn.sig.ident;
159 let input_fn_vis = input_fn.vis;
160 let input_fn_block = input_fn.block;
161 let input_fn_inputs = input_fn.sig.inputs;
162 let input_fn_output = input_fn.sig.output;
163
164 let output = quote! {
165 #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
166 let session_token = ctx.accounts.session_token();
170 if let Some(token) = session_token {
171 require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
172 require_eq!(
174 ctx.accounts.session_authority(),
175 token.authority.key(),
176 #error_ty
177 );
178 } else {
179 require!(
180 #auth_expr,
181 #error_ty
182 );
183 }
184 #input_fn_block
186 }
187 };
188 output.into()
189}