1use std::collections::HashSet;
2use std::fmt::Display;
3
4use super::{function_helper::*, templates::*, translate_c_enums};
8use crate::binding_types::{Exceptions, RustWrapperType, WrapperType};
9use crate::cpp::externs::create_extern_imports;
10use crate::cpp::generator::create_classes_forward_declarations;
11use crate::enum_helpers::{
12 create_field_getter_function,
13 create_variant_getter_function,
14 enum_tag_name,
15 get_fields,
16 is_many_fields_variant,
17 is_primitive_enum,
18 variant_wrapper_ident,
19};
20use crate::extern_module_translator::{
21 ExternFunction,
22 ExternModuleTranslator,
23 Function,
24 RustTrait,
25};
26use crate::swift::ffi_protocols::get_supported_ffi_protocols;
27use crate::EXPORTED_SYMBOLS_PREFIX;
28
29struct TranslatableTraits {
30 protocols: Vec<String>,
31 extensions: Vec<String>,
32}
33
34fn create_class_methods_definitions(extern_module_translator: &ExternModuleTranslator) -> String {
35 extern_module_translator
36 .user_custom_types
37 .iter()
38 .map(|(wrapper_type, vec_of_functions)| {
39 let class_name = wrapper_type.wrapper_name.to_string();
40 let class_functions = vec_of_functions
41 .iter()
42 .map(|f| FunctionTranslator::from_class_method(f, &class_name))
43 .map(FunctionTranslator::generate_definition)
44 .collect::<String>();
45 let mut translatable_traits =
46 get_translatable_traits(class_name.to_owned(), &wrapper_type.impl_traits);
47 translatable_traits
48 .protocols
49 .insert(0, "Opaque".to_string()); custom_class_definition(
51 &class_name,
52 translatable_traits.protocols.join(", ").as_str(),
53 &class_functions,
54 translatable_traits.extensions.join("\n").as_str(),
55 )
56 })
57 .collect::<String>()
58}
59
60fn get_translatable_traits(class_name: String, traits: &Vec<RustTrait>) -> TranslatableTraits {
61 let mut protocols = HashSet::new();
62 let mut extensions = HashSet::new();
63 for trait_ in traits {
64 if !trait_.has_methods {
65 if let Some(swift_name) = translate_marker_name(trait_.name.to_owned()) {
67 protocols.insert(swift_name.to_string());
68 }
69 } else {
70 if let Some(swift_protocol_name) = translate_marker_name(trait_.name.to_owned()) {
72 if let Some(extension_declaration) =
73 extension_declaration_template(class_name.as_str(), swift_protocol_name)
74 {
75 extensions.insert(extension_declaration);
76 }
77 }
78 }
79 }
80
81 TranslatableTraits {
82 protocols: protocols.into_iter().collect(),
83 extensions: extensions.into_iter().collect(),
84 }
85}
86
87fn translate_marker_name(rust_trait_name: String) -> Option<&'static str> {
88 match rust_trait_name.as_str() {
89 "Send" | "Sync" => Some("@unchecked Sendable"),
90 "Eq" | "PartialEq" => Some("Equatable"),
91 "Hash" => Some("Hashable"),
92 "Ord" | "PartialOrd" => Some("Comparable"),
93 "Read" => Some("FFIRead"),
94 "Write" => Some("FFIWrite"),
95 _ => None,
96 }
97}
98
99fn translate_type_names(mut fun: Function) -> Function {
100 if let Some(ret_type) = &mut fun.return_type {
101 if ret_type.wrapper_name.as_str() == "String" {
102 ret_type.wrapper_name = "RustString".to_string()
103 }
104 }
105 fun
106}
107
108fn create_complex_enum_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
109 extern_module_translator
110 .shared_enums
111 .iter()
112 .filter(|e| !is_primitive_enum(e))
113 .map(|enum_item| {
114 let class_name = enum_item.ident.to_string();
115
116 let variant_getters = enum_item
117 .variants
118 .iter()
119 .filter_map(|variant| create_variant_getter_function(enum_item, variant))
120 .map(translate_type_names);
121
122 let many_fields_variants_wrapper: String = enum_item
123 .variants
124 .iter()
125 .filter(|v| is_many_fields_variant(v))
126 .map(|variant| {
127 let fields = get_fields(variant).unwrap();
128 let variant_wrapper_name =
129 variant_wrapper_ident(&enum_item.ident, &variant.ident).to_string();
130 let variant_wrapper_getters = fields
131 .iter()
132 .enumerate()
133 .map(|(field_idx, field)| {
134 translate_type_names(create_field_getter_function(
135 enum_item, variant, field, field_idx,
136 ))
137 })
138 .map(|f| FunctionTranslator::from_class_method(&f, &variant_wrapper_name))
139 .map(FunctionTranslator::generate_definition)
140 .collect::<String>();
141 custom_class_definition(
142 variant_wrapper_name.as_str(),
143 "Opaque",
144 &variant_wrapper_getters,
145 "",
146 )
147 })
148 .collect();
149
150 let class_functions = variant_getters
151 .map(|f| FunctionTranslator::from_class_method(&f, &class_name))
152 .map(FunctionTranslator::generate_definition)
153 .collect::<String>();
154 let enum_tag_name = enum_tag_name(&enum_item.ident);
155 let tag_getter_fn = format!(
156 " public func getTag() -> {enum_tag_name} {{
157 return self._self.load(as: {enum_tag_name}.self)
158 }}\n"
159 );
160 custom_class_definition(
161 &class_name,
162 "Opaque",
163 &(class_functions + &tag_getter_fn),
164 "",
165 ) + &many_fields_variants_wrapper
166 })
167 .collect::<String>()
168}
169
170fn create_protocols_declarations(extern_module_translator: &ExternModuleTranslator) -> String {
171 extern_module_translator
172 .user_traits
173 .iter()
174 .map(|(wrapper_type, vec_of_functions)| {
175 let class_name = wrapper_type.wrapper_name.to_string();
176 let functions_declaration: String = vec_of_functions
177 .iter()
178 .map(|f| FunctionHelperVirtual::from_virtual_function(f, &class_name))
179 .map(FunctionHelperVirtual::generate_virtual_declaration)
180 .collect();
181 protocol_declaration(&class_name, &functions_declaration)
182 })
183 .collect::<String>()
184}
185
186fn create_virtual_method_calls(extern_module_translator: &ExternModuleTranslator) -> String {
187 extern_module_translator
188 .user_traits
189 .iter()
190 .map(|(wrapper_type, vec_of_functions)| {
191 let class_name = wrapper_type.wrapper_name.to_string();
192 vec_of_functions
193 .iter()
194 .map(|f| FunctionHelperVirtual::from_virtual_function(f, &class_name))
195 .map(FunctionHelperVirtual::generate_virtual_definition)
196 .collect::<String>()
197 })
198 .collect::<String>()
199}
200
201fn create_rust_types_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
202 extern_module_translator
203 .rust_types_wrappers
204 .ordered_iter()
205 .filter_map(|wrapper| match wrapper {
206 WrapperType {
207 rust_type: RustWrapperType::Vector(inner_type),
208 ..
209 } => {
210 let inner_type_name = inner_type.get_name();
211 let is_generic = matches!(inner_type.rust_type, RustWrapperType::Option(_))
212 || matches!(inner_type.rust_type, RustWrapperType::Vector(_));
213 Some(vector_impl(
214 &inner_type_name,
215 &inner_type.wrapper_name,
216 is_generic,
217 matches!(inner_type.rust_type, RustWrapperType::Primitive),
218 ))
219 }
220 WrapperType {
221 rust_type: RustWrapperType::Option(inner_type),
222 ..
223 } => {
224 let inner_type_name = inner_type.get_name();
225 let is_generic = matches!(inner_type.rust_type, RustWrapperType::Option(_))
226 || matches!(inner_type.rust_type, RustWrapperType::Vector(_));
227 Some(option_class(
228 &inner_type_name,
229 &inner_type.wrapper_name,
230 is_generic,
231 ))
232 }
233 WrapperType {
234 rust_type: RustWrapperType::Exceptions(Exceptions::NonPrimitive(idents)),
235 wrapper_name,
236 ..
237 } => Some(
238 idents
239 .iter()
240 .map(|exception| {
241 create_non_primitive_exception_class(
242 &exception.to_string(),
243 wrapper_name,
244 extern_module_translator.exception_trait_methods.iter(),
245 )
246 })
247 .collect::<String>(),
248 ),
249 WrapperType {
250 rust_type: RustWrapperType::Exceptions(Exceptions::Primitive(idents)),
251 wrapper_name,
252 ..
253 } => Some(
254 idents
255 .iter()
256 .map(|exception| {
257 create_primitive_exception_class(
258 &exception.to_string(),
259 wrapper_name,
260 extern_module_translator.exception_trait_methods.iter(),
261 )
262 })
263 .collect::<String>(),
264 ),
265 _ => None,
266 })
267 .collect()
268}
269
270fn create_global_functions_definitions(
271 extern_module_translator: &ExternModuleTranslator,
272) -> String {
273 extern_module_translator
274 .global_functions
275 .iter()
276 .map(FunctionTranslator::from_global_function)
277 .map(FunctionTranslator::generate_definition)
278 .collect()
279}
280
281pub fn create_non_primitive_exception_class<'a>(
283 exception: &impl Display,
284 err_name: &impl Display,
285 custom_methods: impl Iterator<Item = &'a Function>,
286) -> String {
287 let custom_methods = create_exception_custom_methods(custom_methods, err_name, "err._self");
288 format_exception_class(exception, err_name, &custom_methods)
289}
290
291pub fn create_primitive_exception_class<'a>(
293 exception: &impl Display,
294 err_name: &impl Display,
295 custom_methods: impl Iterator<Item = &'a Function>,
296) -> String {
297 let custom_methods = create_exception_custom_methods(custom_methods, err_name, "&err");
298 format_exception_class(exception, err_name, &custom_methods)
299}
300
301fn format_exception_class(
302 exception: &impl Display,
303 err_name: &impl Display,
304 custom_methods: &impl Display,
305) -> String {
306 let exception_name = format!("{err_name}_{exception}Exception");
307 format!(
308 "
309public class {exception_name} : {RUST_EXCEPTION_BASE_CLASS_NAME} {{
310 private(set) var err: {err_name}
311 init(_ err: {err_name}) {{ self.err = err }}
312{custom_methods}
313}}
314"
315 )
316}
317
318fn create_enum_init_method(extern_module_translator: &ExternModuleTranslator) -> String {
319 extern_module_translator
320 .shared_enums
321 .iter()
322 .filter(|e| is_primitive_enum(e))
323 .map(|enum_class| {
324 let enum_name = &enum_class.ident;
325 format!(
326 "extension {enum_name} {{
327 init(_ enumObj: {enum_name}) {{
328 self = enumObj
329 }}
330}}\n"
331 )
332 })
333 .collect()
334}
335
336fn create_result_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
337 extern_module_translator
338 .rust_types_wrappers
339 .ordered_iter()
340 .filter_map(|wrapper| match wrapper {
341 WrapperType {
342 rust_type: RustWrapperType::Result(ok_type, exceptions_type),
343 ..
344 } => {
345 let ok_type = ok_type.get_name();
346 let error_enum_name = &exceptions_type.wrapper_name;
347 Some(result_class(
348 &wrapper.wrapper_name,
349 &ok_type,
350 error_enum_name,
351 ))
352 }
353 _ => None,
354 })
355 .collect()
356}
357
358fn create_exception_custom_methods<'a>(
359 custom_methods: impl Iterator<Item = &'a Function>,
360 err_name: &impl Display,
361 rust_obj_ptr: impl Display,
362) -> impl Display {
363 custom_methods
364 .map(|fun| {
365 let return_type = fun
366 .return_type
367 .as_ref()
368 .map(|wrapper| wrapper.wrapper_name.as_str())
369 .unwrap_or("");
370 let function_name = &fun.name;
371 let ffi_call = format!("{EXPORTED_SYMBOLS_PREFIX}${err_name}${function_name}");
372 let ffi_call = format!("{ffi_call}({rust_obj_ptr})");
373 let ffi_call = match &fun.return_type {
374 None
375 | Some(WrapperType {
376 rust_type: RustWrapperType::Primitive | RustWrapperType::FieldlessEnum,
377 ..
378 }) => ffi_call,
379 Some(WrapperType { wrapper_name, .. }) => {
380 format!("{wrapper_name}({ffi_call})")
381 }
382 };
383 format!(
384 " public func {function_name}() -> {return_type} {{
385 return {ffi_call}
386 }}"
387 )
388 })
389 .collect::<Vec<_>>()
390 .join("\n")
391}
392
393fn base_exception_method(function: &Function) -> String {
394 let return_type = &function
395 .return_type
396 .as_ref()
397 .map(|t| t.get_name())
398 .unwrap_or_else(|| "".to_string());
399 let name = &function.name;
400 format!(" func {name}() -> {return_type};")
401}
402
403fn base_exception_class(emt: &ExternModuleTranslator) -> String {
404 let exception_trait_methods = emt
405 .exception_trait_methods
406 .iter()
407 .map(base_exception_method)
408 .collect::<Vec<_>>()
409 .join("\n");
410 format!("public protocol {RUST_EXCEPTION_BASE_CLASS_NAME} : Error {{\n{exception_trait_methods}\n}}\n")
411}
412
413pub fn generate_swift_file(extern_module_translator: &ExternModuleTranslator) -> String {
417 let classes_definition = create_class_methods_definitions(extern_module_translator);
418 let complex_enum_classes_definitions = create_complex_enum_wrappers(extern_module_translator);
419 let protocols_declaration = create_protocols_declarations(extern_module_translator);
420 let ffi_protocols_declaration = get_supported_ffi_protocols();
421 let virtual_methods_calls = create_virtual_method_calls(extern_module_translator);
422 let rust_types_wrappers = create_rust_types_wrappers(extern_module_translator);
423 let global_functions_definition: String =
424 create_global_functions_definitions(extern_module_translator);
425 let base_exception_class = base_exception_class(extern_module_translator);
426 let result_wrapper = create_result_wrappers(extern_module_translator);
427 let enum_init_methods = create_enum_init_method(extern_module_translator);
428 format!(
429 "{PREDEFINED}
430{enum_init_methods}
431{complex_enum_classes_definitions}
432{result_wrapper}
433{base_exception_class}
434{rust_types_wrappers}
435{classes_definition}
436{global_functions_definition}
437{virtual_methods_calls}
438{protocols_declaration}
439{ffi_protocols_declaration}"
440 )
441}
442
443pub fn generate_c_externs_file(
447 extern_module_translator: &ExternModuleTranslator,
448 extern_functions: &[ExternFunction],
449) -> String {
450 let externs = create_extern_imports(extern_functions);
451 let classes_forward_declarations =
452 create_classes_forward_declarations(extern_module_translator);
453 let enum_classes_definitions = translate_c_enums(extern_module_translator);
454 format!(
455 "#include <stdbool.h>
456#include <stdint.h>
457
458typedef uint8_t u8;
459typedef uint16_t u16;
460typedef uint32_t u32;
461typedef uint64_t u64;
462
463typedef int8_t i8;
464typedef int16_t i16;
465typedef int32_t i32;
466typedef int64_t i64;
467
468typedef float f32;
469typedef double f64;
470
471typedef intptr_t isize;
472typedef uintptr_t usize;
473
474{enum_classes_definitions}
475{classes_forward_declarations}
476{externs}
477"
478 )
479}