1use std::fmt::Display;
19
20use anyhow::Context;
21use convert_case::{Case, Casing};
22use proc_macro2::{Ident, Span};
23use syn::punctuated::Punctuated;
24use syn::token::Comma;
25use syn::{
26 parse_quote,
27 Field,
28 Fields,
29 FieldsNamed,
30 FieldsUnnamed,
31 Item,
32 ItemEnum,
33 ItemMod,
34 Type,
35 TypeInfer,
36 Variant,
37};
38
39use crate::extern_module_translator::{
40 Arg,
41 Function,
42 ReferenceParameters,
43 RustWrapperType,
44 WrapperType,
45};
46use crate::utils::{is_primitive, BuildContext};
47
48pub fn get_enums_from_module(
50 module: &ItemMod,
51 context: &BuildContext,
52) -> anyhow::Result<Vec<ItemEnum>> {
53 Ok(module
54 .content
55 .as_ref()
56 .context("The module is empty.")?
57 .1
58 .iter()
59 .filter_map(|item| match item {
60 Item::Enum(enum_item) => Some(enum_item),
61 _ => None,
62 })
63 .filter(|item| context.check_cfg_attrs(&item.attrs))
64 .cloned()
65 .collect())
66}
67
68pub fn enum_tag_name(enum_name: impl Display) -> String {
69 format!("{enum_name}Tag")
70}
71
72pub fn is_primitive_enum(enum_item: &ItemEnum) -> bool {
73 enum_item.variants.iter().all(|v| v.fields == Fields::Unit)
74}
75
76pub fn variant_wrapper_ident(enum_ident: impl Display, variant_ident: impl Display) -> Ident {
77 Ident::new(&format!("{enum_ident}_{variant_ident}"), Span::call_site())
78}
79
80pub fn field_getter_ident(field: &Field, field_idx: usize) -> Ident {
81 let id_str = match &field.ident {
82 Some(name) => format!("get_{name}"),
83 None => format!("get_{field_idx}"),
84 };
85 Ident::new(&id_str, Span::call_site())
86}
87
88pub fn create_field_getter_function(
89 enum_item: &ItemEnum,
90 variant: &Variant,
91 field: &Field,
92 field_idx: usize,
93) -> Function {
94 let variant_wrapper_name = variant_wrapper_ident(&enum_item.ident, &variant.ident);
95 let return_type: Ident = syn::parse_str(&field_type(field).unwrap_type()).unwrap();
96 Function {
97 arguments: vec![Arg {
98 arg_name: "self".to_owned(),
99 typ: WrapperType {
100 original_type_name: syn::parse_str(&variant_wrapper_name.to_string()).unwrap(),
101 wrapper_name: variant_wrapper_name.to_string(),
102 rust_type: RustWrapperType::Custom,
103 reference_parameters: Some(ReferenceParameters::shared()),
104 impl_traits: vec![],
105 },
106 }],
107 return_type: Some(WrapperType {
108 original_type_name: parse_quote! {#return_type},
109 wrapper_name: return_type.to_string(),
110 rust_type: if is_primitive_field(field) {
111 RustWrapperType::Primitive
112 } else {
113 RustWrapperType::Custom
114 },
115 reference_parameters: None,
116 impl_traits: vec![],
117 }),
118 name: field_getter_ident(field, field_idx).to_string(),
119 }
120}
121
122pub fn create_variant_getter_function(enum_item: &ItemEnum, variant: &Variant) -> Option<Function> {
123 let enum_name = enum_item.ident.to_string();
124 if is_many_fields_variant(variant)
125 || (is_single_field_variant(variant) && !is_ignored_variant(variant))
126 {
127 Some(Function {
128 arguments: vec![Arg {
129 arg_name: "self".to_owned(),
130 typ: WrapperType {
131 original_type_name: syn::parse_str(&enum_name).unwrap(),
132 wrapper_name: enum_name.clone(),
133 rust_type: RustWrapperType::Custom,
134 reference_parameters: Some(ReferenceParameters::shared()),
135 impl_traits: vec![],
136 },
137 }],
138 return_type: create_variant_getter_return_type(enum_item, variant),
139 name: variant_getter_ident(variant).to_string(),
140 })
141 } else {
142 None
143 }
144}
145
146pub fn create_variant_getter_return_type(
147 enum_item: &ItemEnum,
148 variant: &Variant,
149) -> Option<WrapperType> {
150 if is_single_field_variant(variant) {
151 single_field_variant_getter(variant)
152 } else if is_many_fields_variant(variant) {
153 Some(many_fields_variant_getter(enum_item, variant))
154 } else {
155 None
156 }
157}
158
159pub fn many_fields_variant_getter(enum_item: &ItemEnum, variant: &Variant) -> WrapperType {
160 let wrapper_name = variant_wrapper_ident(&enum_item.ident, &variant.ident);
161 WrapperType {
162 original_type_name: parse_quote! {#wrapper_name},
163 wrapper_name: wrapper_name.to_string(),
164 rust_type: RustWrapperType::Custom,
165 reference_parameters: None,
166 impl_traits: vec![],
167 }
168}
169
170pub fn single_field_variant_getter(variant: &Variant) -> Option<WrapperType> {
171 let field = &get_fields(variant).unwrap()[0];
172 match &field_type(field) {
173 FieldType::Type(field_type) => {
174 let return_type: Ident = syn::parse_str(field_type).unwrap();
175 Some(WrapperType {
176 original_type_name: parse_quote! {#return_type},
177 wrapper_name: return_type.to_string(),
178 rust_type: if is_primitive_field(field) {
179 RustWrapperType::Primitive
180 } else {
181 RustWrapperType::Custom
182 },
183 reference_parameters: None,
184 impl_traits: vec![],
185 })
186 }
187 FieldType::Ignored => None,
188 }
189}
190
191pub fn variant_getter_ident(v: &Variant) -> Ident {
192 Ident::new(
193 &format!("get_{}", v.ident.to_string().to_case(Case::Snake)),
194 Span::call_site(),
195 )
196}
197
198#[derive(PartialEq, Eq)]
199pub enum FieldType {
200 Type(String),
201 Ignored,
202}
203
204impl FieldType {
205 pub fn unwrap_type(self) -> String {
206 match self {
207 FieldType::Type(s) => s,
208 _ => panic!("Invalid field type"),
209 }
210 }
211}
212
213pub fn field_type(field: &Field) -> FieldType {
214 match &field.ty {
215 Type::Path(type_path) => FieldType::Type(
216 type_path
217 .path
218 .get_ident()
219 .unwrap_or_else(|| panic!("Invalid ident of an enum variant field {field:?}"))
220 .to_string(),
221 ),
222 Type::Infer(TypeInfer { .. }) => FieldType::Ignored,
223 _ => panic!("Invalid type of an enum variant field {field:?}"),
224 }
225}
226
227pub fn is_primitive_field(field: &Field) -> bool {
228 is_primitive(field_type(field).unwrap_type().as_str())
229}
230
231pub fn field_name(field: &Field) -> Option<String> {
232 field.ident.as_ref().map(|i| i.to_string())
233}
234
235pub fn is_many_fields_variant(v: &Variant) -> bool {
236 match &v.fields {
237 syn::Fields::Named(FieldsNamed { named: fields, .. })
238 | syn::Fields::Unnamed(FieldsUnnamed {
239 unnamed: fields, ..
240 }) => fields.len() > 1,
241 syn::Fields::Unit => false,
242 }
243}
244pub fn is_single_field_variant(v: &Variant) -> bool {
245 match &v.fields {
246 syn::Fields::Named(FieldsNamed { named: fields, .. })
247 | syn::Fields::Unnamed(FieldsUnnamed {
248 unnamed: fields, ..
249 }) => fields.len() == 1,
250 syn::Fields::Unit => false,
251 }
252}
253pub fn is_unit_variant(v: &Variant) -> bool {
254 matches!(&v.fields, syn::Fields::Unit)
255}
256
257pub fn is_ignored_variant(v: &Variant) -> bool {
258 match get_fields(v) {
259 Some(fields) if fields.len() == 1 => field_type(&fields[0]) == FieldType::Ignored,
260 _ => false,
261 }
262}
263
264pub fn get_fields(v: &Variant) -> Option<&Punctuated<Field, Comma>> {
265 match &v.fields {
266 syn::Fields::Named(FieldsNamed { named: fields, .. })
267 | syn::Fields::Unnamed(FieldsUnnamed {
268 unnamed: fields, ..
269 }) => Some(fields),
270 syn::Fields::Unit => None,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use pretty_assertions::assert_eq;
277 use syn::ItemMod;
278
279 use super::*;
280 use crate::utils::helpers;
281
282 #[test]
283 fn test_getting_enums_from_module() {
284 let rust_code = "
285mod ffi {
286 enum En1 {
287 V1,
288 }
289
290 enum En2 {
291 V1,
292 }
293
294 struct SomeStructToBeIgnored {
295 field: String,
296 }
297}
298 ";
299 let module: ItemMod = syn::parse_str(rust_code).unwrap();
300 let enums: Vec<String> = get_enums_from_module(&module, &helpers::get_context())
301 .unwrap()
302 .into_iter()
303 .map(|enum_item| enum_item.ident.to_string())
304 .collect();
305
306 assert_eq!(enums, vec!["En1".to_owned(), "En2".to_owned()]);
307 }
308}