xsd_parser/optimizer/
flatten_unions.rs1use crate::types::{Ident, TypeVariant, UnionInfo, UnionTypeInfo};
2
3use super::{Error, Optimizer};
4
5struct FlattenUnionInfo {
6 count: usize,
7 info: UnionInfo,
8}
9
10impl Optimizer {
11 #[doc = include_str!("../../tests/optimizer/union_flatten.xsd")]
23 #[doc = include_str!("../../tests/optimizer/expected0/flatten_unions.rs")]
28 #[doc = include_str!("../../tests/optimizer/expected1/flatten_unions.rs")]
33 pub fn flatten_union(mut self, ident: Ident) -> Result<Self, Error> {
35 tracing::debug!("flatten_union(ident={ident:?})");
36
37 let Some(ty) = self.types.get(&ident) else {
38 return Err(Error::UnknownType(ident));
39 };
40
41 let TypeVariant::Union(ui) = &ty.variant else {
42 return Err(Error::ExpectedUnion(ident));
43 };
44
45 let mut info = FlattenUnionInfo {
46 count: 0,
47 info: UnionInfo::default(),
48 };
49
50 self.flatten_union_impl(&ident, None, &mut info);
51
52 if info.count > 1 {
53 info.info.base = ui.base.clone();
54
55 let ty = self.types.get_mut(&ident).unwrap();
56 ty.variant = TypeVariant::Union(info.info);
57 }
58
59 Ok(self)
60 }
61
62 pub fn flatten_unions(mut self) -> Self {
66 tracing::debug!("flatten_unions");
67
68 let idents = self
69 .types
70 .iter()
71 .filter_map(|(ident, type_)| {
72 if matches!(&type_.variant, TypeVariant::Union(_)) {
73 Some(ident)
74 } else {
75 None
76 }
77 })
78 .cloned()
79 .collect::<Vec<_>>();
80
81 for ident in idents {
82 self = self.flatten_union(ident).unwrap();
83 }
84
85 self
86 }
87
88 fn flatten_union_impl(
89 &self,
90 ident: &Ident,
91 display_name: Option<&str>,
92 next: &mut FlattenUnionInfo,
93 ) {
94 let Some(type_) = self.types.get(ident) else {
95 return;
96 };
97
98 match &type_.variant {
99 TypeVariant::Union(x) => {
100 next.count += 1;
101 for t in &*x.types {
102 self.flatten_union_impl(&t.type_, t.display_name.as_deref(), next);
103 }
104 }
105 TypeVariant::Reference(x) if x.is_single() => {
106 self.flatten_union_impl(&x.type_, display_name, next);
107 }
108 _ => {
109 let mut ui = UnionTypeInfo::new(ident.clone());
110 ui.display_name = display_name.map(ToOwned::to_owned);
111
112 next.info.types.push(ui);
113 }
114 }
115 }
116}