smt_lang/problem/
typing.rs

1use std::vec;
2
3use super::*;
4
5#[derive(Clone, PartialEq, Eq, Hash, Debug)]
6pub enum Type {
7    Bool,
8    Int,
9    Real,
10    Interval(isize, isize),
11    Structure(StructureId),
12    Class(ClassId),
13    //
14    Unresolved(String, Option<Position>),
15    Undefined,
16    // Error,
17}
18
19impl Type {
20    pub fn is_bounded(&self) -> bool {
21        match self {
22            Type::Bool => true,
23            Type::Int => false,
24            Type::Real => false,
25            Type::Interval(_, _) => true,
26            Type::Structure(_) => true,
27            Type::Class(_) => true,
28            //
29            Type::Unresolved(_, _) => false,
30            Type::Undefined => false,
31            // Type::Error => false,
32        }
33    }
34
35    pub fn resolve_type(&self, entries: &TypeEntries) -> Result<Type, Error> {
36        match self {
37            Type::Unresolved(name, position) => match entries.get(&name) {
38                Some(entry) => match entry.typ() {
39                    TypeEntryType::Structure(id) => Ok(Type::Structure(id)),
40                    TypeEntryType::Class(id) => Ok(Type::Class(id)),
41                },
42                None => Err(Error::Resolve {
43                    category: "type".to_string(),
44                    name: name.clone(),
45                    position: position.clone(),
46                }),
47            },
48            _ => Ok(self.clone()),
49        }
50    }
51
52    pub fn is_bool(&self) -> bool {
53        match self {
54            Type::Bool => true,
55            _ => false,
56        }
57    }
58
59    pub fn is_int(&self) -> bool {
60        match self {
61            Type::Int => true,
62            _ => false,
63        }
64    }
65
66    pub fn is_real(&self) -> bool {
67        match self {
68            Type::Real => true,
69            _ => false,
70        }
71    }
72
73    pub fn is_interval(&self) -> bool {
74        match self {
75            Type::Bool => true,
76            _ => false,
77        }
78    }
79
80    pub fn is_undefined(&self) -> bool {
81        match self {
82            Type::Undefined => true,
83            _ => false,
84        }
85    }
86
87    pub fn is_integer(&self) -> bool {
88        match self {
89            Type::Int => true,
90            Type::Interval(_, _) => true,
91            _ => false,
92        }
93    }
94
95    pub fn is_number(&self) -> bool {
96        match self {
97            Type::Int => true,
98            Type::Real => true,
99            Type::Interval(_, _) => true,
100            _ => false,
101        }
102    }
103
104    pub fn is_structure(&self) -> bool {
105        match self {
106            Type::Structure(_) => true,
107            _ => false,
108        }
109    }
110
111    pub fn is_class(&self) -> bool {
112        match self {
113            Type::Class(_) => true,
114            _ => false,
115        }
116    }
117
118    pub fn class(&self) -> Option<ClassId> {
119        match self {
120            Type::Class(id) => Some(*id),
121            _ => None,
122        }
123    }
124
125    pub fn is_compatible_with(&self, problem: &Problem, other: &Self) -> bool {
126        match (self, other) {
127            // TODO: check
128            // (Type::Interval(min1, max1), Type::Interval(min2, max2)) => {
129            //     (max1 >= min2 && max1 <= max2) || (min1 >= min2 && min1 <= max2)
130            // }
131            (Type::Interval(_, _), Type::Interval(_, _)) => true,
132            (Type::Interval(_, _), Type::Int) => true,
133            (Type::Int, Type::Interval(_, _)) => true,
134            (x, y) => x.is_subtype_of(problem, y) || y.is_subtype_of(problem, x),
135        }
136    }
137
138    pub fn is_subtype_of(&self, problem: &Problem, other: &Self) -> bool {
139        match (self, other) {
140            // TODO: Int :<: Real ??? and Interval :<: Real
141            (Type::Interval(min1, max1), Type::Interval(min2, max2)) => {
142                min1 >= min2 && max1 <= max2
143            }
144            (Type::Interval(_, _), Type::Int) => true,
145            (Type::Class(i1), Type::Class(i2)) => {
146                if i1 == i2 {
147                    true
148                } else {
149                    let c1 = problem.get(*i1).unwrap();
150                    c1.super_classes(problem).contains(i2)
151                }
152            }
153            (x, y) => x == y,
154        }
155    }
156
157    pub fn common_type(&self, problem: &Problem, other: &Self) -> Type {
158        if self == other {
159            self.clone()
160        } else {
161            match (self, other) {
162                (Type::Interval(_, _), Type::Int) => Type::Int,
163                (Type::Int, Type::Interval(_, _)) => Type::Int,
164                (Type::Interval(min1, max1), Type::Interval(min2, max2)) => {
165                    Type::Interval(*min1.min(min2), *max1.max(max2))
166                }
167                (Type::Class(i1), Type::Class(i2)) => {
168                    let c1 = problem.get(*i1).unwrap();
169                    match c1.common_class(problem, *i2) {
170                        Some(id) => Type::Class(id),
171                        _ => Type::Undefined,
172                    }
173                }
174                _ => Type::Undefined,
175            }
176        }
177    }
178
179    pub fn check_interval(
180        &self,
181        problem: &Problem,
182        position: &Option<Position>,
183    ) -> Result<(), Error> {
184        match self {
185            Type::Interval(min, max) => {
186                if min > max {
187                    Err(Error::Interval {
188                        name: self.to_lang(problem),
189                        position: position.clone(),
190                    })
191                } else {
192                    Ok(())
193                }
194            }
195            _ => Ok(()),
196        }
197    }
198
199    pub fn all(&self, problem: &Problem) -> Vec<Expr> {
200        match self {
201            Type::Structure(id) => problem
202                .get(*id)
203                .unwrap()
204                .instances(problem)
205                .iter()
206                .map(|i| Expr::Instance(*i, None))
207                .collect(),
208            Type::Class(id) => problem
209                .get(*id)
210                .unwrap()
211                .all_instances(problem)
212                .iter()
213                .map(|i| Expr::Instance(*i, None))
214                .collect(),
215            Type::Bool => vec![Expr::BoolValue(false, None), Expr::BoolValue(true, None)],
216            Type::Interval(min, max) => (*min..=*max)
217                .into_iter()
218                .map(|i| Expr::IntValue(i, None))
219                .collect(),
220            _ => vec![],
221        }
222    }
223}
224//------------------------- To Lang -------------------------
225
226impl ToLang for Type {
227    fn to_lang(&self, problem: &Problem) -> String {
228        match self {
229            Type::Bool => "Bool".into(),
230            Type::Int => "Int".into(),
231            Type::Real => "Real".into(),
232            Type::Interval(min, max) => format!("{}..{}", min, max),
233            Type::Structure(id) => problem.get(*id).unwrap().name().to_string(),
234            Type::Class(id) => problem.get(*id).unwrap().name().to_string(),
235            //
236            Type::Unresolved(name, _) => format!("{}?", name),
237            Type::Undefined => "?".into(),
238            // Type::Error => "type_error".into(),
239        }
240    }
241}