1use {
2 crate::{helpers::PathHelper, Encoding},
3 heck::ToSnakeCase,
4 quote::format_ident,
5 syn::{
6 parse::{Parse, Parser},
7 punctuated::Punctuated,
8 visit::Visit,
9 Expr, FnArg, GenericArgument, Ident, LitInt, Pat, Token, Type, TypePath,
10 },
11};
12
13pub struct InstructionReturnData {
14 pub ty: Option<Type>,
15 pub encoding: Encoding,
16}
17
18pub enum InstructionArg {
19 Type { ty: Box<Type>, encoding: Encoding },
20 Context(Ident),
21}
22
23pub struct Instruction {
24 pub name: Ident,
25 pub args: Vec<(Ident, InstructionArg)>,
26 pub return_data: InstructionReturnData,
27}
28
29impl TryFrom<&syn::ItemFn> for Instruction {
30 type Error = syn::Error;
31
32 fn try_from(value: &syn::ItemFn) -> Result<Self, Self::Error> {
33 let return_data = value
34 .sig
35 .output
36 .get_element_with_inner()
37 .and_then(|(_, inner, _)| inner);
38
39 let mut args = Vec::with_capacity(value.sig.inputs.len());
40 for fn_arg in &value.sig.inputs {
41 let FnArg::Typed(pat_ty) = fn_arg else {
42 continue;
43 };
44
45 let Type::Path(ref ty_path) = *pat_ty.ty else {
46 continue;
47 };
48
49 let (name, ty, size) = ty_path
50 .get_element_with_inner()
51 .ok_or(syn::Error::new_spanned(fn_arg, "Invalid FnArg."))?;
52
53 if name == "ProgramIdArg" || name == "Remaining" || name == "AccountIter" {
54 continue;
55 }
56
57 let arg_name = extract_name(&pat_ty.pat)
58 .unwrap_or(format_ident!("{}", name.to_string().to_snake_case()));
59
60 if name == "Arg" {
61 args.push((
62 arg_name,
63 InstructionArg::Type {
64 ty: Box::new(
65 ty.ok_or(syn::Error::new_spanned(fn_arg, "Invalid argument type."))?,
66 ),
67 encoding: infer_arg_encoding(ty_path),
68 },
69 ));
70 } else if name == "Array" {
71 let size = size.ok_or(syn::Error::new_spanned(fn_arg, "Invalid Array type."))?;
72 let ty = ty.ok_or(syn::Error::new_spanned(fn_arg, "Invalid argument type."))?;
73 let Type::Path(path) = ty else {
74 return Err(syn::Error::new_spanned(&arg_name, "Invalid ty_path."));
75 };
76 let (name, _, _) = path
77 .get_element_with_inner()
78 .ok_or(syn::Error::new_spanned(&path, "Invalid Array inner type."))?;
79 for i in 0..size {
80 let arg_name = format_ident!("{arg_name}_{i}");
81 args.push((arg_name, InstructionArg::Context(name.clone())));
82 }
83 } else {
84 args.push((arg_name, InstructionArg::Context(name.clone())));
85 }
86 }
87
88 Ok(Instruction {
89 name: value.sig.ident.clone(),
90 args,
91 return_data: InstructionReturnData {
92 ty: return_data,
93 encoding: Encoding::Bytemuck,
94 },
95 })
96 }
97}
98
99fn infer_arg_encoding(ty_path: &TypePath) -> Encoding {
100 let Some(seg) = ty_path.path.segments.last() else {
101 return Encoding::Bytemuck;
102 };
103 let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
104 return Encoding::Bytemuck;
105 };
106
107 let strategy = args
108 .args
109 .iter()
110 .filter_map(|arg| match arg {
111 GenericArgument::Type(ty) => Some(ty),
112 _ => None,
113 })
114 .nth(1);
115
116 match strategy {
117 None => Encoding::Bytemuck,
118 Some(Type::Path(path)) if path.path.is_ident("BytemuckStrategy") => Encoding::Bytemuck,
119 Some(Type::Path(path)) if path.path.is_ident("BorshStrategy") => Encoding::Borsh,
120 Some(_) => Encoding::Custom,
121 }
122}
123
124fn extract_name(pat: &Pat) -> Option<Ident> {
125 match pat {
126 Pat::Ident(ident) => Some(ident.ident.clone()),
127 Pat::TupleStruct(tuple_struct) => {
128 let pat = tuple_struct.elems.first()?;
129 extract_name(pat)
130 }
131 _ => None,
132 }
133}
134
135#[derive(Default)]
136pub struct InstructionsList(pub Vec<(usize, Ident)>);
137
138struct RouterEntry {
139 discriminator: LitInt,
140 _arrow_eq: Token![=],
141 _arrow_gt: Token![>],
142 handler_name: Ident,
143}
144
145impl Parse for RouterEntry {
146 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
147 Ok(RouterEntry {
148 discriminator: input.parse()?,
149 _arrow_eq: input.parse()?,
150 _arrow_gt: input.parse()?,
151 handler_name: input.parse()?,
152 })
153 }
154}
155
156impl TryFrom<&syn::ItemConst> for InstructionsList {
157 type Error = syn::Error;
158
159 fn try_from(value: &syn::ItemConst) -> syn::Result<Self> {
160 let Expr::Macro(expr_macro) = value.expr.as_ref() else {
161 return Err(syn::Error::new_spanned(value, "Invalid router type."));
162 };
163
164 let instructions = Punctuated::<RouterEntry, syn::Token![,]>::parse_terminated
165 .parse2(expr_macro.mac.tokens.clone())?;
166 Ok(Self(
167 instructions
168 .iter()
169 .map(|entry| {
170 Ok((
171 entry.discriminator.base10_parse::<usize>()?,
172 entry.handler_name.clone(),
173 ))
174 })
175 .collect::<Result<_, syn::Error>>()?,
176 ))
177 }
178}
179
180impl<'ast> Visit<'ast> for InstructionsList {
181 fn visit_item_const(&mut self, i: &'ast syn::ItemConst) {
182 if i.ident != "ROUTER" {
183 return;
184 }
185
186 if let Ok(ix_list) = InstructionsList::try_from(i) {
187 *self = ix_list;
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use {
195 super::*,
196 syn::{parse_quote, ItemConst, ItemFn},
197 };
198
199 #[test]
200 fn test_instruction_list() {
201 let router: ItemConst = parse_quote! {
202 pub const ROUTER: EntryFn = basic_router! {
203 0 => account_iter,
204 1 => initialize,
205 2 => assert
206 };
207 };
208
209 let ix_list = InstructionsList::try_from(&router).unwrap();
210 assert_eq!(ix_list.0[0].0, 0);
211 assert_eq!(ix_list.0[1].0, 1);
212 assert_eq!(ix_list.0[2].0, 2);
213 assert_eq!(ix_list.0[0].1, "account_iter");
214 assert_eq!(ix_list.0[1].1, "initialize");
215 assert_eq!(ix_list.0[2].1, "assert");
216 }
217
218 #[test]
219 fn test_instruction_construction() {
220 let fn_raw: ItemFn = parse_quote! {
221 pub fn instruction_1(ctx: Context1, array: Array<Context2, 2>, arg: Arg<u64>, arg2: Arg<u64, BorshStrategy>) -> ProgramResult {
222 Ok(())
223 }
224 };
225 let ix = Instruction::try_from(&fn_raw).unwrap();
226
227 assert_eq!(ix.name, "instruction_1");
228 assert_eq!(ix.args.len(), 5);
229 assert_eq!(ix.args[0].0, "ctx");
230 assert!(matches!(&ix.args[0].1, InstructionArg::Context(x) if x == "Context1"));
231 assert_eq!(ix.args[1].0, "array_0");
232 assert!(matches!(&ix.args[1].1, InstructionArg::Context(x) if x == "Context2"));
233 assert_eq!(ix.args[2].0, "array_1");
234 assert!(matches!(&ix.args[2].1, InstructionArg::Context(x) if x == "Context2"));
235 assert_eq!(ix.args[3].0, "arg");
236 assert!(matches!(
237 &ix.args[3].1,
238 InstructionArg::Type { ty, encoding }
239 if matches!(**ty, Type::Path(ref path) if path.path.is_ident("u64"))
240 && matches!(encoding, Encoding::Bytemuck)
241 ));
242 assert_eq!(ix.args[4].0, "arg2");
243 assert!(matches!(
244 &ix.args[4].1,
245 InstructionArg::Type { ty, encoding }
246 if matches!(**ty, Type::Path(ref path) if path.path.is_ident("u64"))
247 && matches!(encoding, Encoding::Borsh)
248 ));
249 assert!(ix.return_data.ty.is_none());
250 assert!(matches!(ix.return_data.encoding, Encoding::Bytemuck));
251 }
252}