xsd_parser/optimizer/
flatten_unions.rs

1use crate::types::{Ident, TypeVariant, UnionInfo, UnionTypeInfo};
2
3use super::{Error, Optimizer};
4
5struct FlattenUnionInfo {
6    count: usize,
7    info: UnionInfo,
8}
9
10impl Optimizer {
11    /// This will flatten the union identified by `ident` to one single union.
12    ///
13    /// # Errors
14    ///
15    /// Returns an error if the passed `ident` could not be found,
16    /// or is not an union.
17    ///
18    /// # Examples
19    ///
20    /// Consider the following XML schema.
21    /// ```xml
22    #[doc = include_str!("../../tests/optimizer/union_flatten.xsd")]
23    /// ```
24    ///
25    /// Without this optimization this will result in the following code:
26    /// ```rust
27    #[doc = include_str!("../../tests/optimizer/expected0/flatten_unions.rs")]
28    /// ```
29    ///
30    /// With this optimization the following code is generated:
31    /// ```rust
32    #[doc = include_str!("../../tests/optimizer/expected1/flatten_unions.rs")]
33    /// ```
34    pub fn flatten_union(mut self, ident: Ident) -> Result<Self, Error> {
35        tracing::debug!("flatten_union(ident={ident:?})");
36
37        let Some(ty) = self.types.get(&ident) else {
38            return Err(Error::UnknownType(ident));
39        };
40
41        let TypeVariant::Union(ui) = &ty.variant else {
42            return Err(Error::ExpectedUnion(ident));
43        };
44
45        let mut info = FlattenUnionInfo {
46            count: 0,
47            info: UnionInfo::default(),
48        };
49
50        self.flatten_union_impl(&ident, None, &mut info);
51
52        if info.count > 1 {
53            info.info.base = ui.base.clone();
54
55            let ty = self.types.get_mut(&ident).unwrap();
56            ty.variant = TypeVariant::Union(info.info);
57        }
58
59        Ok(self)
60    }
61
62    /// This will flatten all union types.
63    ///
64    /// For details see [`flatten_union`](Self::flatten_union).
65    pub fn flatten_unions(mut self) -> Self {
66        tracing::debug!("flatten_unions");
67
68        let idents = self
69            .types
70            .iter()
71            .filter_map(|(ident, type_)| {
72                if matches!(&type_.variant, TypeVariant::Union(_)) {
73                    Some(ident)
74                } else {
75                    None
76                }
77            })
78            .cloned()
79            .collect::<Vec<_>>();
80
81        for ident in idents {
82            self = self.flatten_union(ident).unwrap();
83        }
84
85        self
86    }
87
88    fn flatten_union_impl(
89        &self,
90        ident: &Ident,
91        display_name: Option<&str>,
92        next: &mut FlattenUnionInfo,
93    ) {
94        let Some(type_) = self.types.get(ident) else {
95            return;
96        };
97
98        match &type_.variant {
99            TypeVariant::Union(x) => {
100                next.count += 1;
101                for t in &*x.types {
102                    self.flatten_union_impl(&t.type_, t.display_name.as_deref(), next);
103                }
104            }
105            TypeVariant::Reference(x) if x.is_single() => {
106                self.flatten_union_impl(&x.type_, display_name, next);
107            }
108            _ => {
109                let mut ui = UnionTypeInfo::new(ident.clone());
110                ui.display_name = display_name.map(ToOwned::to_owned);
111
112                next.info.types.push(ui);
113            }
114        }
115    }
116}