Skip to main content

rusteron_code_gen/
parser.rs

1use crate::generator::{parse_custom_methods, CBinding, CWrapper, Method};
2use crate::{Arg, ArgProcessing, CHandler};
3use itertools::Itertools;
4use quote::ToTokens;
5use std::collections::{BTreeMap, BTreeSet};
6use std::fs;
7use std::path::PathBuf;
8use syn::{Attribute, Item, ItemForeignMod, ItemStruct, ItemType, Lit, Meta, MetaNameValue};
9
10pub fn parse_bindings(out: &PathBuf) -> CBinding {
11    let file_content = fs::read_to_string(out.clone()).expect("Unable to read file");
12    let syntax_tree = syn::parse_file(&file_content).expect("Unable to parse file");
13    let mut wrappers = BTreeMap::new();
14    let mut methods = Vec::new();
15    let mut handlers = Vec::new();
16
17    // Iterate through the items in the file
18    for item in syntax_tree.items {
19        match item {
20            Item::Struct(s) => {
21                process_struct(&mut wrappers, &s);
22            }
23            Item::Type(ty) => {
24                process_type(&mut wrappers, &mut handlers, &ty);
25            }
26            Item::ForeignMod(fm) => {
27                process_c_method(&mut wrappers, &mut methods, fm);
28            }
29            _ => {}
30        }
31    }
32
33    /*    // need to filter out args which don't match
34        for wrapper in wrappers.values_mut() {
35          for method in wrapper.methods.iter_mut() {
36              let method_debug = format!("{:?}", method);
37              for arg in method.arguments.iter_mut() {
38                if let ArgProcessing::Handler(args) = &arg.processing {
39                    let handler = args.get(0).unwrap();
40                    if !handlers.iter().any(|h| h.type_name == handler.c_type) {
41                      log::info!("replacing {} back to default", method_debug);
42                      // arg.processing = ArgProcessing::Default;
43                    }
44                }
45              }
46          }
47        }
48    */
49    let mut bindings = CBinding {
50        wrappers: wrappers
51            .into_iter()
52            .filter(|(_, wrapper)| {
53                // these are from media driver and do not follow convention
54                ![
55                    "aeron_thread",
56                    "aeron_command",
57                    "aeron_executor",
58                    "aeron_name_resolver",
59                    "aeron_udp_channel_transport", // this one I have issues with handlers
60                    "aeron_udp_transport",         // this one I have issues with handlers
61                ]
62                .iter()
63                .any(|&filter| wrapper.type_name.starts_with(filter))
64            })
65            .collect(),
66        methods,
67        handlers: handlers
68            .into_iter()
69            .filter(|h| {
70                !["aeron_udp_channel", "aeron_udp_transport"]
71                    .iter()
72                    .any(|&filter| h.type_name.starts_with(filter))
73            })
74            .collect(),
75    };
76
77    let mismatched_types = bindings
78        .wrappers
79        .iter()
80        .filter(|(key, w)| key.as_str() != w.type_name)
81        .map(|(a, b)| (a.clone(), b.clone()))
82        .collect_vec();
83    assert_eq!(Vec::<(String, CWrapper)>::new(), mismatched_types);
84
85    let custom = parse_custom_methods(crate::CUSTOM_AERON_CODE);
86    for wrapper in bindings.wrappers.values_mut() {
87        if let Some(methods) = custom.get(&wrapper.class_name) {
88            wrapper.skipped_methods = methods.clone();
89        }
90    }
91
92    bindings
93}
94
95fn process_c_method(
96    wrappers: &mut BTreeMap<String, CWrapper>,
97    methods: &mut Vec<Method>,
98    fm: ItemForeignMod,
99) {
100    // Extract functions inside extern "C" blocks
101    if fm.abi.name.is_some() && fm.abi.name.as_ref().unwrap().value() == "C" {
102        for foreign_item in fm.items {
103            if let syn::ForeignItem::Fn(f) = foreign_item {
104                let docs = get_doc_comments(&f.attrs);
105                let fn_name = f.sig.ident.to_string();
106
107                // Get function arguments and return type as Rust code
108                let args = extract_function_arguments(&f.sig.inputs);
109                let ret = extract_return_type(&f.sig.output);
110
111                let option = if let Some(arg) = args
112                    .iter()
113                    .skip_while(|a| a.is_mut_pointer() && a.is_primitive())
114                    .next()
115                {
116                    let ty = &arg.c_type;
117                    let ty = ty.split(' ').last().map(|t| t.to_string()).unwrap();
118                    if wrappers.contains_key(&ty) {
119                        Some(ty)
120                    } else {
121                        find_closest_wrapper_from_method_name(wrappers, &fn_name)
122                    }
123                } else {
124                    find_closest_wrapper_from_method_name(wrappers, &fn_name)
125                };
126
127                match option {
128                    Some(key) => {
129                        let wrapper = wrappers.get_mut(&key).unwrap();
130                        wrapper.methods.push(Method {
131                            fn_name: fn_name.clone(),
132                            struct_method_name: fn_name
133                                .replace(&wrapper.type_name[..wrapper.type_name.len() - 1], "")
134                                .to_string(),
135                            return_type: Arg {
136                                name: "".to_string(),
137                                c_type: ret.clone(),
138                                processing: ArgProcessing::Default,
139                            },
140                            arguments: process_types(args.clone()),
141                            docs: docs.clone(),
142                        });
143                    }
144                    None => methods.push(Method {
145                        fn_name: fn_name.clone(),
146                        struct_method_name: "".to_string(),
147                        return_type: Arg {
148                            name: "".to_string(),
149                            c_type: ret.clone(),
150                            processing: ArgProcessing::Default,
151                        },
152                        arguments: process_types(args.clone()),
153                        docs: docs.clone(),
154                    }),
155                }
156            }
157        }
158    }
159}
160
161fn find_closest_wrapper_from_method_name(
162    wrappers: &mut BTreeMap<String, CWrapper>,
163    fn_name: &String,
164) -> Option<String> {
165    let type_names = get_possible_wrappers(&fn_name);
166
167    let mut value = None;
168    for ty in type_names {
169        if wrappers.contains_key(&ty) {
170            value = Some(ty);
171            break;
172        }
173    }
174
175    value
176}
177
178pub fn get_possible_wrappers(fn_name: &str) -> Vec<String> {
179    fn_name
180        .char_indices()
181        .filter(|(_, c)| *c == '_')
182        .map(|(i, _)| format!("{}_t", &fn_name[..i]))
183        .rev()
184        .collect_vec()
185}
186
187fn process_type(
188    wrappers: &mut BTreeMap<String, CWrapper>,
189    handlers: &mut Vec<CHandler>,
190    ty: &ItemType,
191) {
192    // Handle type definitions and get docs
193    let docs = get_doc_comments(&ty.attrs);
194
195    let type_name = ty.ident.to_string();
196    let class_name = snake_to_pascal_case(&type_name);
197
198    if ty.to_token_stream().to_string().contains("_stct") {
199        wrappers
200            .entry(type_name.clone())
201            .or_insert(CWrapper {
202                class_name,
203                without_name: type_name[..type_name.len() - 2].to_string(),
204                type_name,
205                ..Default::default()
206            })
207            .docs
208            .extend(docs);
209    } else {
210        // Parse the function pointer type -> it is typically used for handlers/callbacks
211        if let syn::Type::Path(type_path) = &*ty.ty {
212            if let Some(segment) = type_path.path.segments.last() {
213                if segment.ident.to_string() == "Option" {
214                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
215                        if let Some(syn::GenericArgument::Type(syn::Type::BareFn(bare_fn))) =
216                            args.args.first()
217                        {
218                            let args: Vec<Arg> = bare_fn
219                                .inputs
220                                .iter()
221                                .map(|arg| {
222                                    let arg_name = match &arg.name {
223                                        Some((ident, _)) => ident.to_string(),
224                                        None => "".to_string(),
225                                    };
226                                    let arg_type = arg.ty.to_token_stream().to_string();
227                                    (arg_name, arg_type)
228                                })
229                                .map(|(field_name, field_type)| Arg {
230                                    name: field_name,
231                                    c_type: field_type,
232                                    processing: ArgProcessing::Default,
233                                })
234                                .collect();
235                            let string = bare_fn.output.to_token_stream().to_string();
236                            let mut return_type = string.trim();
237
238                            if return_type.starts_with("-> ") {
239                                return_type = &return_type[3..];
240                            }
241
242                            if return_type.is_empty() {
243                                return_type = "()";
244                            }
245
246                            if args.iter().filter(|a| a.is_c_void()).count() == 1 {
247                                let value = CHandler {
248                                    type_name: ty.ident.to_string(),
249                                    args: process_types(args),
250                                    return_type: Arg {
251                                        name: "".to_string(),
252                                        c_type: return_type.to_string(),
253                                        processing: ArgProcessing::Default,
254                                    },
255                                    docs: docs.clone(),
256                                    fn_mut_signature: Default::default(),
257                                    closure_type_name: Default::default(),
258                                };
259                                handlers.push(value);
260                            }
261                        }
262                    }
263                }
264            }
265        }
266    }
267}
268
269fn process_struct(wrappers: &mut BTreeMap<String, CWrapper>, s: &ItemStruct) {
270    // Print the struct name and its doc comments
271    let docs = get_doc_comments(&s.attrs);
272    let type_name = s.ident.to_string().replace("_stct", "_t");
273    let class_name = snake_to_pascal_case(&type_name);
274
275    let fields: Vec<Arg> = s
276        .fields
277        .iter()
278        .map(|f| {
279            let field_name = f.ident.as_ref().unwrap().to_string();
280            let field_type = f.ty.to_token_stream().to_string();
281            (field_name, field_type)
282        })
283        .map(|(field_name, field_type)| Arg {
284            name: field_name,
285            c_type: field_type,
286            processing: ArgProcessing::Default,
287        })
288        .collect();
289
290    let w = wrappers.entry(type_name.to_string()).or_insert(CWrapper {
291        class_name,
292        without_name: type_name[..type_name.len() - 2].to_string(),
293        type_name,
294        ..Default::default()
295    });
296    w.docs.extend(docs);
297    w.fields = process_types(fields);
298}
299
300fn process_types(mut name_and_type: Vec<Arg>) -> Vec<Arg> {
301    // now mark arguments which can be reduced
302    for i in 1..name_and_type.len() {
303        let param1 = &name_and_type[i - 1];
304        let param2 = &name_and_type[i];
305
306        let is_int = param2.c_type == "usize" || param2.c_type == "i32";
307        let length_field = param2.name == "length"
308            || param2.name == "len"
309            || (param2.name.ends_with("_length") && param2.name.starts_with(&param1.name));
310        if param2.is_c_void() && !param1.is_mut_pointer() && param1.c_type.ends_with("_t") {
311            // closures
312            //         handler: aeron_on_available_counter_t,
313            //         clientd: *mut ::std::os::raw::c_void,
314            let processing = ArgProcessing::Handler(vec![param1.clone(), param2.clone()]);
315            name_and_type[i - 1].processing = processing.clone();
316            name_and_type[i].processing = processing.clone();
317        } else if param1.is_c_string_any() && !param1.is_mut_pointer() && is_int && length_field {
318            //     pub stripped_channel: *mut ::std::os::raw::c_char,
319            //     pub stripped_channel_length: usize,
320            let processing = ArgProcessing::StringWithLength(vec![param1.clone(), param2.clone()]);
321            name_and_type[i - 1].processing = processing.clone();
322            name_and_type[i].processing = processing.clone();
323        } else if param1.is_byte_array()
324            // && !param1.is_mut_pointer()
325            && is_int
326            && length_field
327        {
328            //         key_buffer: *const u8,
329            //         key_buffer_length: usize,
330            let processing =
331                ArgProcessing::ByteArrayWithLength(vec![param1.clone(), param2.clone()]);
332            name_and_type[i - 1].processing = processing.clone();
333            name_and_type[i].processing = processing.clone();
334        }
335
336        //
337    }
338
339    name_and_type
340}
341
342// Helper function to extract doc comments
343fn get_doc_comments(attrs: &[Attribute]) -> BTreeSet<String> {
344    attrs
345        .iter()
346        .filter_map(|attr| {
347            // Parse the attribute meta to check if it is a `Meta::NameValue`
348            if let Meta::NameValue(MetaNameValue {
349                path,
350                value: syn::Expr::Lit(expr_lit),
351                ..
352            }) = &attr.meta
353            {
354                // Check if the path is "doc"
355                if path.is_ident("doc") {
356                    // Check if the literal is a string and return its value
357                    if let Lit::Str(lit_str) = &expr_lit.lit {
358                        return Some(lit_str.value().trim().to_string());
359                    }
360                }
361            }
362            None
363        })
364        .collect()
365}
366
367pub fn snake_to_pascal_case(mut snake: &str) -> String {
368    if snake.ends_with("_t") {
369        snake = &snake[..snake.len() - 2];
370    }
371    snake
372        .split('_')
373        .filter(|x| *x != "on") // Split the string by underscores
374        .map(|word| {
375            let mut chars = word.chars();
376            // Capitalize the first letter and collect the rest of the letters
377            match chars.next() {
378                Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
379                None => String::new(),
380            }
381        })
382        .collect()
383}
384
385// Helper function to extract function arguments as Rust code
386fn extract_function_arguments(
387    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
388) -> Vec<Arg> {
389    inputs
390        .iter()
391        .map(|arg| match arg {
392            syn::FnArg::Receiver(_) => "self".to_string(), // Handle self receiver
393            syn::FnArg::Typed(pat_type) => pat_type.to_token_stream().to_string(), // Convert the pattern and type to Rust code
394        })
395        .map(|arg| {
396            arg.splitn(2, ':')
397                .map(|s| s.trim().to_string())
398                .collect_tuple()
399                .unwrap()
400        })
401        .map(|(name, ty)| Arg {
402            name,
403            c_type: ty,
404            processing: ArgProcessing::Default,
405        })
406        .collect_vec()
407}
408
409// Helper function to extract return type as Rust code
410fn extract_return_type(output: &syn::ReturnType) -> String {
411    match output {
412        syn::ReturnType::Default => "()".to_string(), // No return type, equivalent to ()
413        syn::ReturnType::Type(_, ty) => ty.to_token_stream().to_string(), // Convert the type to Rust code
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use crate::parser::parse_bindings;
420    use std::path::PathBuf;
421
422    fn running_under_valgrind() -> bool {
423        std::env::var_os("RUSTERON_VALGRIND").is_some()
424    }
425
426    #[test]
427    fn media_driver() {
428        if running_under_valgrind() {
429            return;
430        }
431
432        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
433            .join("bindings")
434            .join("media-driver.rs");
435        let bindings = parse_bindings(&path);
436        assert_eq!(
437            "AeronImageFragmentAssembler",
438            bindings
439                .wrappers
440                .get("aeron_image_fragment_assembler_t")
441                .unwrap()
442                .class_name
443        );
444    }
445    #[test]
446    fn client() {
447        if running_under_valgrind() {
448            return;
449        }
450
451        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
452            .join("bindings")
453            .join("client.rs");
454        let bindings = parse_bindings(&path);
455        assert_eq!(
456            "AeronImageFragmentAssembler",
457            bindings
458                .wrappers
459                .get("aeron_image_fragment_assembler_t")
460                .unwrap()
461                .class_name
462        );
463        assert!(bindings.handlers.len() > 1);
464    }
465}