1use std::collections::BTreeMap;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Type {
9 Bool,
11 Nat,
13 Int,
15 String,
17 Set(Box<Type>),
19 Seq(Box<Type>),
21 Fn(Box<Type>, Box<Type>),
23 Option(Box<Type>),
25 Record(RecordType),
27 Tuple(Vec<Type>),
29 Range(i64, i64),
31 Named(String),
33 Var(TypeVar),
35 Error,
37}
38
39impl Type {
40 pub fn is_numeric(&self) -> bool {
42 matches!(self, Type::Nat | Type::Int | Type::Range(_, _))
43 }
44
45 pub fn is_collection(&self) -> bool {
47 matches!(self, Type::Set(_) | Type::Seq(_) | Type::Fn(_, _))
48 }
49
50 pub fn has_vars(&self) -> bool {
52 match self {
53 Type::Var(_) => true,
54 Type::Set(t) | Type::Seq(t) | Type::Option(t) => t.has_vars(),
55 Type::Fn(k, v) => k.has_vars() || v.has_vars(),
56 Type::Record(r) => r.fields.values().any(|t| t.has_vars()),
57 Type::Tuple(elems) => elems.iter().any(|t| t.has_vars()),
58 _ => false,
59 }
60 }
61
62 pub fn substitute(&self, subst: &Substitution) -> Type {
64 match self {
65 Type::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
66 Type::Set(t) => Type::Set(Box::new(t.substitute(subst))),
67 Type::Seq(t) => Type::Seq(Box::new(t.substitute(subst))),
68 Type::Option(t) => Type::Option(Box::new(t.substitute(subst))),
69 Type::Fn(k, v) => {
70 Type::Fn(Box::new(k.substitute(subst)), Box::new(v.substitute(subst)))
71 }
72 Type::Record(r) => Type::Record(RecordType {
73 fields: r
74 .fields
75 .iter()
76 .map(|(k, v)| (k.clone(), v.substitute(subst)))
77 .collect(),
78 }),
79 Type::Tuple(elems) => Type::Tuple(elems.iter().map(|t| t.substitute(subst)).collect()),
80 _ => self.clone(),
81 }
82 }
83}
84
85impl fmt::Display for Type {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 match self {
88 Type::Bool => write!(f, "Bool"),
89 Type::Nat => write!(f, "Nat"),
90 Type::Int => write!(f, "Int"),
91 Type::String => write!(f, "String"),
92 Type::Set(t) => write!(f, "Set[{}]", t),
93 Type::Seq(t) => write!(f, "Seq[{}]", t),
94 Type::Fn(k, v) => write!(f, "dict[{}, {}]", k, v),
95 Type::Option(t) => write!(f, "Option[{}]", t),
96 Type::Record(r) => {
97 write!(f, "Record {{ ")?;
98 for (i, (name, ty)) in r.fields.iter().enumerate() {
99 if i > 0 {
100 write!(f, ", ")?;
101 }
102 write!(f, "{}: {}", name, ty)?;
103 }
104 write!(f, " }}")
105 }
106 Type::Tuple(elems) => {
107 write!(f, "(")?;
108 for (i, ty) in elems.iter().enumerate() {
109 if i > 0 {
110 write!(f, ", ")?;
111 }
112 write!(f, "{}", ty)?;
113 }
114 write!(f, ")")
115 }
116 Type::Range(lo, hi) => write!(f, "{}..{}", lo, hi),
117 Type::Named(name) => write!(f, "{}", name),
118 Type::Var(v) => write!(f, "?{}", v.0),
119 Type::Error => write!(f, "<error>"),
120 }
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Hash)]
126pub struct RecordType {
127 pub fields: BTreeMap<String, Type>,
129}
130
131impl RecordType {
132 pub fn new() -> Self {
134 Self {
135 fields: BTreeMap::new(),
136 }
137 }
138
139 pub fn from_fields(fields: impl IntoIterator<Item = (String, Type)>) -> Self {
141 Self {
142 fields: fields.into_iter().collect(),
143 }
144 }
145
146 pub fn get_field(&self, name: &str) -> Option<&Type> {
148 self.fields.get(name)
149 }
150}
151
152impl Default for RecordType {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
160pub struct TypeVar(pub u32);
161
162impl TypeVar {
163 pub fn new(id: u32) -> Self {
165 Self(id)
166 }
167}
168
169#[derive(Debug, Clone, Default)]
171pub struct Substitution {
172 mappings: BTreeMap<TypeVar, Type>,
173}
174
175impl Substitution {
176 pub fn new() -> Self {
178 Self {
179 mappings: BTreeMap::new(),
180 }
181 }
182
183 pub fn get(&self, var: &TypeVar) -> Option<&Type> {
185 self.mappings.get(var)
186 }
187
188 pub fn insert(&mut self, var: TypeVar, ty: Type) {
190 self.mappings.insert(var, ty);
191 }
192
193 pub fn compose(&self, other: &Substitution) -> Substitution {
195 let mut result = Substitution::new();
196
197 for (var, ty) in &self.mappings {
199 result.insert(*var, ty.substitute(other));
200 }
201
202 for (var, ty) in &other.mappings {
204 if !result.mappings.contains_key(var) {
205 result.insert(*var, ty.clone());
206 }
207 }
208
209 result
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.mappings.is_empty()
215 }
216}
217
218#[derive(Debug, Clone, Default)]
220pub struct TypeVarGen {
221 next_id: u32,
222}
223
224impl TypeVarGen {
225 pub fn new() -> Self {
227 Self { next_id: 0 }
228 }
229
230 pub fn fresh(&mut self) -> TypeVar {
232 let var = TypeVar(self.next_id);
233 self.next_id += 1;
234 var
235 }
236
237 pub fn fresh_type(&mut self) -> Type {
239 Type::Var(self.fresh())
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_type_display() {
249 assert_eq!(Type::Bool.to_string(), "Bool");
250 assert_eq!(Type::Set(Box::new(Type::Nat)).to_string(), "Set[Nat]");
251 assert_eq!(
252 Type::Fn(Box::new(Type::String), Box::new(Type::Int)).to_string(),
253 "dict[String, Int]"
254 );
255 }
256
257 #[test]
258 fn test_type_has_vars() {
259 let mut gen = TypeVarGen::new();
260 assert!(!Type::Bool.has_vars());
261 assert!(Type::Var(gen.fresh()).has_vars());
262 assert!(Type::Set(Box::new(Type::Var(gen.fresh()))).has_vars());
263 }
264
265 #[test]
266 fn test_substitution() {
267 let mut gen = TypeVarGen::new();
268 let v1 = gen.fresh();
269 let v2 = gen.fresh();
270
271 let mut subst = Substitution::new();
272 subst.insert(v1, Type::Nat);
273
274 assert_eq!(Type::Var(v1).substitute(&subst), Type::Nat);
275 assert_eq!(Type::Var(v2).substitute(&subst), Type::Var(v2));
276 }
277
278 #[test]
279 fn test_record_type() {
280 let rec =
281 RecordType::from_fields([("x".to_string(), Type::Nat), ("y".to_string(), Type::Bool)]);
282 assert_eq!(rec.get_field("x"), Some(&Type::Nat));
283 assert_eq!(rec.get_field("z"), None);
284 }
285}