rusty_bind_parser/cpp/
enums.rs

1//
2// Wildland Project
3//
4// Copyright © 2022 Golem Foundation,
5//
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License version 3 as published by
8// the Free Software Foundation.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU General Public License for more details.
14//
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <https://www.gnu.org/licenses/>.
17
18use std::fmt::Display;
19
20use syn::ItemEnum;
21
22use crate::cpp::function_translator::FunctionTranslator;
23use crate::cpp::templates::custom_class_definition;
24use crate::enum_helpers::{
25    create_field_getter_function,
26    create_variant_getter_function,
27    enum_tag_name,
28    get_fields,
29    is_many_fields_variant,
30    is_primitive_enum,
31    variant_wrapper_ident,
32};
33use crate::extern_module_translator::ExternModuleTranslator;
34
35// Generates forward declarations for c++ enum equivalents
36pub fn enums_cpp_forward_declarations(emt: &ExternModuleTranslator) -> String {
37    emt.shared_enums
38        .iter()
39        .map(forward_declarations_cpp)
40        .collect()
41}
42
43/// Translates rust enum into its cpp equivalents.
44/// For rust enums with unit variants only a c++ enum class representation is generated.
45/// For complex enums it generates a struct (tagged union) that has a tag which is c++ enum class
46/// and payload which is union of all possible variants.
47pub fn translate_enums(emt: &ExternModuleTranslator) -> String {
48    emt.shared_enums.iter().map(translate_enum).collect()
49}
50
51fn forward_declarations_cpp(enum_item: &ItemEnum) -> String {
52    let enum_name = enum_item.ident.to_string();
53    let enum_tag_name = enum_tag_name(&enum_name);
54    if is_primitive_enum(enum_item) {
55        format!(
56            "
57enum class {enum_name};
58"
59        )
60    } else {
61        let variants_payloads_structs: String = enum_item
62            .variants
63            .iter()
64            .filter_map(|variant| {
65                if variant.fields.len() > 1 {
66                    Some(format!(
67                        "class {};\n",
68                        variant_wrapper_ident(&enum_name, variant.ident.to_string())
69                    ))
70                } else {
71                    None
72                }
73            })
74            .collect();
75        format!(
76            "
77enum class {enum_tag_name};
78class {enum_name};
79{variants_payloads_structs}"
80        )
81    }
82}
83
84fn translate_enum(enum_item: &ItemEnum) -> String {
85    if is_primitive_enum(enum_item) {
86        primitive_enum_class_definition(enum_item)
87    } else {
88        translate_complex_enum(enum_item)
89    }
90}
91
92pub fn create_enum_class_objects(enum_item: &ItemEnum) -> String {
93    let enum_object_class = create_enum_object_class(enum_item);
94    let variant_wrappers = create_variant_wrappers(enum_item);
95
96    variant_wrappers + &enum_object_class
97}
98
99fn create_enum_object_class(enum_item: &ItemEnum) -> String {
100    let (mut variant_getters_declarations, mut variant_getters_definitions): (Vec<_>, Vec<_>) =
101        enum_item
102            .variants
103            .iter()
104            .filter_map(|variant| create_variant_getter_function(enum_item, variant))
105            .map(|f| {
106                FunctionTranslator::from_class_method(&f, enum_item.ident.to_string().as_str())
107            })
108            .map(|ft| {
109                (
110                    FunctionTranslator::generate_declaration(ft.clone()),
111                    FunctionTranslator::generate_definition(ft),
112                )
113            })
114            .unzip();
115    variant_getters_declarations.push(tag_getter_declaration(enum_item));
116    variant_getters_definitions.push(tag_getter_definition(enum_item));
117
118    custom_class_definition(&enum_item.ident, variant_getters_declarations.join("\n"))
119        + &variant_getters_definitions.join("\n")
120}
121
122fn create_variant_wrappers(enum_item: &ItemEnum) -> String {
123    enum_item
124        .variants
125        .iter()
126        .filter(|v| is_many_fields_variant(v))
127        .map(|variant| {
128            let fields = get_fields(variant).unwrap();
129            let variant_wrapper_name =
130                variant_wrapper_ident(&enum_item.ident, &variant.ident).to_string();
131            let (getters_declarations, getters_definitions): (Vec<_>, Vec<_>) = fields
132                .iter()
133                .enumerate()
134                .map(|(field_idx, field)| {
135                    create_field_getter_function(enum_item, variant, field, field_idx)
136                })
137                .map(|f| FunctionTranslator::from_class_method(&f, &variant_wrapper_name))
138                .map(|ft| {
139                    (
140                        FunctionTranslator::generate_declaration(ft.clone()),
141                        FunctionTranslator::generate_definition(ft),
142                    )
143                })
144                .unzip();
145            custom_class_definition(
146                variant_wrapper_name.as_str(),
147                getters_declarations.join("\n"),
148            ) + &getters_definitions.join("\n")
149        })
150        .collect::<String>()
151}
152
153fn tag_getter_declaration(enum_item: &ItemEnum) -> String {
154    let enum_name = enum_item.ident.to_string();
155    let enum_tag_name = enum_tag_name(enum_name);
156    format!("    {enum_tag_name} get_tag();\n")
157}
158
159fn tag_getter_definition(enum_item: &ItemEnum) -> String {
160    let enum_name = enum_item.ident.to_string();
161    let enum_tag_name = enum_tag_name(&enum_name);
162    format!(
163        "{enum_tag_name} {enum_name}::get_tag() {{
164    auto ptr = reinterpret_cast<{enum_tag_name} *>(self);
165    return *ptr;
166}};\n"
167    )
168}
169
170fn translate_complex_enum(enum_item: &ItemEnum) -> String {
171    let enum_class_object = create_enum_class_objects(enum_item);
172    let enum_tag = create_enum_class_tag(enum_item);
173    format!(
174        "{enum_tag}
175{enum_class_object}",
176    )
177}
178
179fn create_enum_class_tag(enum_item: &ItemEnum) -> String {
180    let enum_tag_name = enum_tag_name(enum_item.ident.to_string().as_str());
181    let variants = enum_item.variants.iter().map(|variant| &variant.ident);
182    create_enum_class(&enum_tag_name, variants)
183}
184
185fn primitive_enum_class_definition(enum_item: &ItemEnum) -> String {
186    let enum_name = enum_item.ident.to_string();
187    let variants = enum_item.variants.iter().map(|variant| &variant.ident);
188    create_enum_class(&enum_name, variants)
189}
190
191pub fn create_enum_class<T: Display>(name: &str, variants: impl Iterator<Item = T>) -> String {
192    let variants = variants
193        .map(|id| format!("    {id}"))
194        .collect::<Vec<_>>()
195        .join(",\n");
196    format!(
197        "
198enum class {name} {{
199{variants}
200}};
201"
202    )
203}
204
205#[cfg(test)]
206mod tests {
207    use pretty_assertions::assert_eq;
208
209    use super::*;
210    use crate::utils::helpers;
211
212    #[test]
213    fn test_creating_enum_class_definition() {
214        let enum_item = helpers::get_enum_item();
215        let definition = primitive_enum_class_definition(&enum_item);
216
217        let expected_definition = "
218enum class En1 {
219    V1,
220    V2
221};
222";
223        assert_eq!(definition, expected_definition);
224    }
225}