1use std::fmt::Display;
19
20use convert_case::{Case, Casing};
21use syn::punctuated::Punctuated;
22use syn::token::Comma;
23use syn::{Field, ItemEnum, Variant};
24
25use super::templates::TargetLanguageTypeName;
26use crate::binding_types::RustWrapperType;
27use crate::cpp::FunctionVirtualTranslator;
28use crate::enum_helpers::{
29 enum_tag_name,
30 field_name,
31 get_fields,
32 is_ignored_variant,
33 is_many_fields_variant,
34 is_primitive_enum,
35 variant_wrapper_ident,
36};
37use crate::extern_module_translator::ExternModuleTranslator;
38
39const WASM_EXCEPTION_CLASS_NAME: &str = "RustException";
40const CPP_WASM_EXCEPTION_CLASS_NAME: &str = "WasmException";
41
42pub fn create_wasm_module(extern_module_translator: &ExternModuleTranslator) -> String {
43 let enum_classes = translate_enums(extern_module_translator);
44 let user_classes = create_user_classes(extern_module_translator);
45 let rust_wrappers = create_rust_wrappers(extern_module_translator);
46 let abstract_classes_wrappers: String = create_custom_traits_wrappers(extern_module_translator);
47 let abstract_classes_gluecode = create_custom_traits_gluecode(extern_module_translator);
48 let classes: String = rust_wrappers
49 .chain(user_classes)
50 .chain(abstract_classes_gluecode)
51 .collect();
52 let global_functions: String = extern_module_translator
53 .global_functions
54 .iter()
55 .map(|f| global_def_template(&f.name.to_string()))
56 .collect();
57 let exception_class_enum = enum_class_bindings(
58 "ExceptionClass",
59 extern_module_translator.exception_names.iter(),
60 );
61 let exception_trait_methods = extern_module_translator
62 .exception_trait_methods
63 .iter()
64 .map(|fun| class_method(&fun.name, CPP_WASM_EXCEPTION_CLASS_NAME))
65 .collect::<Vec<_>>()
66 .join("\n");
67 let exception_class = class_template(
68 CPP_WASM_EXCEPTION_CLASS_NAME,
69 &(class_constructor("unsigned")
70 + &class_method("exception_class", CPP_WASM_EXCEPTION_CLASS_NAME)
71 + &class_method("what", CPP_WASM_EXCEPTION_CLASS_NAME)
72 + &exception_trait_methods),
73 WASM_EXCEPTION_CLASS_NAME,
74 );
75 generate_cpp_code_for_wasm(
76 &abstract_classes_wrappers,
77 &classes,
78 &global_functions,
79 &enum_classes,
80 &exception_class_enum,
81 &exception_class,
82 )
83}
84
85fn translate_enums(emt: &ExternModuleTranslator) -> String {
86 let (fieldless_enums, data_enums): (Vec<&ItemEnum>, Vec<&ItemEnum>) = emt
87 .shared_enums
88 .iter()
89 .partition(|enum_item| is_primitive_enum(enum_item));
90 let fieldless_enums_bindings = fieldless_enums
91 .into_iter()
92 .map(|enum_item| {
93 enum_class_bindings(
94 &enum_item.ident,
95 enum_item.variants.iter().map(|variant| &variant.ident),
96 )
97 })
98 .collect::<String>();
99 let data_enums_bindings = data_enums
100 .into_iter()
101 .map(data_enum_emscripten_bindings)
102 .collect::<String>();
103 format!(
104 "{fieldless_enums_bindings}
105{data_enums_bindings}"
106 )
107}
108
109fn data_enum_emscripten_bindings(enum_item: &ItemEnum) -> String {
110 let tag_enum = tag_enum_bindings(enum_item);
111 let payload_object_values = create_variant_object_values(enum_item);
112 let enum_struct = enum_object_class_bindings(enum_item);
113
114 format!("{enum_struct}{payload_object_values}{tag_enum}")
115}
116
117fn create_variant_object_values(enum_item: &ItemEnum) -> String {
118 let enum_name = &enum_item.ident;
119 enum_item
120 .variants
121 .iter()
122 .filter(|v| is_many_fields_variant(v))
123 .map(|v| create_variant_object_value(v, &variant_wrapper_ident(enum_name, &v.ident)))
124 .collect::<Vec<String>>()
125 .join("\n")
126}
127
128fn create_variant_object_value(
129 variant: &Variant,
130 variant_wrapper_name: impl Display + Copy,
131) -> String {
132 let field_getters = field_getters(get_fields(variant).unwrap(), variant_wrapper_name);
133 format!(
134 "
135EMSCRIPTEN_BINDINGS({variant_wrapper_name}) {{
136 emscripten::class_<{variant_wrapper_name}>(\"{variant_wrapper_name}\")
137{field_getters}
138 ;
139}}
140"
141 )
142}
143
144fn field_getters(fields: &Punctuated<Field, Comma>, struct_name: impl Display) -> String {
145 fields
146 .iter()
147 .enumerate()
148 .map(|(idx, field)| {
149 let field_name = field_name(field).unwrap_or(format!("_{idx}"));
150 let field_name = if let Some(stripped) = field_name.strip_prefix('_') {
151 stripped
152 } else {
153 &field_name
154 };
155 format!(" .function(\"get_{field_name}\", &{struct_name}::get_{field_name})")
156 })
157 .collect::<Vec<_>>()
158 .join("\n")
159}
160
161fn enum_object_class_bindings(enum_item: &ItemEnum) -> String {
162 let enum_name = &enum_item.ident;
163 let variant_getters = create_enum_variant_getters(enum_item);
164 format!(
165 "
166EMSCRIPTEN_BINDINGS({enum_name}) {{
167 emscripten::class_<{enum_name}>(\"{enum_name}\")
168 .function(\"get_tag\", &{enum_name}::get_tag)
169{variant_getters}
170 ;
171}}
172"
173 )
174}
175
176fn create_enum_variant_getters(enum_item: &ItemEnum) -> String {
177 let enum_name = enum_item.ident.to_string();
178 enum_item
179 .variants
180 .iter()
181 .filter_map(|v| create_enum_variant_getter(v, &enum_name))
182 .collect::<Vec<_>>()
183 .join("\n")
184}
185
186fn create_enum_variant_getter(variant: &Variant, enum_name: &str) -> Option<String> {
187 match get_fields(variant) {
188 Some(_) if !is_ignored_variant(variant) => {
189 let variant_name = &variant.ident.to_string().to_case(Case::Snake);
190 Some(format!(
191 " .function(\"get_{variant_name}\", &{enum_name}::get_{variant_name})"
192 ))
193 }
194 _ => None,
195 }
196}
197
198fn tag_enum_bindings(enum_item: &ItemEnum) -> String {
199 let enum_name_tag = enum_tag_name(enum_item.ident.to_string().as_str());
200 let variants = enum_item
201 .variants
202 .iter()
203 .map(|variant| {
204 let variant_ident = &variant.ident;
205 format!(" .value(\"{variant_ident}\", {enum_name_tag}::{variant_ident})")
206 })
207 .collect::<Vec<_>>()
208 .join("\n");
209 format!(
210 "
211EMSCRIPTEN_BINDINGS({enum_name_tag}) {{
212 emscripten::enum_<{enum_name_tag}>(\"{enum_name_tag}\")
213{variants}
214 ;
215}}
216"
217 )
218}
219
220fn enum_class_bindings<T: Display, U: Display>(
221 enum_name: T,
222 variants: impl Iterator<Item = U>,
223) -> String {
224 let variants = variants
225 .map(|v| format!(" .value(\"{v}\", {enum_name}::{v})"))
226 .collect::<Vec<String>>()
227 .join("\n");
228 format!(
229 "
230EMSCRIPTEN_BINDINGS({enum_name}) {{
231 emscripten::enum_<{enum_name}>(\"{enum_name}\")
232{variants}
233 ;
234}}
235"
236 )
237}
238
239fn create_user_classes(
240 extern_module_translator: &ExternModuleTranslator,
241) -> impl Iterator<Item = String> + '_ {
242 extern_module_translator
243 .user_custom_types
244 .iter()
245 .map(|(wrapper, functions)| {
246 let class_name = wrapper.wrapper_name.to_string();
247 let functions: String = functions
248 .iter()
249 .map(|f| class_method(&f.name.to_string(), &class_name))
250 .collect();
251 class_template(&class_name, &functions, &class_name)
252 })
253}
254
255fn create_rust_wrappers(
256 extern_module_translator: &ExternModuleTranslator,
257) -> impl Iterator<Item = String> + '_ {
258 extern_module_translator
259 .rust_types_wrappers
260 .unordered_iter()
261 .map(|wrapper| match &wrapper.rust_type {
262 RustWrapperType::Option(inner_type) => {
263 let inner_type_generics = inner_type.get_name();
264 let inner_type = inner_type.wrapper_name.to_string();
265 let class_name = format!("Optional<{inner_type_generics}>");
266 let target_name = format!("Optional{inner_type}");
267 let functions = [
268 class_constructor(""),
269 class_constructor(&inner_type_generics),
270 class_method("unwrap", &class_name),
271 class_method("is_some", &class_name),
272 ]
273 .join("");
274 class_template(&class_name, &functions, &target_name)
275 }
276 RustWrapperType::Vector(inner_type) => {
277 let inner_type_generics = inner_type.get_name();
278 let inner_type = inner_type.wrapper_name.to_string();
279 let class_name = format!("RustVec<{inner_type_generics}>");
280 let target_name = format!("Vec{inner_type}");
281 let functions = [
282 class_constructor(""),
283 class_method("at", &class_name),
284 class_method("size", &class_name),
285 class_method("push", &class_name),
286 ]
287 .join("");
288 class_template(&class_name, &functions, &target_name)
289 }
290 RustWrapperType::Result(_, _) => {
291 let class_name = &wrapper.wrapper_name;
292 let functions = [
293 class_function("from_ok", class_name),
294 class_function("from_err", class_name),
295 ]
296 .join("");
297 class_template(class_name, &functions, class_name)
298 }
299 _ => "".to_owned(),
300 })
301}
302
303fn create_custom_traits_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
304 extern_module_translator
305 .user_traits
306 .iter()
307 .map(|(wrapper, functions)| {
308 let class_name = wrapper.wrapper_name.to_string();
309 let functions_calls: String = functions
310 .iter()
311 .map(|f| FunctionVirtualTranslator::from_virtual_function(f, &class_name))
312 .map(|f_helper| {
313 let return_type_string = if let Some(ref wrapper) = f_helper.return_type {
314 wrapper.get_name_for_abstract_method()
315 } else {
316 "void".to_owned()
317 };
318 let function_name = f_helper.function_name;
319 let function_signature = f_helper.generated_virtual_function_signature;
320 let args: String = f_helper.arg_names[1..]
321 .iter()
322 .map(|arg| format!("std::move({arg})"))
323 .collect::<Vec<String>>()
324 .join(", ");
325 virtual_method_call(
326 &function_name,
327 &return_type_string,
328 &function_signature,
329 &args,
330 )
331 })
332 .collect();
333 abstract_class_wrapper(&class_name, &functions_calls)
334 })
335 .collect()
336}
337
338fn create_custom_traits_gluecode(
339 extern_module_translator: &ExternModuleTranslator,
340) -> impl Iterator<Item = String> + '_ {
341 extern_module_translator
342 .user_traits
343 .iter()
344 .map(|(wrapper, functions)| {
345 let class_name = wrapper.wrapper_name.to_string();
346 let virtual_functions: String = functions
347 .iter()
348 .map(|function| virtual_function(&function.name.to_string(), &class_name))
349 .collect();
350 abstract_class(&wrapper.wrapper_name.to_string(), &virtual_functions)
351 })
352}
353
354fn virtual_method_call(
355 function_name: &str,
356 return_type: &str,
357 function_signature: &str,
358 args: &str,
359) -> String {
360 if args.is_empty() {
361 format!(
362 " {return_type} {function_name}({function_signature}) {{
363 return call<{return_type}>(\"{function_name}\");
364 }}\n"
365 )
366 } else {
367 format!(
368 " {return_type} {function_name}({function_signature}) {{
369 return call<{return_type}>(\"{function_name}\", {args});
370 }}\n"
371 )
372 }
373}
374
375fn virtual_function(function_name: &str, class_name: &str) -> String {
376 format!(" .function(\"{function_name}\", &{class_name}::{function_name}, emscripten::pure_virtual())\n")
377}
378
379fn abstract_class_wrapper(class_name: &str, functions_calls: &str) -> String {
380 format!(
381 "
382 struct {class_name}Wrapper : public emscripten::wrapper<{class_name}> {{
383 EMSCRIPTEN_WRAPPER({class_name}Wrapper);
384{functions_calls}
385 }};\n"
386 )
387}
388
389fn abstract_class(class_name: &str, virtual_functions: &str) -> String {
390 format!(
391 "
392 emscripten::class_<{class_name}>(\"{class_name}\")
393{virtual_functions}
394 .allow_subclass<{class_name}Wrapper>(\"{class_name}Wrapper\")
395 ;\n"
396 )
397}
398
399fn generate_cpp_code_for_wasm(
400 abstract_classes_wrappers: &str,
401 classes: &str,
402 global_functions: &str,
403 enum_classes: &str,
404 exception_class_enum: &str,
405 exception_class: &str,
406) -> String {
407 format!(
408 "
409#ifdef WASM
410#include <emscripten/bind.h>
411#include <utility>
412{exception_class_enum}
413{enum_classes}
414{abstract_classes_wrappers}
415
416EMSCRIPTEN_BINDINGS(WasmModule) {{
417 emscripten::class_<String>(\"String\")
418 .constructor<std::string>()
419 .function(\"to_string\", &String::to_string)
420 ;
421{exception_class}
422{classes}
423{global_functions}
424}}
425#endif
426"
427 )
428}
429
430fn class_template(class_name: &str, functions: &str, target_name: &str) -> String {
431 format!(" emscripten::class_<{class_name}>(\"{target_name}\")\n{functions} ;\n")
432}
433
434fn class_method(function_name: &str, class_name: &str) -> String {
435 format!(" .function(\"{function_name}\", &{class_name}::{function_name})\n")
436}
437
438fn class_function(function_name: &str, class_name: &str) -> String {
439 format!(" .class_function(\"{function_name}\", &{class_name}::{function_name})\n")
440}
441
442fn class_constructor(args: &str) -> String {
443 format!(" .constructor<{args}>()\n")
444}
445
446fn global_def_template(function_name: &str) -> String {
447 format!(" emscripten::function(\"{function_name}\", &{function_name});\n")
448}
449
450#[cfg(test)]
451mod tests {
452 use pretty_assertions::assert_eq;
453
454 use super::*;
455 use crate::utils::helpers;
456
457 #[test]
458 fn test_enum_class_emscripten_bindings() {
459 let enum_item = helpers::get_enum_item();
460 let emscripten_binding =
461 enum_class_bindings(enum_item.ident, enum_item.variants.iter().map(|i| &i.ident));
462 let expected = "
463EMSCRIPTEN_BINDINGS(En1) {
464 emscripten::enum_<En1>(\"En1\")
465 .value(\"V1\", En1::V1)
466 .value(\"V2\", En1::V2)
467 ;
468}
469";
470 assert_eq!(emscripten_binding, expected);
471 }
472}