rusty_bind_parser/cpp/
wasm_generator.rs

1//
2// Wildland Project
3//
4// Copyright © 2022 Golem Foundation,
5//
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License version 3 as published by
8// the Free Software Foundation.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU General Public License for more details.
14//
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18use std::fmt::Display;
19
20use convert_case::{Case, Casing};
21use syn::punctuated::Punctuated;
22use syn::token::Comma;
23use syn::{Field, ItemEnum, Variant};
24
25use super::templates::TargetLanguageTypeName;
26use crate::binding_types::RustWrapperType;
27use crate::cpp::FunctionVirtualTranslator;
28use crate::enum_helpers::{
29    enum_tag_name,
30    field_name,
31    get_fields,
32    is_ignored_variant,
33    is_many_fields_variant,
34    is_primitive_enum,
35    variant_wrapper_ident,
36};
37use crate::extern_module_translator::ExternModuleTranslator;
38
39const WASM_EXCEPTION_CLASS_NAME: &str = "RustException";
40const CPP_WASM_EXCEPTION_CLASS_NAME: &str = "WasmException";
41
42pub fn create_wasm_module(extern_module_translator: &ExternModuleTranslator) -> String {
43    let enum_classes = translate_enums(extern_module_translator);
44    let user_classes = create_user_classes(extern_module_translator);
45    let rust_wrappers = create_rust_wrappers(extern_module_translator);
46    let abstract_classes_wrappers: String = create_custom_traits_wrappers(extern_module_translator);
47    let abstract_classes_gluecode = create_custom_traits_gluecode(extern_module_translator);
48    let classes: String = rust_wrappers
49        .chain(user_classes)
50        .chain(abstract_classes_gluecode)
51        .collect();
52    let global_functions: String = extern_module_translator
53        .global_functions
54        .iter()
55        .map(|f| global_def_template(&f.name.to_string()))
56        .collect();
57    let exception_class_enum = enum_class_bindings(
58        "ExceptionClass",
59        extern_module_translator.exception_names.iter(),
60    );
61    let exception_trait_methods = extern_module_translator
62        .exception_trait_methods
63        .iter()
64        .map(|fun| class_method(&fun.name, CPP_WASM_EXCEPTION_CLASS_NAME))
65        .collect::<Vec<_>>()
66        .join("\n");
67    let exception_class = class_template(
68        CPP_WASM_EXCEPTION_CLASS_NAME,
69        &(class_constructor("unsigned")
70            + &class_method("exception_class", CPP_WASM_EXCEPTION_CLASS_NAME)
71            + &class_method("what", CPP_WASM_EXCEPTION_CLASS_NAME)
72            + &exception_trait_methods),
73        WASM_EXCEPTION_CLASS_NAME,
74    );
75    generate_cpp_code_for_wasm(
76        &abstract_classes_wrappers,
77        &classes,
78        &global_functions,
79        &enum_classes,
80        &exception_class_enum,
81        &exception_class,
82    )
83}
84
85fn translate_enums(emt: &ExternModuleTranslator) -> String {
86    let (fieldless_enums, data_enums): (Vec<&ItemEnum>, Vec<&ItemEnum>) = emt
87        .shared_enums
88        .iter()
89        .partition(|enum_item| is_primitive_enum(enum_item));
90    let fieldless_enums_bindings = fieldless_enums
91        .into_iter()
92        .map(|enum_item| {
93            enum_class_bindings(
94                &enum_item.ident,
95                enum_item.variants.iter().map(|variant| &variant.ident),
96            )
97        })
98        .collect::<String>();
99    let data_enums_bindings = data_enums
100        .into_iter()
101        .map(data_enum_emscripten_bindings)
102        .collect::<String>();
103    format!(
104        "{fieldless_enums_bindings}
105{data_enums_bindings}"
106    )
107}
108
109fn data_enum_emscripten_bindings(enum_item: &ItemEnum) -> String {
110    let tag_enum = tag_enum_bindings(enum_item);
111    let payload_object_values = create_variant_object_values(enum_item);
112    let enum_struct = enum_object_class_bindings(enum_item);
113
114    format!("{enum_struct}{payload_object_values}{tag_enum}")
115}
116
117fn create_variant_object_values(enum_item: &ItemEnum) -> String {
118    let enum_name = &enum_item.ident;
119    enum_item
120        .variants
121        .iter()
122        .filter(|v| is_many_fields_variant(v))
123        .map(|v| create_variant_object_value(v, &variant_wrapper_ident(enum_name, &v.ident)))
124        .collect::<Vec<String>>()
125        .join("\n")
126}
127
128fn create_variant_object_value(
129    variant: &Variant,
130    variant_wrapper_name: impl Display + Copy,
131) -> String {
132    let field_getters = field_getters(get_fields(variant).unwrap(), variant_wrapper_name);
133    format!(
134        "
135EMSCRIPTEN_BINDINGS({variant_wrapper_name}) {{
136    emscripten::class_<{variant_wrapper_name}>(\"{variant_wrapper_name}\")
137{field_getters}
138        ;
139}}
140"
141    )
142}
143
144fn field_getters(fields: &Punctuated<Field, Comma>, struct_name: impl Display) -> String {
145    fields
146        .iter()
147        .enumerate()
148        .map(|(idx, field)| {
149            let field_name = field_name(field).unwrap_or(format!("_{idx}"));
150            let field_name = if let Some(stripped) = field_name.strip_prefix('_') {
151                stripped
152            } else {
153                &field_name
154            };
155            format!("        .function(\"get_{field_name}\", &{struct_name}::get_{field_name})")
156        })
157        .collect::<Vec<_>>()
158        .join("\n")
159}
160
161fn enum_object_class_bindings(enum_item: &ItemEnum) -> String {
162    let enum_name = &enum_item.ident;
163    let variant_getters = create_enum_variant_getters(enum_item);
164    format!(
165        "
166EMSCRIPTEN_BINDINGS({enum_name}) {{
167    emscripten::class_<{enum_name}>(\"{enum_name}\")
168        .function(\"get_tag\", &{enum_name}::get_tag)
169{variant_getters}
170        ;
171}}
172"
173    )
174}
175
176fn create_enum_variant_getters(enum_item: &ItemEnum) -> String {
177    let enum_name = enum_item.ident.to_string();
178    enum_item
179        .variants
180        .iter()
181        .filter_map(|v| create_enum_variant_getter(v, &enum_name))
182        .collect::<Vec<_>>()
183        .join("\n")
184}
185
186fn create_enum_variant_getter(variant: &Variant, enum_name: &str) -> Option<String> {
187    match get_fields(variant) {
188        Some(_) if !is_ignored_variant(variant) => {
189            let variant_name = &variant.ident.to_string().to_case(Case::Snake);
190            Some(format!(
191                "        .function(\"get_{variant_name}\", &{enum_name}::get_{variant_name})"
192            ))
193        }
194        _ => None,
195    }
196}
197
198fn tag_enum_bindings(enum_item: &ItemEnum) -> String {
199    let enum_name_tag = enum_tag_name(enum_item.ident.to_string().as_str());
200    let variants = enum_item
201        .variants
202        .iter()
203        .map(|variant| {
204            let variant_ident = &variant.ident;
205            format!("        .value(\"{variant_ident}\", {enum_name_tag}::{variant_ident})")
206        })
207        .collect::<Vec<_>>()
208        .join("\n");
209    format!(
210        "
211EMSCRIPTEN_BINDINGS({enum_name_tag}) {{
212    emscripten::enum_<{enum_name_tag}>(\"{enum_name_tag}\")
213{variants}
214        ;
215}}
216"
217    )
218}
219
220fn enum_class_bindings<T: Display, U: Display>(
221    enum_name: T,
222    variants: impl Iterator<Item = U>,
223) -> String {
224    let variants = variants
225        .map(|v| format!("        .value(\"{v}\", {enum_name}::{v})"))
226        .collect::<Vec<String>>()
227        .join("\n");
228    format!(
229        "
230EMSCRIPTEN_BINDINGS({enum_name}) {{
231    emscripten::enum_<{enum_name}>(\"{enum_name}\")
232{variants}
233        ;
234}}
235"
236    )
237}
238
239fn create_user_classes(
240    extern_module_translator: &ExternModuleTranslator,
241) -> impl Iterator<Item = String> + '_ {
242    extern_module_translator
243        .user_custom_types
244        .iter()
245        .map(|(wrapper, functions)| {
246            let class_name = wrapper.wrapper_name.to_string();
247            let functions: String = functions
248                .iter()
249                .map(|f| class_method(&f.name.to_string(), &class_name))
250                .collect();
251            class_template(&class_name, &functions, &class_name)
252        })
253}
254
255fn create_rust_wrappers(
256    extern_module_translator: &ExternModuleTranslator,
257) -> impl Iterator<Item = String> + '_ {
258    extern_module_translator
259        .rust_types_wrappers
260        .unordered_iter()
261        .map(|wrapper| match &wrapper.rust_type {
262            RustWrapperType::Option(inner_type) => {
263                let inner_type_generics = inner_type.get_name();
264                let inner_type = inner_type.wrapper_name.to_string();
265                let class_name = format!("Optional<{inner_type_generics}>");
266                let target_name = format!("Optional{inner_type}");
267                let functions = [
268                    class_constructor(""),
269                    class_constructor(&inner_type_generics),
270                    class_method("unwrap", &class_name),
271                    class_method("is_some", &class_name),
272                ]
273                .join("");
274                class_template(&class_name, &functions, &target_name)
275            }
276            RustWrapperType::Vector(inner_type) => {
277                let inner_type_generics = inner_type.get_name();
278                let inner_type = inner_type.wrapper_name.to_string();
279                let class_name = format!("RustVec<{inner_type_generics}>");
280                let target_name = format!("Vec{inner_type}");
281                let functions = [
282                    class_constructor(""),
283                    class_method("at", &class_name),
284                    class_method("size", &class_name),
285                    class_method("push", &class_name),
286                ]
287                .join("");
288                class_template(&class_name, &functions, &target_name)
289            }
290            RustWrapperType::Result(_, _) => {
291                let class_name = &wrapper.wrapper_name;
292                let functions = [
293                    class_function("from_ok", class_name),
294                    class_function("from_err", class_name),
295                ]
296                .join("");
297                class_template(class_name, &functions, class_name)
298            }
299            _ => "".to_owned(),
300        })
301}
302
303fn create_custom_traits_wrappers(extern_module_translator: &ExternModuleTranslator) -> String {
304    extern_module_translator
305        .user_traits
306        .iter()
307        .map(|(wrapper, functions)| {
308            let class_name = wrapper.wrapper_name.to_string();
309            let functions_calls: String = functions
310                .iter()
311                .map(|f| FunctionVirtualTranslator::from_virtual_function(f, &class_name))
312                .map(|f_helper| {
313                    let return_type_string = if let Some(ref wrapper) = f_helper.return_type {
314                        wrapper.get_name_for_abstract_method()
315                    } else {
316                        "void".to_owned()
317                    };
318                    let function_name = f_helper.function_name;
319                    let function_signature = f_helper.generated_virtual_function_signature;
320                    let args: String = f_helper.arg_names[1..]
321                        .iter()
322                        .map(|arg| format!("std::move({arg})"))
323                        .collect::<Vec<String>>()
324                        .join(", ");
325                    virtual_method_call(
326                        &function_name,
327                        &return_type_string,
328                        &function_signature,
329                        &args,
330                    )
331                })
332                .collect();
333            abstract_class_wrapper(&class_name, &functions_calls)
334        })
335        .collect()
336}
337
338fn create_custom_traits_gluecode(
339    extern_module_translator: &ExternModuleTranslator,
340) -> impl Iterator<Item = String> + '_ {
341    extern_module_translator
342        .user_traits
343        .iter()
344        .map(|(wrapper, functions)| {
345            let class_name = wrapper.wrapper_name.to_string();
346            let virtual_functions: String = functions
347                .iter()
348                .map(|function| virtual_function(&function.name.to_string(), &class_name))
349                .collect();
350            abstract_class(&wrapper.wrapper_name.to_string(), &virtual_functions)
351        })
352}
353
354fn virtual_method_call(
355    function_name: &str,
356    return_type: &str,
357    function_signature: &str,
358    args: &str,
359) -> String {
360    if args.is_empty() {
361        format!(
362            "    {return_type} {function_name}({function_signature}) {{
363            return call<{return_type}>(\"{function_name}\");
364        }}\n"
365        )
366    } else {
367        format!(
368            "    {return_type} {function_name}({function_signature}) {{
369            return call<{return_type}>(\"{function_name}\", {args});
370        }}\n"
371        )
372    }
373}
374
375fn virtual_function(function_name: &str, class_name: &str) -> String {
376    format!("    .function(\"{function_name}\", &{class_name}::{function_name}, emscripten::pure_virtual())\n")
377}
378
379fn abstract_class_wrapper(class_name: &str, functions_calls: &str) -> String {
380    format!(
381        "
382    struct {class_name}Wrapper : public emscripten::wrapper<{class_name}> {{
383        EMSCRIPTEN_WRAPPER({class_name}Wrapper);
384{functions_calls}
385    }};\n"
386    )
387}
388
389fn abstract_class(class_name: &str, virtual_functions: &str) -> String {
390    format!(
391        "
392    emscripten::class_<{class_name}>(\"{class_name}\")
393{virtual_functions}
394        .allow_subclass<{class_name}Wrapper>(\"{class_name}Wrapper\")
395    ;\n"
396    )
397}
398
399fn generate_cpp_code_for_wasm(
400    abstract_classes_wrappers: &str,
401    classes: &str,
402    global_functions: &str,
403    enum_classes: &str,
404    exception_class_enum: &str,
405    exception_class: &str,
406) -> String {
407    format!(
408        "
409#ifdef WASM
410#include <emscripten/bind.h>
411#include <utility>
412{exception_class_enum}
413{enum_classes}
414{abstract_classes_wrappers}
415
416EMSCRIPTEN_BINDINGS(WasmModule) {{
417    emscripten::class_<String>(\"String\")
418        .constructor<std::string>()
419        .function(\"to_string\", &String::to_string)
420        ;
421{exception_class}
422{classes}
423{global_functions}
424}}
425#endif
426"
427    )
428}
429
430fn class_template(class_name: &str, functions: &str, target_name: &str) -> String {
431    format!("    emscripten::class_<{class_name}>(\"{target_name}\")\n{functions}    ;\n")
432}
433
434fn class_method(function_name: &str, class_name: &str) -> String {
435    format!("        .function(\"{function_name}\", &{class_name}::{function_name})\n")
436}
437
438fn class_function(function_name: &str, class_name: &str) -> String {
439    format!("        .class_function(\"{function_name}\", &{class_name}::{function_name})\n")
440}
441
442fn class_constructor(args: &str) -> String {
443    format!("        .constructor<{args}>()\n")
444}
445
446fn global_def_template(function_name: &str) -> String {
447    format!("    emscripten::function(\"{function_name}\", &{function_name});\n")
448}
449
450#[cfg(test)]
451mod tests {
452    use pretty_assertions::assert_eq;
453
454    use super::*;
455    use crate::utils::helpers;
456
457    #[test]
458    fn test_enum_class_emscripten_bindings() {
459        let enum_item = helpers::get_enum_item();
460        let emscripten_binding =
461            enum_class_bindings(enum_item.ident, enum_item.variants.iter().map(|i| &i.ident));
462        let expected = "
463EMSCRIPTEN_BINDINGS(En1) {
464    emscripten::enum_<En1>(\"En1\")
465        .value(\"V1\", En1::V1)
466        .value(\"V2\", En1::V2)
467        ;
468}
469";
470        assert_eq!(emscripten_binding, expected);
471    }
472}