Skip to main content

specl_types/
types.rs

1//! Type representation for the Specl type system.
2
3use std::collections::BTreeMap;
4use std::fmt;
5
6/// A Specl type.
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Type {
9    /// Boolean type.
10    Bool,
11    /// Natural number (non-negative integer).
12    Nat,
13    /// Integer.
14    Int,
15    /// String.
16    String,
17    /// Set type `Set[T]`.
18    Set(Box<Type>),
19    /// Sequence type `Seq[T]`.
20    Seq(Box<Type>),
21    /// Dict type `dict[K, V]` (finite map).
22    Fn(Box<Type>, Box<Type>),
23    /// Option type `Option[T]`.
24    Option(Box<Type>),
25    /// Record type with named fields.
26    Record(RecordType),
27    /// Tuple type `(T1, T2, ...)`.
28    Tuple(Vec<Type>),
29    /// Finite range type `lo..hi`.
30    Range(i64, i64),
31    /// Named type (reference to a type alias).
32    Named(String),
33    /// Type variable (for inference).
34    Var(TypeVar),
35    /// Error type (used for error recovery).
36    Error,
37}
38
39impl Type {
40    /// Check if this is a numeric type (Nat, Int, or Range).
41    pub fn is_numeric(&self) -> bool {
42        matches!(self, Type::Nat | Type::Int | Type::Range(_, _))
43    }
44
45    /// Check if this is a collection type (Set, Seq, or Fn).
46    pub fn is_collection(&self) -> bool {
47        matches!(self, Type::Set(_) | Type::Seq(_) | Type::Fn(_, _))
48    }
49
50    /// Check if this type contains any type variables.
51    pub fn has_vars(&self) -> bool {
52        match self {
53            Type::Var(_) => true,
54            Type::Set(t) | Type::Seq(t) | Type::Option(t) => t.has_vars(),
55            Type::Fn(k, v) => k.has_vars() || v.has_vars(),
56            Type::Record(r) => r.fields.values().any(|t| t.has_vars()),
57            Type::Tuple(elems) => elems.iter().any(|t| t.has_vars()),
58            _ => false,
59        }
60    }
61
62    /// Substitute type variables according to a substitution.
63    pub fn substitute(&self, subst: &Substitution) -> Type {
64        match self {
65            Type::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
66            Type::Set(t) => Type::Set(Box::new(t.substitute(subst))),
67            Type::Seq(t) => Type::Seq(Box::new(t.substitute(subst))),
68            Type::Option(t) => Type::Option(Box::new(t.substitute(subst))),
69            Type::Fn(k, v) => {
70                Type::Fn(Box::new(k.substitute(subst)), Box::new(v.substitute(subst)))
71            }
72            Type::Record(r) => Type::Record(RecordType {
73                fields: r
74                    .fields
75                    .iter()
76                    .map(|(k, v)| (k.clone(), v.substitute(subst)))
77                    .collect(),
78            }),
79            Type::Tuple(elems) => Type::Tuple(elems.iter().map(|t| t.substitute(subst)).collect()),
80            _ => self.clone(),
81        }
82    }
83}
84
85impl fmt::Display for Type {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match self {
88            Type::Bool => write!(f, "Bool"),
89            Type::Nat => write!(f, "Nat"),
90            Type::Int => write!(f, "Int"),
91            Type::String => write!(f, "String"),
92            Type::Set(t) => write!(f, "Set[{}]", t),
93            Type::Seq(t) => write!(f, "Seq[{}]", t),
94            Type::Fn(k, v) => write!(f, "dict[{}, {}]", k, v),
95            Type::Option(t) => write!(f, "Option[{}]", t),
96            Type::Record(r) => {
97                write!(f, "Record {{ ")?;
98                for (i, (name, ty)) in r.fields.iter().enumerate() {
99                    if i > 0 {
100                        write!(f, ", ")?;
101                    }
102                    write!(f, "{}: {}", name, ty)?;
103                }
104                write!(f, " }}")
105            }
106            Type::Tuple(elems) => {
107                write!(f, "(")?;
108                for (i, ty) in elems.iter().enumerate() {
109                    if i > 0 {
110                        write!(f, ", ")?;
111                    }
112                    write!(f, "{}", ty)?;
113                }
114                write!(f, ")")
115            }
116            Type::Range(lo, hi) => write!(f, "{}..{}", lo, hi),
117            Type::Named(name) => write!(f, "{}", name),
118            Type::Var(v) => write!(f, "?{}", v.0),
119            Type::Error => write!(f, "<error>"),
120        }
121    }
122}
123
124/// A record type with named fields.
125#[derive(Debug, Clone, PartialEq, Eq, Hash)]
126pub struct RecordType {
127    /// Ordered map of field names to types.
128    pub fields: BTreeMap<String, Type>,
129}
130
131impl RecordType {
132    /// Create a new empty record type.
133    pub fn new() -> Self {
134        Self {
135            fields: BTreeMap::new(),
136        }
137    }
138
139    /// Create a record type from field definitions.
140    pub fn from_fields(fields: impl IntoIterator<Item = (String, Type)>) -> Self {
141        Self {
142            fields: fields.into_iter().collect(),
143        }
144    }
145
146    /// Get the type of a field.
147    pub fn get_field(&self, name: &str) -> Option<&Type> {
148        self.fields.get(name)
149    }
150}
151
152impl Default for RecordType {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// A type variable for type inference.
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
160pub struct TypeVar(pub u32);
161
162impl TypeVar {
163    /// Create a fresh type variable from an ID.
164    pub fn new(id: u32) -> Self {
165        Self(id)
166    }
167}
168
169/// A type substitution mapping type variables to types.
170#[derive(Debug, Clone, Default)]
171pub struct Substitution {
172    mappings: BTreeMap<TypeVar, Type>,
173}
174
175impl Substitution {
176    /// Create an empty substitution.
177    pub fn new() -> Self {
178        Self {
179            mappings: BTreeMap::new(),
180        }
181    }
182
183    /// Get the type bound to a variable, if any.
184    pub fn get(&self, var: &TypeVar) -> Option<&Type> {
185        self.mappings.get(var)
186    }
187
188    /// Bind a type variable to a type.
189    pub fn insert(&mut self, var: TypeVar, ty: Type) {
190        self.mappings.insert(var, ty);
191    }
192
193    /// Compose two substitutions: self then other.
194    pub fn compose(&self, other: &Substitution) -> Substitution {
195        let mut result = Substitution::new();
196
197        // Apply other to all types in self
198        for (var, ty) in &self.mappings {
199            result.insert(*var, ty.substitute(other));
200        }
201
202        // Add bindings from other that aren't in self
203        for (var, ty) in &other.mappings {
204            if !result.mappings.contains_key(var) {
205                result.insert(*var, ty.clone());
206            }
207        }
208
209        result
210    }
211
212    /// Check if the substitution is empty.
213    pub fn is_empty(&self) -> bool {
214        self.mappings.is_empty()
215    }
216}
217
218/// Type variable generator for fresh variables.
219#[derive(Debug, Clone, Default)]
220pub struct TypeVarGen {
221    next_id: u32,
222}
223
224impl TypeVarGen {
225    /// Create a new generator.
226    pub fn new() -> Self {
227        Self { next_id: 0 }
228    }
229
230    /// Generate a fresh type variable.
231    pub fn fresh(&mut self) -> TypeVar {
232        let var = TypeVar(self.next_id);
233        self.next_id += 1;
234        var
235    }
236
237    /// Generate a fresh type variable wrapped in a Type.
238    pub fn fresh_type(&mut self) -> Type {
239        Type::Var(self.fresh())
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_type_display() {
249        assert_eq!(Type::Bool.to_string(), "Bool");
250        assert_eq!(Type::Set(Box::new(Type::Nat)).to_string(), "Set[Nat]");
251        assert_eq!(
252            Type::Fn(Box::new(Type::String), Box::new(Type::Int)).to_string(),
253            "dict[String, Int]"
254        );
255    }
256
257    #[test]
258    fn test_type_has_vars() {
259        let mut gen = TypeVarGen::new();
260        assert!(!Type::Bool.has_vars());
261        assert!(Type::Var(gen.fresh()).has_vars());
262        assert!(Type::Set(Box::new(Type::Var(gen.fresh()))).has_vars());
263    }
264
265    #[test]
266    fn test_substitution() {
267        let mut gen = TypeVarGen::new();
268        let v1 = gen.fresh();
269        let v2 = gen.fresh();
270
271        let mut subst = Substitution::new();
272        subst.insert(v1, Type::Nat);
273
274        assert_eq!(Type::Var(v1).substitute(&subst), Type::Nat);
275        assert_eq!(Type::Var(v2).substitute(&subst), Type::Var(v2));
276    }
277
278    #[test]
279    fn test_record_type() {
280        let rec =
281            RecordType::from_fields([("x".to_string(), Type::Nat), ("y".to_string(), Type::Bool)]);
282        assert_eq!(rec.get_field("x"), Some(&Type::Nat));
283        assert_eq!(rec.get_field("z"), None);
284    }
285}