use std::fmt::Display;
use syn::ItemEnum;
use crate::cpp::function_translator::FunctionTranslator;
use crate::cpp::templates::custom_class_definition;
use crate::enum_helpers::{
create_field_getter_function,
create_variant_getter_function,
enum_tag_name,
get_fields,
is_many_fields_variant,
is_primitive_enum,
variant_wrapper_ident,
};
use crate::extern_module_translator::ExternModuleTranslator;
pub fn enums_cpp_forward_declarations(emt: &ExternModuleTranslator) -> String {
emt.shared_enums
.iter()
.map(forward_declarations_cpp)
.collect()
}
pub fn translate_enums(emt: &ExternModuleTranslator) -> String {
emt.shared_enums.iter().map(translate_enum).collect()
}
fn forward_declarations_cpp(enum_item: &ItemEnum) -> String {
let enum_name = enum_item.ident.to_string();
let enum_tag_name = enum_tag_name(&enum_name);
if is_primitive_enum(enum_item) {
format!(
"
enum class {enum_name};
"
)
} else {
let variants_payloads_structs: String = enum_item
.variants
.iter()
.filter_map(|variant| {
if variant.fields.len() > 1 {
Some(format!(
"class {};\n",
variant_wrapper_ident(&enum_name, variant.ident.to_string())
))
} else {
None
}
})
.collect();
format!(
"
enum class {enum_tag_name};
class {enum_name};
{variants_payloads_structs}"
)
}
}
fn translate_enum(enum_item: &ItemEnum) -> String {
if is_primitive_enum(enum_item) {
primitive_enum_class_definition(enum_item)
} else {
translate_complex_enum(enum_item)
}
}
pub fn create_enum_class_objects(enum_item: &ItemEnum) -> String {
let enum_object_class = create_enum_object_class(enum_item);
let variant_wrappers = create_variant_wrappers(enum_item);
variant_wrappers + &enum_object_class
}
fn create_enum_object_class(enum_item: &ItemEnum) -> String {
let (mut variant_getters_declarations, mut variant_getters_definitions): (Vec<_>, Vec<_>) =
enum_item
.variants
.iter()
.filter_map(|variant| create_variant_getter_function(enum_item, variant))
.map(|f| {
FunctionTranslator::from_class_method(&f, enum_item.ident.to_string().as_str())
})
.map(|ft| {
(
FunctionTranslator::generate_declaration(ft.clone()),
FunctionTranslator::generate_definition(ft),
)
})
.unzip();
variant_getters_declarations.push(tag_getter_declaration(enum_item));
variant_getters_definitions.push(tag_getter_definition(enum_item));
custom_class_definition(&enum_item.ident, variant_getters_declarations.join("\n"))
+ &variant_getters_definitions.join("\n")
}
fn create_variant_wrappers(enum_item: &ItemEnum) -> String {
enum_item
.variants
.iter()
.filter(|v| is_many_fields_variant(v))
.map(|variant| {
let fields = get_fields(variant).unwrap();
let variant_wrapper_name =
variant_wrapper_ident(&enum_item.ident, &variant.ident).to_string();
let (getters_declarations, getters_definitions): (Vec<_>, Vec<_>) = fields
.iter()
.enumerate()
.map(|(field_idx, field)| {
create_field_getter_function(enum_item, variant, field, field_idx)
})
.map(|f| FunctionTranslator::from_class_method(&f, &variant_wrapper_name))
.map(|ft| {
(
FunctionTranslator::generate_declaration(ft.clone()),
FunctionTranslator::generate_definition(ft),
)
})
.unzip();
custom_class_definition(
variant_wrapper_name.as_str(),
getters_declarations.join("\n"),
) + &getters_definitions.join("\n")
})
.collect::<String>()
}
fn tag_getter_declaration(enum_item: &ItemEnum) -> String {
let enum_name = enum_item.ident.to_string();
let enum_tag_name = enum_tag_name(enum_name);
format!(" {enum_tag_name} get_tag();\n")
}
fn tag_getter_definition(enum_item: &ItemEnum) -> String {
let enum_name = enum_item.ident.to_string();
let enum_tag_name = enum_tag_name(&enum_name);
format!(
"{enum_tag_name} {enum_name}::get_tag() {{
auto ptr = reinterpret_cast<{enum_tag_name} *>(self);
return *ptr;
}};\n"
)
}
fn translate_complex_enum(enum_item: &ItemEnum) -> String {
let enum_class_object = create_enum_class_objects(enum_item);
let enum_tag = create_enum_class_tag(enum_item);
format!(
"{enum_tag}
{enum_class_object}",
)
}
fn create_enum_class_tag(enum_item: &ItemEnum) -> String {
let enum_tag_name = enum_tag_name(enum_item.ident.to_string().as_str());
let variants = enum_item.variants.iter().map(|variant| &variant.ident);
create_enum_class(&enum_tag_name, variants)
}
fn primitive_enum_class_definition(enum_item: &ItemEnum) -> String {
let enum_name = enum_item.ident.to_string();
let variants = enum_item.variants.iter().map(|variant| &variant.ident);
create_enum_class(&enum_name, variants)
}
pub fn create_enum_class<T: Display>(name: &str, variants: impl Iterator<Item = T>) -> String {
let variants = variants
.map(|id| format!(" {id}"))
.collect::<Vec<_>>()
.join(",\n");
format!(
"
enum class {name} {{
{variants}
}};
"
)
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
use crate::utils::helpers;
#[test]
fn test_creating_enum_class_definition() {
let enum_item = helpers::get_enum_item();
let definition = primitive_enum_class_definition(&enum_item);
let expected_definition = "
enum class En1 {
V1,
V2
};
";
assert_eq!(definition, expected_definition);
}
}