Skip to main content

typhoon_syn/
instruction.rs

1use {
2    crate::{helpers::PathHelper, Encoding},
3    heck::ToSnakeCase,
4    quote::format_ident,
5    syn::{
6        parse::{Parse, Parser},
7        punctuated::Punctuated,
8        visit::Visit,
9        Expr, FnArg, GenericArgument, Ident, LitInt, Pat, Token, Type, TypePath,
10    },
11};
12
13pub struct InstructionReturnData {
14    pub ty: Option<Type>,
15    pub encoding: Encoding,
16}
17
18pub enum InstructionArg {
19    Type { ty: Box<Type>, encoding: Encoding },
20    Context(Ident),
21}
22
23pub struct Instruction {
24    pub name: Ident,
25    pub args: Vec<(Ident, InstructionArg)>,
26    pub return_data: InstructionReturnData,
27}
28
29impl TryFrom<&syn::ItemFn> for Instruction {
30    type Error = syn::Error;
31
32    fn try_from(value: &syn::ItemFn) -> Result<Self, Self::Error> {
33        let return_data = value
34            .sig
35            .output
36            .get_element_with_inner()
37            .and_then(|(_, inner, _)| inner);
38
39        let mut args = Vec::with_capacity(value.sig.inputs.len());
40        for fn_arg in &value.sig.inputs {
41            let FnArg::Typed(pat_ty) = fn_arg else {
42                continue;
43            };
44
45            let Type::Path(ref ty_path) = *pat_ty.ty else {
46                continue;
47            };
48
49            let (name, ty, size) = ty_path
50                .get_element_with_inner()
51                .ok_or(syn::Error::new_spanned(fn_arg, "Invalid FnArg."))?;
52
53            if name == "ProgramIdArg" || name == "Remaining" || name == "AccountIter" {
54                continue;
55            }
56
57            let arg_name = extract_name(&pat_ty.pat)
58                .unwrap_or(format_ident!("{}", name.to_string().to_snake_case()));
59
60            if name == "Arg" {
61                args.push((
62                    arg_name,
63                    InstructionArg::Type {
64                        ty: Box::new(
65                            ty.ok_or(syn::Error::new_spanned(fn_arg, "Invalid argument type."))?,
66                        ),
67                        encoding: infer_arg_encoding(ty_path),
68                    },
69                ));
70            } else if name == "Array" {
71                let size = size.ok_or(syn::Error::new_spanned(fn_arg, "Invalid Array type."))?;
72                let ty = ty.ok_or(syn::Error::new_spanned(fn_arg, "Invalid argument type."))?;
73                let Type::Path(path) = ty else {
74                    return Err(syn::Error::new_spanned(&arg_name, "Invalid ty_path."));
75                };
76                let (name, _, _) = path
77                    .get_element_with_inner()
78                    .ok_or(syn::Error::new_spanned(&path, "Invalid Array inner type."))?;
79                for i in 0..size {
80                    let arg_name = format_ident!("{arg_name}_{i}");
81                    args.push((arg_name, InstructionArg::Context(name.clone())));
82                }
83            } else {
84                args.push((arg_name, InstructionArg::Context(name.clone())));
85            }
86        }
87
88        Ok(Instruction {
89            name: value.sig.ident.clone(),
90            args,
91            return_data: InstructionReturnData {
92                ty: return_data,
93                encoding: Encoding::Bytemuck,
94            },
95        })
96    }
97}
98
99fn infer_arg_encoding(ty_path: &TypePath) -> Encoding {
100    let Some(seg) = ty_path.path.segments.last() else {
101        return Encoding::Bytemuck;
102    };
103    let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
104        return Encoding::Bytemuck;
105    };
106
107    let strategy = args
108        .args
109        .iter()
110        .filter_map(|arg| match arg {
111            GenericArgument::Type(ty) => Some(ty),
112            _ => None,
113        })
114        .nth(1);
115
116    match strategy {
117        None => Encoding::Bytemuck,
118        Some(Type::Path(path)) if path.path.is_ident("BytemuckStrategy") => Encoding::Bytemuck,
119        Some(Type::Path(path)) if path.path.is_ident("BorshStrategy") => Encoding::Borsh,
120        Some(_) => Encoding::Custom,
121    }
122}
123
124fn extract_name(pat: &Pat) -> Option<Ident> {
125    match pat {
126        Pat::Ident(ident) => Some(ident.ident.clone()),
127        Pat::TupleStruct(tuple_struct) => {
128            let pat = tuple_struct.elems.first()?;
129            extract_name(pat)
130        }
131        _ => None,
132    }
133}
134
135#[derive(Default)]
136pub struct InstructionsList(pub Vec<(usize, Ident)>);
137
138struct RouterEntry {
139    discriminator: LitInt,
140    _arrow_eq: Token![=],
141    _arrow_gt: Token![>],
142    handler_name: Ident,
143}
144
145impl Parse for RouterEntry {
146    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
147        Ok(RouterEntry {
148            discriminator: input.parse()?,
149            _arrow_eq: input.parse()?,
150            _arrow_gt: input.parse()?,
151            handler_name: input.parse()?,
152        })
153    }
154}
155
156impl TryFrom<&syn::ItemConst> for InstructionsList {
157    type Error = syn::Error;
158
159    fn try_from(value: &syn::ItemConst) -> syn::Result<Self> {
160        let Expr::Macro(expr_macro) = value.expr.as_ref() else {
161            return Err(syn::Error::new_spanned(value, "Invalid router type."));
162        };
163
164        let instructions = Punctuated::<RouterEntry, syn::Token![,]>::parse_terminated
165            .parse2(expr_macro.mac.tokens.clone())?;
166        Ok(Self(
167            instructions
168                .iter()
169                .map(|entry| {
170                    Ok((
171                        entry.discriminator.base10_parse::<usize>()?,
172                        entry.handler_name.clone(),
173                    ))
174                })
175                .collect::<Result<_, syn::Error>>()?,
176        ))
177    }
178}
179
180impl<'ast> Visit<'ast> for InstructionsList {
181    fn visit_item_const(&mut self, i: &'ast syn::ItemConst) {
182        if i.ident != "ROUTER" {
183            return;
184        }
185
186        if let Ok(ix_list) = InstructionsList::try_from(i) {
187            *self = ix_list;
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use {
195        super::*,
196        syn::{parse_quote, ItemConst, ItemFn},
197    };
198
199    #[test]
200    fn test_instruction_list() {
201        let router: ItemConst = parse_quote! {
202            pub const ROUTER: EntryFn = basic_router! {
203                0 => account_iter,
204                1 => initialize,
205                2 => assert
206            };
207        };
208
209        let ix_list = InstructionsList::try_from(&router).unwrap();
210        assert_eq!(ix_list.0[0].0, 0);
211        assert_eq!(ix_list.0[1].0, 1);
212        assert_eq!(ix_list.0[2].0, 2);
213        assert_eq!(ix_list.0[0].1, "account_iter");
214        assert_eq!(ix_list.0[1].1, "initialize");
215        assert_eq!(ix_list.0[2].1, "assert");
216    }
217
218    #[test]
219    fn test_instruction_construction() {
220        let fn_raw: ItemFn = parse_quote! {
221            pub fn instruction_1(ctx: Context1, array: Array<Context2, 2>, arg: Arg<u64>, arg2: Arg<u64, BorshStrategy>) -> ProgramResult {
222                Ok(())
223            }
224        };
225        let ix = Instruction::try_from(&fn_raw).unwrap();
226
227        assert_eq!(ix.name, "instruction_1");
228        assert_eq!(ix.args.len(), 5);
229        assert_eq!(ix.args[0].0, "ctx");
230        assert!(matches!(&ix.args[0].1, InstructionArg::Context(x) if x == "Context1"));
231        assert_eq!(ix.args[1].0, "array_0");
232        assert!(matches!(&ix.args[1].1, InstructionArg::Context(x) if x == "Context2"));
233        assert_eq!(ix.args[2].0, "array_1");
234        assert!(matches!(&ix.args[2].1, InstructionArg::Context(x) if x == "Context2"));
235        assert_eq!(ix.args[3].0, "arg");
236        assert!(matches!(
237            &ix.args[3].1,
238            InstructionArg::Type { ty, encoding }
239                if matches!(**ty, Type::Path(ref path) if path.path.is_ident("u64"))
240                    && matches!(encoding, Encoding::Bytemuck)
241        ));
242        assert_eq!(ix.args[4].0, "arg2");
243        assert!(matches!(
244            &ix.args[4].1,
245            InstructionArg::Type { ty, encoding }
246                if matches!(**ty, Type::Path(ref path) if path.path.is_ident("u64"))
247                    && matches!(encoding, Encoding::Borsh)
248        ));
249        assert!(ix.return_data.ty.is_none());
250        assert!(matches!(ix.return_data.encoding, Encoding::Bytemuck));
251    }
252}