rusty_bind_parser/
enum_helpers.rs

1//
2// Wildland Project
3//
4// Copyright © 2022 Golem Foundation,
5//
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License version 3 as published by
8// the Free Software Foundation.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU General Public License for more details.
14//
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18use std::fmt::Display;
19
20use anyhow::Context;
21use convert_case::{Case, Casing};
22use proc_macro2::{Ident, Span};
23use syn::punctuated::Punctuated;
24use syn::token::Comma;
25use syn::{
26    parse_quote,
27    Field,
28    Fields,
29    FieldsNamed,
30    FieldsUnnamed,
31    Item,
32    ItemEnum,
33    ItemMod,
34    Type,
35    TypeInfer,
36    Variant,
37};
38
39use crate::extern_module_translator::{
40    Arg,
41    Function,
42    ReferenceParameters,
43    RustWrapperType,
44    WrapperType,
45};
46use crate::utils::{is_primitive, BuildContext};
47
48/// Get enums definitions from a module
49pub fn get_enums_from_module(
50    module: &ItemMod,
51    context: &BuildContext,
52) -> anyhow::Result<Vec<ItemEnum>> {
53    Ok(module
54        .content
55        .as_ref()
56        .context("The module is empty.")?
57        .1
58        .iter()
59        .filter_map(|item| match item {
60            Item::Enum(enum_item) => Some(enum_item),
61            _ => None,
62        })
63        .filter(|item| context.check_cfg_attrs(&item.attrs))
64        .cloned()
65        .collect())
66}
67
68pub fn enum_tag_name(enum_name: impl Display) -> String {
69    format!("{enum_name}Tag")
70}
71
72pub fn is_primitive_enum(enum_item: &ItemEnum) -> bool {
73    enum_item.variants.iter().all(|v| v.fields == Fields::Unit)
74}
75
76pub fn variant_wrapper_ident(enum_ident: impl Display, variant_ident: impl Display) -> Ident {
77    Ident::new(&format!("{enum_ident}_{variant_ident}"), Span::call_site())
78}
79
80pub fn field_getter_ident(field: &Field, field_idx: usize) -> Ident {
81    let id_str = match &field.ident {
82        Some(name) => format!("get_{name}"),
83        None => format!("get_{field_idx}"),
84    };
85    Ident::new(&id_str, Span::call_site())
86}
87
88pub fn create_field_getter_function(
89    enum_item: &ItemEnum,
90    variant: &Variant,
91    field: &Field,
92    field_idx: usize,
93) -> Function {
94    let variant_wrapper_name = variant_wrapper_ident(&enum_item.ident, &variant.ident);
95    let return_type: Ident = syn::parse_str(&field_type(field).unwrap_type()).unwrap();
96    Function {
97        arguments: vec![Arg {
98            arg_name: "self".to_owned(),
99            typ: WrapperType {
100                original_type_name: syn::parse_str(&variant_wrapper_name.to_string()).unwrap(),
101                wrapper_name: variant_wrapper_name.to_string(),
102                rust_type: RustWrapperType::Custom,
103                reference_parameters: Some(ReferenceParameters::shared()),
104                impl_traits: vec![],
105            },
106        }],
107        return_type: Some(WrapperType {
108            original_type_name: parse_quote! {#return_type},
109            wrapper_name: return_type.to_string(),
110            rust_type: if is_primitive_field(field) {
111                RustWrapperType::Primitive
112            } else {
113                RustWrapperType::Custom
114            },
115            reference_parameters: None,
116            impl_traits: vec![],
117        }),
118        name: field_getter_ident(field, field_idx).to_string(),
119    }
120}
121
122pub fn create_variant_getter_function(enum_item: &ItemEnum, variant: &Variant) -> Option<Function> {
123    let enum_name = enum_item.ident.to_string();
124    if is_many_fields_variant(variant)
125        || (is_single_field_variant(variant) && !is_ignored_variant(variant))
126    {
127        Some(Function {
128            arguments: vec![Arg {
129                arg_name: "self".to_owned(),
130                typ: WrapperType {
131                    original_type_name: syn::parse_str(&enum_name).unwrap(),
132                    wrapper_name: enum_name.clone(),
133                    rust_type: RustWrapperType::Custom,
134                    reference_parameters: Some(ReferenceParameters::shared()),
135                    impl_traits: vec![],
136                },
137            }],
138            return_type: create_variant_getter_return_type(enum_item, variant),
139            name: variant_getter_ident(variant).to_string(),
140        })
141    } else {
142        None
143    }
144}
145
146pub fn create_variant_getter_return_type(
147    enum_item: &ItemEnum,
148    variant: &Variant,
149) -> Option<WrapperType> {
150    if is_single_field_variant(variant) {
151        single_field_variant_getter(variant)
152    } else if is_many_fields_variant(variant) {
153        Some(many_fields_variant_getter(enum_item, variant))
154    } else {
155        None
156    }
157}
158
159pub fn many_fields_variant_getter(enum_item: &ItemEnum, variant: &Variant) -> WrapperType {
160    let wrapper_name = variant_wrapper_ident(&enum_item.ident, &variant.ident);
161    WrapperType {
162        original_type_name: parse_quote! {#wrapper_name},
163        wrapper_name: wrapper_name.to_string(),
164        rust_type: RustWrapperType::Custom,
165        reference_parameters: None,
166        impl_traits: vec![],
167    }
168}
169
170pub fn single_field_variant_getter(variant: &Variant) -> Option<WrapperType> {
171    let field = &get_fields(variant).unwrap()[0];
172    match &field_type(field) {
173        FieldType::Type(field_type) => {
174            let return_type: Ident = syn::parse_str(field_type).unwrap();
175            Some(WrapperType {
176                original_type_name: parse_quote! {#return_type},
177                wrapper_name: return_type.to_string(),
178                rust_type: if is_primitive_field(field) {
179                    RustWrapperType::Primitive
180                } else {
181                    RustWrapperType::Custom
182                },
183                reference_parameters: None,
184                impl_traits: vec![],
185            })
186        }
187        FieldType::Ignored => None,
188    }
189}
190
191pub fn variant_getter_ident(v: &Variant) -> Ident {
192    Ident::new(
193        &format!("get_{}", v.ident.to_string().to_case(Case::Snake)),
194        Span::call_site(),
195    )
196}
197
198#[derive(PartialEq, Eq)]
199pub enum FieldType {
200    Type(String),
201    Ignored,
202}
203
204impl FieldType {
205    pub fn unwrap_type(self) -> String {
206        match self {
207            FieldType::Type(s) => s,
208            _ => panic!("Invalid field type"),
209        }
210    }
211}
212
213pub fn field_type(field: &Field) -> FieldType {
214    match &field.ty {
215        Type::Path(type_path) => FieldType::Type(
216            type_path
217                .path
218                .get_ident()
219                .unwrap_or_else(|| panic!("Invalid ident of an enum variant field {field:?}"))
220                .to_string(),
221        ),
222        Type::Infer(TypeInfer { .. }) => FieldType::Ignored,
223        _ => panic!("Invalid type of an enum variant field {field:?}"),
224    }
225}
226
227pub fn is_primitive_field(field: &Field) -> bool {
228    is_primitive(field_type(field).unwrap_type().as_str())
229}
230
231pub fn field_name(field: &Field) -> Option<String> {
232    field.ident.as_ref().map(|i| i.to_string())
233}
234
235pub fn is_many_fields_variant(v: &Variant) -> bool {
236    match &v.fields {
237        syn::Fields::Named(FieldsNamed { named: fields, .. })
238        | syn::Fields::Unnamed(FieldsUnnamed {
239            unnamed: fields, ..
240        }) => fields.len() > 1,
241        syn::Fields::Unit => false,
242    }
243}
244pub fn is_single_field_variant(v: &Variant) -> bool {
245    match &v.fields {
246        syn::Fields::Named(FieldsNamed { named: fields, .. })
247        | syn::Fields::Unnamed(FieldsUnnamed {
248            unnamed: fields, ..
249        }) => fields.len() == 1,
250        syn::Fields::Unit => false,
251    }
252}
253pub fn is_unit_variant(v: &Variant) -> bool {
254    matches!(&v.fields, syn::Fields::Unit)
255}
256
257pub fn is_ignored_variant(v: &Variant) -> bool {
258    match get_fields(v) {
259        Some(fields) if fields.len() == 1 => field_type(&fields[0]) == FieldType::Ignored,
260        _ => false,
261    }
262}
263
264pub fn get_fields(v: &Variant) -> Option<&Punctuated<Field, Comma>> {
265    match &v.fields {
266        syn::Fields::Named(FieldsNamed { named: fields, .. })
267        | syn::Fields::Unnamed(FieldsUnnamed {
268            unnamed: fields, ..
269        }) => Some(fields),
270        syn::Fields::Unit => None,
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use pretty_assertions::assert_eq;
277    use syn::ItemMod;
278
279    use super::*;
280    use crate::utils::helpers;
281
282    #[test]
283    fn test_getting_enums_from_module() {
284        let rust_code = "
285mod ffi {
286    enum En1 {
287        V1,
288    }
289
290    enum En2 {
291        V1,
292    }
293
294    struct SomeStructToBeIgnored {
295        field: String,
296    }
297}
298        ";
299        let module: ItemMod = syn::parse_str(rust_code).unwrap();
300        let enums: Vec<String> = get_enums_from_module(&module, &helpers::get_context())
301            .unwrap()
302            .into_iter()
303            .map(|enum_item| enum_item.ident.to_string())
304            .collect();
305
306        assert_eq!(enums, vec!["En1".to_owned(), "En2".to_owned()]);
307    }
308}