rusty_bind_parser/swift/
generator.rs

1use std::collections::HashSet;
2use std::fmt::Display;
3
4/// The module consists of methods and structures for Swift
5/// code generation.
6///
7use super::{function_helper::*, templates::*, translate_c_enums};
8use crate::binding_types::{Exceptions, RustWrapperType, WrapperType};
9use crate::cpp::externs::create_extern_imports;
10use crate::cpp::generator::create_classes_forward_declarations;
11use crate::enum_helpers::{
12    create_field_getter_function,
13    create_variant_getter_function,
14    enum_tag_name,
15    get_fields,
16    is_many_fields_variant,
17    is_primitive_enum,
18    variant_wrapper_ident,
19};
20use crate::extern_module_translator::{
21    ExternFunction,
22    ExternModuleTranslator,
23    Function,
24    RustTrait,
25};
26use crate::swift::ffi_protocols::get_supported_ffi_protocols;
27use crate::EXPORTED_SYMBOLS_PREFIX;
28
29struct TranslatableTraits {
30    protocols: Vec<String>,
31    extensions: Vec<String>,
32}
33
34fn create_class_methods_definitions(extern_module_translator: &ExternModuleTranslator) -> String {
35    extern_module_translator
36        .user_custom_types
37        .iter()
38        .map(|(wrapper_type, vec_of_functions)| {
39            let class_name = wrapper_type.wrapper_name.to_string();
40            let class_functions = vec_of_functions
41                .iter()
42                .map(|f| FunctionTranslator::from_class_method(f, &class_name))
43                .map(FunctionTranslator::generate_definition)
44                .collect::<String>();
45            let mut translatable_traits =
46                get_translatable_traits(class_name.to_owned(), &wrapper_type.impl_traits);
47            translatable_traits
48                .protocols
49                .insert(0, "Opaque".to_string()); // All types should be opaque
50            custom_class_definition(
51                &class_name,
52                translatable_traits.protocols.join(", ").as_str(),
53                &class_functions,
54                translatable_traits.extensions.join("\n").as_str(),
55            )
56        })
57        .collect::<String>()
58}
59
60fn get_translatable_traits(class_name: String, traits: &Vec<RustTrait>) -> TranslatableTraits {
61    let mut protocols = HashSet::new();
62    let mut extensions = HashSet::new();
63    for trait_ in traits {
64        if !trait_.has_methods {
65            // marker protocol
66            if let Some(swift_name) = translate_marker_name(trait_.name.to_owned()) {
67                protocols.insert(swift_name.to_string());
68            }
69        } else {
70            // extension
71            if let Some(swift_protocol_name) = translate_marker_name(trait_.name.to_owned()) {
72                if let Some(extension_declaration) =
73                    extension_declaration_template(class_name.as_str(), swift_protocol_name)
74                {
75                    extensions.insert(extension_declaration);
76                }
77            }
78        }
79    }
80
81    TranslatableTraits {
82        protocols: protocols.into_iter().collect(),
83        extensions: extensions.into_iter().collect(),
84    }
85}
86
87fn translate_marker_name(rust_trait_name: String) -> Option<&'static str> {
88    match rust_trait_name.as_str() {
89        "Send" | "Sync" => Some("@unchecked Sendable"),
90        "Eq" | "PartialEq" => Some("Equatable"),
91        "Hash" => Some("Hashable"),
92        "Ord" | "PartialOrd" => Some("Comparable"),
93        "Read" => Some("FFIRead"),
94        "Write" => Some("FFIWrite"),
95        _ => None,
96    }
97}
98
99fn translate_type_names(mut fun: Function) -> Function {
100    if let Some(ret_type) = &mut fun.return_type {
101        if ret_type.wrapper_name.as_str() == "String" {
102            ret_type.wrapper_name = "RustString".to_string()
103        }
104    }
105    fun
106}
107
108fn create_complex_enum_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
109    extern_module_translator
110        .shared_enums
111        .iter()
112        .filter(|e| !is_primitive_enum(e))
113        .map(|enum_item| {
114            let class_name = enum_item.ident.to_string();
115
116            let variant_getters = enum_item
117                .variants
118                .iter()
119                .filter_map(|variant| create_variant_getter_function(enum_item, variant))
120                .map(translate_type_names);
121
122            let many_fields_variants_wrapper: String = enum_item
123                .variants
124                .iter()
125                .filter(|v| is_many_fields_variant(v))
126                .map(|variant| {
127                    let fields = get_fields(variant).unwrap();
128                    let variant_wrapper_name =
129                        variant_wrapper_ident(&enum_item.ident, &variant.ident).to_string();
130                    let variant_wrapper_getters = fields
131                        .iter()
132                        .enumerate()
133                        .map(|(field_idx, field)| {
134                            translate_type_names(create_field_getter_function(
135                                enum_item, variant, field, field_idx,
136                            ))
137                        })
138                        .map(|f| FunctionTranslator::from_class_method(&f, &variant_wrapper_name))
139                        .map(FunctionTranslator::generate_definition)
140                        .collect::<String>();
141                    custom_class_definition(
142                        variant_wrapper_name.as_str(),
143                        "Opaque",
144                        &variant_wrapper_getters,
145                        "",
146                    )
147                })
148                .collect();
149
150            let class_functions = variant_getters
151                .map(|f| FunctionTranslator::from_class_method(&f, &class_name))
152                .map(FunctionTranslator::generate_definition)
153                .collect::<String>();
154            let enum_tag_name = enum_tag_name(&enum_item.ident);
155            let tag_getter_fn = format!(
156                "    public func getTag() -> {enum_tag_name} {{
157        return self._self.load(as: {enum_tag_name}.self)
158    }}\n"
159            );
160            custom_class_definition(
161                &class_name,
162                "Opaque",
163                &(class_functions + &tag_getter_fn),
164                "",
165            ) + &many_fields_variants_wrapper
166        })
167        .collect::<String>()
168}
169
170fn create_protocols_declarations(extern_module_translator: &ExternModuleTranslator) -> String {
171    extern_module_translator
172        .user_traits
173        .iter()
174        .map(|(wrapper_type, vec_of_functions)| {
175            let class_name = wrapper_type.wrapper_name.to_string();
176            let functions_declaration: String = vec_of_functions
177                .iter()
178                .map(|f| FunctionHelperVirtual::from_virtual_function(f, &class_name))
179                .map(FunctionHelperVirtual::generate_virtual_declaration)
180                .collect();
181            protocol_declaration(&class_name, &functions_declaration)
182        })
183        .collect::<String>()
184}
185
186fn create_virtual_method_calls(extern_module_translator: &ExternModuleTranslator) -> String {
187    extern_module_translator
188        .user_traits
189        .iter()
190        .map(|(wrapper_type, vec_of_functions)| {
191            let class_name = wrapper_type.wrapper_name.to_string();
192            vec_of_functions
193                .iter()
194                .map(|f| FunctionHelperVirtual::from_virtual_function(f, &class_name))
195                .map(FunctionHelperVirtual::generate_virtual_definition)
196                .collect::<String>()
197        })
198        .collect::<String>()
199}
200
201fn create_rust_types_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
202    extern_module_translator
203        .rust_types_wrappers
204        .ordered_iter()
205        .filter_map(|wrapper| match wrapper {
206            WrapperType {
207                rust_type: RustWrapperType::Vector(inner_type),
208                ..
209            } => {
210                let inner_type_name = inner_type.get_name();
211                let is_generic = matches!(inner_type.rust_type, RustWrapperType::Option(_))
212                    || matches!(inner_type.rust_type, RustWrapperType::Vector(_));
213                Some(vector_impl(
214                    &inner_type_name,
215                    &inner_type.wrapper_name,
216                    is_generic,
217                    matches!(inner_type.rust_type, RustWrapperType::Primitive),
218                ))
219            }
220            WrapperType {
221                rust_type: RustWrapperType::Option(inner_type),
222                ..
223            } => {
224                let inner_type_name = inner_type.get_name();
225                let is_generic = matches!(inner_type.rust_type, RustWrapperType::Option(_))
226                    || matches!(inner_type.rust_type, RustWrapperType::Vector(_));
227                Some(option_class(
228                    &inner_type_name,
229                    &inner_type.wrapper_name,
230                    is_generic,
231                ))
232            }
233            WrapperType {
234                rust_type: RustWrapperType::Exceptions(Exceptions::NonPrimitive(idents)),
235                wrapper_name,
236                ..
237            } => Some(
238                idents
239                    .iter()
240                    .map(|exception| {
241                        create_non_primitive_exception_class(
242                            &exception.to_string(),
243                            wrapper_name,
244                            extern_module_translator.exception_trait_methods.iter(),
245                        )
246                    })
247                    .collect::<String>(),
248            ),
249            WrapperType {
250                rust_type: RustWrapperType::Exceptions(Exceptions::Primitive(idents)),
251                wrapper_name,
252                ..
253            } => Some(
254                idents
255                    .iter()
256                    .map(|exception| {
257                        create_primitive_exception_class(
258                            &exception.to_string(),
259                            wrapper_name,
260                            extern_module_translator.exception_trait_methods.iter(),
261                        )
262                    })
263                    .collect::<String>(),
264            ),
265            _ => None,
266        })
267        .collect()
268}
269
270fn create_global_functions_definitions(
271    extern_module_translator: &ExternModuleTranslator,
272) -> String {
273    extern_module_translator
274        .global_functions
275        .iter()
276        .map(FunctionTranslator::from_global_function)
277        .map(FunctionTranslator::generate_definition)
278        .collect()
279}
280
281/// Creates exception class for an error variant that may be returned from rust
282pub fn create_non_primitive_exception_class<'a>(
283    exception: &impl Display,
284    err_name: &impl Display,
285    custom_methods: impl Iterator<Item = &'a Function>,
286) -> String {
287    let custom_methods = create_exception_custom_methods(custom_methods, err_name, "err._self");
288    format_exception_class(exception, err_name, &custom_methods)
289}
290
291/// Creates exception class for an error variant of primitive enum that may be returned from rust
292pub fn create_primitive_exception_class<'a>(
293    exception: &impl Display,
294    err_name: &impl Display,
295    custom_methods: impl Iterator<Item = &'a Function>,
296) -> String {
297    let custom_methods = create_exception_custom_methods(custom_methods, err_name, "&err");
298    format_exception_class(exception, err_name, &custom_methods)
299}
300
301fn format_exception_class(
302    exception: &impl Display,
303    err_name: &impl Display,
304    custom_methods: &impl Display,
305) -> String {
306    let exception_name = format!("{err_name}_{exception}Exception");
307    format!(
308        "
309public class {exception_name} : {RUST_EXCEPTION_BASE_CLASS_NAME} {{
310    private(set) var err: {err_name}
311    init(_ err: {err_name}) {{ self.err = err }}
312{custom_methods}
313}}
314"
315    )
316}
317
318fn create_enum_init_method(extern_module_translator: &ExternModuleTranslator) -> String {
319    extern_module_translator
320        .shared_enums
321        .iter()
322        .filter(|e| is_primitive_enum(e))
323        .map(|enum_class| {
324            let enum_name = &enum_class.ident;
325            format!(
326                "extension {enum_name} {{
327    init(_ enumObj: {enum_name}) {{
328        self = enumObj
329    }}
330}}\n"
331            )
332        })
333        .collect()
334}
335
336fn create_result_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
337    extern_module_translator
338        .rust_types_wrappers
339        .ordered_iter()
340        .filter_map(|wrapper| match wrapper {
341            WrapperType {
342                rust_type: RustWrapperType::Result(ok_type, exceptions_type),
343                ..
344            } => {
345                let ok_type = ok_type.get_name();
346                let error_enum_name = &exceptions_type.wrapper_name;
347                Some(result_class(
348                    &wrapper.wrapper_name,
349                    &ok_type,
350                    error_enum_name,
351                ))
352            }
353            _ => None,
354        })
355        .collect()
356}
357
358fn create_exception_custom_methods<'a>(
359    custom_methods: impl Iterator<Item = &'a Function>,
360    err_name: &impl Display,
361    rust_obj_ptr: impl Display,
362) -> impl Display {
363    custom_methods
364        .map(|fun| {
365            let return_type = fun
366                .return_type
367                .as_ref()
368                .map(|wrapper| wrapper.wrapper_name.as_str())
369                .unwrap_or("");
370            let function_name = &fun.name;
371            let ffi_call = format!("{EXPORTED_SYMBOLS_PREFIX}${err_name}${function_name}");
372            let ffi_call = format!("{ffi_call}({rust_obj_ptr})");
373            let ffi_call = match &fun.return_type {
374                None
375                | Some(WrapperType {
376                    rust_type: RustWrapperType::Primitive | RustWrapperType::FieldlessEnum,
377                    ..
378                }) => ffi_call,
379                Some(WrapperType { wrapper_name, .. }) => {
380                    format!("{wrapper_name}({ffi_call})")
381                }
382            };
383            format!(
384                "        public func {function_name}() -> {return_type} {{
385            return {ffi_call}
386        }}"
387            )
388        })
389        .collect::<Vec<_>>()
390        .join("\n")
391}
392
393fn base_exception_method(function: &Function) -> String {
394    let return_type = &function
395        .return_type
396        .as_ref()
397        .map(|t| t.get_name())
398        .unwrap_or_else(|| "".to_string());
399    let name = &function.name;
400    format!("    func {name}() -> {return_type};")
401}
402
403fn base_exception_class(emt: &ExternModuleTranslator) -> String {
404    let exception_trait_methods = emt
405        .exception_trait_methods
406        .iter()
407        .map(base_exception_method)
408        .collect::<Vec<_>>()
409        .join("\n");
410    format!("public protocol {RUST_EXCEPTION_BASE_CLASS_NAME} : Error {{\n{exception_trait_methods}\n}}\n")
411}
412
413/// Function generates a C header that can be used as an Objective-C bridging
414/// layer to the compiled Rust static library.
415///
416pub fn generate_swift_file(extern_module_translator: &ExternModuleTranslator) -> String {
417    let classes_definition = create_class_methods_definitions(extern_module_translator);
418    let complex_enum_classes_definitions = create_complex_enum_wrappers(extern_module_translator);
419    let protocols_declaration = create_protocols_declarations(extern_module_translator);
420    let ffi_protocols_declaration = get_supported_ffi_protocols();
421    let virtual_methods_calls = create_virtual_method_calls(extern_module_translator);
422    let rust_types_wrappers = create_rust_types_wrappers(extern_module_translator);
423    let global_functions_definition: String =
424        create_global_functions_definitions(extern_module_translator);
425    let base_exception_class = base_exception_class(extern_module_translator);
426    let result_wrapper = create_result_wrappers(extern_module_translator);
427    let enum_init_methods = create_enum_init_method(extern_module_translator);
428    format!(
429        "{PREDEFINED}
430{enum_init_methods}
431{complex_enum_classes_definitions}
432{result_wrapper}
433{base_exception_class}
434{rust_types_wrappers}
435{classes_definition}
436{global_functions_definition}
437{virtual_methods_calls}
438{protocols_declaration}
439{ffi_protocols_declaration}"
440    )
441}
442
443/// Extern functions can be saved in another header file. Particularly
444/// useful while importing C methods in Swift.
445///
446pub fn generate_c_externs_file(
447    extern_module_translator: &ExternModuleTranslator,
448    extern_functions: &[ExternFunction],
449) -> String {
450    let externs = create_extern_imports(extern_functions);
451    let classes_forward_declarations =
452        create_classes_forward_declarations(extern_module_translator);
453    let enum_classes_definitions = translate_c_enums(extern_module_translator);
454    format!(
455        "#include <stdbool.h>
456#include <stdint.h>
457
458typedef uint8_t u8;
459typedef uint16_t u16;
460typedef uint32_t u32;
461typedef uint64_t u64;
462
463typedef int8_t i8;
464typedef int16_t i16;
465typedef int32_t i32;
466typedef int64_t i64;
467
468typedef float f32;
469typedef double f64;
470
471typedef intptr_t isize;
472typedef uintptr_t usize;
473
474{enum_classes_definitions}
475{classes_forward_declarations}
476{externs}
477"
478    )
479}