Skip to main content

typhoon_cpi_generator/anchor/
account.rs

1use {
2    crate::{
3        anchor::{gen_docs, gen_type},
4        idl::{
5            Account, DefinedFields, EnumVariant, Repr, ReprModifier, Serialization, Type, TypeDef,
6            TypeDefTy,
7        },
8    },
9    proc_macro2::{Span, TokenStream},
10    quote::{format_ident, quote},
11    std::collections::HashMap,
12    syn::Ident,
13};
14
15pub fn gen_accounts(accounts: &[Account], types: &[TypeDef]) -> proc_macro2::TokenStream {
16    let mut types: HashMap<String, TokenStream> = types
17        .iter()
18        .map(|ty| (ty.name.to_string(), gen_defined_type(ty)))
19        .collect();
20
21    for account in accounts {
22        let ident = format_ident!("{}", account.name);
23        let discriminator = &account.discriminator;
24        let traits_impl = quote! {
25            impl Owner for #ident {
26                const OWNER: Address = PROGRAM_ID;
27            }
28
29            impl Discriminator for #ident {
30                const DISCRIMINATOR: &'static [u8] = &[#(#discriminator),*];
31            }
32        };
33        if let Some(ty) = &account.ty {
34            let type_def = TypeDef {
35                name: account.name.to_owned(),
36                ty: ty.to_owned(),
37                ..Default::default()
38            };
39            let ty = gen_defined_type(&type_def);
40            types.insert(
41                account.name.clone(),
42                quote! {
43                    #ty
44                    #traits_impl
45                },
46            );
47        } else {
48            let ty = types.get_mut(&account.name).unwrap();
49            ty.extend(traits_impl);
50        }
51    }
52
53    let types = types.values();
54
55    quote! {
56        #(#types)*
57    }
58}
59
60fn gen_defined_type(ty: &TypeDef) -> proc_macro2::TokenStream {
61    let ident = format_ident!("{}", ty.name);
62    let repr = ty.repr.as_ref().map(gen_repr);
63    let docs = gen_docs(&ty.docs);
64    let derive = gen_serialization(&ty.serialization);
65    let item = match &ty.ty {
66        TypeDefTy::Struct { fields } => gen_struct(&ident, fields),
67        TypeDefTy::Enum { variants } => gen_enum(&ident, variants),
68        TypeDefTy::Type { alias } => gen_type_alias(&ident, alias),
69    };
70
71    quote! {
72        #docs
73        #derive
74        #repr
75        #item
76    }
77}
78
79fn gen_struct(ident: &Ident, fields: &Option<DefinedFields>) -> proc_macro2::TokenStream {
80    match fields {
81        Some(struct_fields) => match struct_fields {
82            DefinedFields::Named(f) => {
83                let fields = f.iter().map(|el| {
84                    let docs = gen_docs(&el.docs);
85                    let ident = Ident::new(&el.name, Span::call_site());
86                    let ty = gen_type(&el.ty);
87
88                    quote! {
89                        #docs
90                        pub #ident: #ty,
91                    }
92                });
93                quote! {
94                    pub struct #ident {
95                        #(#fields)*
96                    }
97                }
98            }
99            DefinedFields::Tuple(f) => {
100                let fields = f.iter().map(|el| {
101                    let ty = gen_type(el);
102                    quote!(#ty)
103                });
104                quote! {
105                    pub struct #ident(#(#fields),*)
106                }
107            }
108        },
109        None => quote!(pub struct #ident;),
110    }
111}
112
113fn gen_enum(ident: &Ident, variants: &[EnumVariant]) -> proc_macro2::TokenStream {
114    let fields = variants.iter().map(|el| {
115        let variant_ident = Ident::new(&el.name, Span::call_site());
116        if let Some(ref f) = el.fields {
117            match f {
118                DefinedFields::Named(f) => {
119                    let fields = f.iter().map(|el| {
120                        let docs = gen_docs(&el.docs);
121                        let ident = Ident::new(&el.name, Span::call_site());
122                        let ty = gen_type(&el.ty);
123
124                        quote! {
125                            #docs
126                            #ident: #ty,
127                        }
128                    });
129                    quote! {
130                        #variant_ident {
131                            #(#fields)*
132                        }
133                    }
134                }
135                DefinedFields::Tuple(f) => {
136                    let fields = f.iter().map(|el| {
137                        let ty = gen_type(el);
138                        quote!(#ty)
139                    });
140                    quote! {
141                        #variant_ident(#(#fields),*)
142                    }
143                }
144            }
145        } else {
146            quote!(#variant_ident)
147        }
148    });
149
150    quote! {
151        pub enum #ident {
152            #(#fields),*
153        }
154    }
155}
156
157fn gen_type_alias(ident: &Ident, alias: &Type) -> proc_macro2::TokenStream {
158    let ty = gen_type(alias);
159    quote!(pub type #ident = #ty;)
160}
161
162fn gen_repr(r: &Repr) -> proc_macro2::TokenStream {
163    let gen_repr_with_modifiers = |repr_type: &str, modifier: &ReprModifier| {
164        let ident = Ident::new(repr_type, Span::call_site());
165        let mut attrs = vec![quote!(#ident)];
166
167        if modifier.packed {
168            attrs.push(quote!(packed));
169        }
170        if let Some(size) = modifier.align {
171            attrs.push(quote!(align(#size)));
172        }
173
174        quote!(#[repr(#(#attrs),*)])
175    };
176
177    match r {
178        Repr::Rust(modifier) => gen_repr_with_modifiers("Rust", modifier),
179        Repr::C(modifier) => gen_repr_with_modifiers("C", modifier),
180        Repr::Transparent => quote!(#[repr(transparent)]),
181    }
182}
183
184fn gen_serialization(serialization: &Serialization) -> proc_macro2::TokenStream {
185    match serialization {
186        Serialization::Borsh => {
187            quote!(#[derive(borsh::BorshSerialize, borsh::BorshDeserialize)])
188        }
189        Serialization::BytemuckUnsafe | Serialization::Bytemuck => {
190            quote!(#[derive(bytemuck::Pod, bytemuck::Zeroable, Clone, Copy)])
191        }
192        _ => unimplemented!(),
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use {super::*, crate::idl::Field, quote::quote};
199
200    #[test]
201    fn test_gen_repr_rust() {
202        let repr = Repr::Rust(ReprModifier {
203            packed: true,
204            align: Some(4),
205        });
206        let result = gen_repr(&repr).to_string();
207        assert_eq!(
208            result,
209            quote!(#[repr(Rust, packed, align(4usize))]).to_string()
210        );
211    }
212
213    #[test]
214    fn test_gen_repr_c() {
215        let repr = Repr::C(ReprModifier {
216            packed: false,
217            align: Some(8),
218        });
219        let result = gen_repr(&repr).to_string();
220        assert_eq!(result, quote!(#[repr(C, align(8usize))]).to_string());
221    }
222
223    #[test]
224    fn test_gen_repr_transparent() {
225        let repr = Repr::Transparent;
226        let result = gen_repr(&repr).to_string();
227        assert_eq!(result, quote!(#[repr(transparent)]).to_string());
228    }
229
230    #[test]
231    fn test_gen_repr_no_modifiers() {
232        let repr = Repr::Rust(ReprModifier {
233            packed: false,
234            align: None,
235        });
236        let result = gen_repr(&repr).to_string();
237        assert_eq!(result, quote!(#[repr(Rust)]).to_string());
238    }
239
240    #[test]
241    fn test_gen_serialization_borsh() {
242        let result = gen_serialization(&Serialization::Borsh).to_string();
243        assert_eq!(
244            result,
245            quote!(#[derive(borsh::BorshSerialize, borsh::BorshDeserialize)]).to_string()
246        );
247    }
248
249    #[test]
250    fn test_gen_serialization_bytemuck() {
251        let result = gen_serialization(&Serialization::Bytemuck).to_string();
252        assert_eq!(
253            result,
254            quote!(#[derive(bytemuck::Pod, bytemuck::Zeroable, Clone, Copy)]).to_string()
255        );
256    }
257
258    #[test]
259    fn test_gen_struct_named() {
260        let ident = Ident::new("TestStruct", Span::call_site());
261        let fields = DefinedFields::Named(vec![Field {
262            name: "field1".to_string(),
263            docs: vec!["Test doc".to_string()],
264            ty: Type::U64,
265        }]);
266        let result = gen_struct(&ident, &Some(fields)).to_string();
267        assert_eq!(
268            result,
269            quote! {
270                pub struct TestStruct {
271                    #[doc = " Test doc"]
272                    pub field1: u64,
273                }
274            }
275            .to_string()
276        );
277    }
278
279    #[test]
280    fn test_gen_struct_tuple() {
281        let ident = Ident::new("TestStruct", Span::call_site());
282        let fields = DefinedFields::Tuple(vec![Type::U64, Type::Bool]);
283        let result = gen_struct(&ident, &Some(fields)).to_string();
284        assert_eq!(result, quote!(pub struct TestStruct(u64, bool)).to_string());
285    }
286
287    #[test]
288    fn test_gen_struct_empty() {
289        let ident = Ident::new("TestStruct", Span::call_site());
290        let result = gen_struct(&ident, &None).to_string();
291        assert_eq!(
292            result,
293            quote!(
294                pub struct TestStruct;
295            )
296            .to_string()
297        );
298    }
299
300    #[test]
301    fn test_gen_enum() {
302        let ident = Ident::new("TestEnum", Span::call_site());
303        let variants = vec![
304            EnumVariant {
305                name: "Variant1".to_string(),
306                fields: None,
307            },
308            EnumVariant {
309                name: "Variant2".to_string(),
310                fields: Some(DefinedFields::Named(vec![Field {
311                    name: "field1".to_string(),
312                    docs: vec![],
313                    ty: Type::U64,
314                }])),
315            },
316            EnumVariant {
317                name: "Variant3".to_string(),
318                fields: Some(DefinedFields::Tuple(vec![Type::Bool, Type::U64])),
319            },
320        ];
321        let result = gen_enum(&ident, &variants).to_string();
322        assert_eq!(
323            result,
324            quote! {
325                pub enum TestEnum {
326                    Variant1,
327                    Variant2 {
328                        field1: u64,
329                    },
330                    Variant3(bool, u64)
331                }
332            }
333            .to_string()
334        );
335    }
336
337    #[test]
338    fn test_gen_type_alias() {
339        let ident = Ident::new("TestAlias", Span::call_site());
340        let alias = Type::U64;
341        let result = gen_type_alias(&ident, &alias).to_string();
342        assert_eq!(
343            result,
344            quote!(
345                pub type TestAlias = u64;
346            )
347            .to_string()
348        );
349    }
350}