zen_expression/variable/types/
util.rs

1use crate::variable::types::VariableType;
2use serde_json::Value;
3use std::collections::hash_map::Entry;
4use std::collections::HashMap;
5use std::rc::Rc;
6
7impl VariableType {
8    pub fn array_item(&self) -> Option<Rc<VariableType>> {
9        match self {
10            VariableType::Array(item) => Some(item.clone()),
11            _ => None,
12        }
13    }
14
15    pub fn as_const_str(&self) -> Option<&str> {
16        match self {
17            VariableType::Constant(c) => match c.as_ref() {
18                Value::String(s) => Some(s.as_str()),
19                _ => None,
20            },
21            _ => None,
22        }
23    }
24
25    pub fn omit_const(&self) -> VariableType {
26        match self {
27            VariableType::Constant(v) => VariableType::from(v.as_ref()),
28            _ => self.clone(),
29        }
30    }
31
32    pub fn get(&self, vt: &VariableType) -> Rc<VariableType> {
33        match self {
34            VariableType::Array(inner) => inner.clone(),
35            VariableType::Object(obj) => match vt.as_const_str() {
36                None => Rc::new(VariableType::Any),
37                Some(key) => obj.get(key).cloned().unwrap_or(Rc::new(VariableType::Any)),
38            },
39            VariableType::Any => Rc::new(VariableType::Any),
40            VariableType::Constant(c) => match c.as_ref() {
41                Value::Array(arr) => {
42                    let arr_type = VariableType::from(arr.clone());
43                    arr_type.array_item().unwrap_or(Rc::new(VariableType::Any))
44                }
45                Value::Object(obj) => match vt.as_const_str() {
46                    None => Rc::new(VariableType::Any),
47                    Some(key) => obj
48                        .get(key)
49                        .map(|v| Rc::new(v.into()))
50                        .unwrap_or(Rc::new(VariableType::Any)),
51                },
52                _ => Rc::from(VariableType::Null),
53            },
54            _ => Rc::from(VariableType::Null),
55        }
56    }
57
58    pub fn satisfies(&self, constraint: &Self) -> bool {
59        match (self, constraint) {
60            (VariableType::Any, _) | (_, VariableType::Any) => true,
61            (VariableType::Null, VariableType::Null) => true,
62            (VariableType::Bool, VariableType::Bool) => true,
63            (VariableType::String, VariableType::String) => true,
64            (VariableType::Number, VariableType::Number) => true,
65            (VariableType::Array(a1), VariableType::Array(a2)) => a1.satisfies(a2),
66            (VariableType::Object(o1), VariableType::Object(o2)) => o1
67                .iter()
68                .all(|(k, v)| o2.get(k).is_some_and(|tv| v.satisfies(tv))),
69            (VariableType::Constant(c1), VariableType::Constant(c2)) => c1 == c2,
70            (VariableType::Constant(c), _) => {
71                let self_kind: VariableType = c.as_ref().into();
72                self_kind.satisfies(constraint)
73            }
74            (_, _) => false,
75        }
76    }
77
78    pub fn satisfies_array(&self) -> bool {
79        match self {
80            VariableType::Any | VariableType::Array(_) => true,
81            VariableType::Constant(c) => match c.as_ref() {
82                Value::Array(_) => true,
83                _ => false,
84            },
85            _ => false,
86        }
87    }
88
89    pub fn satisfies_object(&self) -> bool {
90        match self {
91            VariableType::Any | VariableType::Object(_) => true,
92            VariableType::Constant(c) => match c.as_ref() {
93                Value::Object(_) => true,
94                _ => false,
95            },
96            _ => false,
97        }
98    }
99
100    pub fn merge(&self, other: &Self) -> Self {
101        match (&self, other) {
102            (VariableType::Any, _) | (_, VariableType::Any) => VariableType::Any,
103            (VariableType::Null, VariableType::Null) => VariableType::Null,
104            (VariableType::Bool, VariableType::Bool) => VariableType::Bool,
105            (VariableType::String, VariableType::String) => VariableType::String,
106            (VariableType::Number, VariableType::Number) => VariableType::Number,
107            (VariableType::Array(a1), VariableType::Array(a2)) => {
108                if Rc::ptr_eq(&a1, &a2) {
109                    VariableType::Array(a1.clone())
110                } else {
111                    VariableType::Array(Rc::new(a1.merge(a2)))
112                }
113            }
114            (VariableType::Constant(c1), VariableType::Constant(c2)) => {
115                if Rc::ptr_eq(&c1, &c2) {
116                    VariableType::Constant(c1.clone())
117                } else if c1 == c2 {
118                    VariableType::Constant(c1.clone())
119                } else {
120                    let vt1 = VariableType::from(c1.as_ref());
121                    let vt2 = VariableType::from(c2.as_ref());
122
123                    vt1.merge(&vt2)
124                }
125            }
126            (VariableType::Object(o1), VariableType::Object(o2)) => {
127                let cap = o1.capacity().max(o2.capacity());
128
129                let map = o1.iter().chain(o2.iter()).fold(
130                    HashMap::<String, Rc<VariableType>>::with_capacity(cap),
131                    |mut acc, (k, v)| {
132                        match acc.entry(k.clone()) {
133                            Entry::Occupied(mut occ) => {
134                                let current = occ.get();
135                                let merged = v.merge(current.as_ref());
136                                occ.insert(Rc::new(merged));
137                            }
138                            Entry::Vacant(vac) => {
139                                vac.insert(v.clone());
140                            }
141                        }
142
143                        acc
144                    },
145                );
146
147                VariableType::Object(map)
148            }
149            (_, _) => VariableType::Any,
150        }
151    }
152
153    pub fn is_null(&self) -> bool {
154        match self {
155            VariableType::Null => true,
156            _ => false,
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use crate::variable::VariableType;
164    use std::rc::Rc;
165
166    #[test]
167    fn merge_simple() {
168        assert_eq!(
169            VariableType::Number.merge(&VariableType::Number),
170            VariableType::Number
171        );
172        assert_eq!(
173            VariableType::String.merge(&VariableType::String),
174            VariableType::String
175        );
176        assert_eq!(
177            VariableType::Bool.merge(&VariableType::Bool),
178            VariableType::Bool
179        );
180        assert_eq!(
181            VariableType::Null.merge(&VariableType::Null),
182            VariableType::Null
183        );
184        assert_eq!(
185            VariableType::Any.merge(&VariableType::Any),
186            VariableType::Any
187        );
188    }
189
190    #[test]
191    fn merge_array() {
192        assert_eq!(
193            VariableType::Array(Rc::new(VariableType::Number))
194                .merge(&VariableType::Array(Rc::new(VariableType::Number))),
195            VariableType::Array(Rc::new(VariableType::Number))
196        );
197    }
198
199    #[test]
200    fn merge_mixed() {
201        assert_eq!(
202            VariableType::Number.merge(&VariableType::String),
203            VariableType::Any
204        );
205    }
206}