Skip to main content

sage_parser/
ty.rs

1//! Type expressions for the Sage language.
2
3use crate::span::Ident;
4use std::fmt;
5
6/// A type expression as it appears in source code.
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum TypeExpr {
9    /// 64-bit signed integer.
10    Int,
11    /// 64-bit IEEE 754 floating point.
12    Float,
13    /// Boolean.
14    Bool,
15    /// UTF-8 string.
16    String,
17    /// Unit type (void equivalent).
18    Unit,
19    /// Error type for error handling (has `.message` and `.kind` fields).
20    Error,
21    /// Homogeneous list: `List<T>`.
22    List(Box<TypeExpr>),
23    /// Optional value: `Option<T>`.
24    Option(Box<TypeExpr>),
25    /// LLM inference result: `Inferred<T>`.
26    Inferred(Box<TypeExpr>),
27    /// Agent handle: `Agent<AgentName>`.
28    Agent(Ident),
29    /// Named type (agent name or future user-defined types).
30    Named(Ident),
31    /// Function type: `Fn(A, B) -> C`.
32    /// The Vec holds parameter types; the Box holds the return type.
33    Fn(Vec<TypeExpr>, Box<TypeExpr>),
34    /// Map type: `Map<K, V>`.
35    Map(Box<TypeExpr>, Box<TypeExpr>),
36    /// Tuple type: `(A, B, C)`.
37    Tuple(Vec<TypeExpr>),
38    /// Result type: `Result<T, E>`.
39    Result(Box<TypeExpr>, Box<TypeExpr>),
40}
41
42impl TypeExpr {
43    /// Check if this is a primitive type.
44    #[must_use]
45    pub fn is_primitive(&self) -> bool {
46        matches!(
47            self,
48            TypeExpr::Int
49                | TypeExpr::Float
50                | TypeExpr::Bool
51                | TypeExpr::String
52                | TypeExpr::Unit
53                | TypeExpr::Error
54        )
55    }
56
57    /// Check if this is a compound type.
58    #[must_use]
59    pub fn is_compound(&self) -> bool {
60        matches!(
61            self,
62            TypeExpr::List(_)
63                | TypeExpr::Option(_)
64                | TypeExpr::Inferred(_)
65                | TypeExpr::Agent(_)
66                | TypeExpr::Fn(_, _)
67                | TypeExpr::Map(_, _)
68                | TypeExpr::Tuple(_)
69                | TypeExpr::Result(_, _)
70        )
71    }
72
73    /// Get the inner type for generic types, if any.
74    #[must_use]
75    pub fn inner_type(&self) -> Option<&TypeExpr> {
76        match self {
77            TypeExpr::List(inner) | TypeExpr::Option(inner) | TypeExpr::Inferred(inner) => {
78                Some(inner)
79            }
80            _ => None,
81        }
82    }
83}
84
85impl fmt::Display for TypeExpr {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match self {
88            TypeExpr::Int => write!(f, "Int"),
89            TypeExpr::Float => write!(f, "Float"),
90            TypeExpr::Bool => write!(f, "Bool"),
91            TypeExpr::String => write!(f, "String"),
92            TypeExpr::Unit => write!(f, "Unit"),
93            TypeExpr::Error => write!(f, "Error"),
94            TypeExpr::List(inner) => write!(f, "List<{inner}>"),
95            TypeExpr::Option(inner) => write!(f, "Option<{inner}>"),
96            TypeExpr::Inferred(inner) => write!(f, "Inferred<{inner}>"),
97            TypeExpr::Agent(name) => write!(f, "Agent<{name}>"),
98            TypeExpr::Named(name) => write!(f, "{name}"),
99            TypeExpr::Fn(params, ret) => {
100                write!(f, "Fn(")?;
101                for (i, param) in params.iter().enumerate() {
102                    if i > 0 {
103                        write!(f, ", ")?;
104                    }
105                    write!(f, "{param}")?;
106                }
107                write!(f, ") -> {ret}")
108            }
109            TypeExpr::Map(key, value) => write!(f, "Map<{key}, {value}>"),
110            TypeExpr::Tuple(elems) => {
111                write!(f, "(")?;
112                for (i, elem) in elems.iter().enumerate() {
113                    if i > 0 {
114                        write!(f, ", ")?;
115                    }
116                    write!(f, "{elem}")?;
117                }
118                write!(f, ")")
119            }
120            TypeExpr::Result(ok, err) => write!(f, "Result<{ok}, {err}>"),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn primitive_display() {
131        assert_eq!(format!("{}", TypeExpr::Int), "Int");
132        assert_eq!(format!("{}", TypeExpr::Float), "Float");
133        assert_eq!(format!("{}", TypeExpr::Bool), "Bool");
134        assert_eq!(format!("{}", TypeExpr::String), "String");
135        assert_eq!(format!("{}", TypeExpr::Unit), "Unit");
136    }
137
138    #[test]
139    fn compound_display() {
140        let list_str = TypeExpr::List(Box::new(TypeExpr::String));
141        assert_eq!(format!("{list_str}"), "List<String>");
142
143        let option_int = TypeExpr::Option(Box::new(TypeExpr::Int));
144        assert_eq!(format!("{option_int}"), "Option<Int>");
145
146        let inferred_str = TypeExpr::Inferred(Box::new(TypeExpr::String));
147        assert_eq!(format!("{inferred_str}"), "Inferred<String>");
148
149        let agent = TypeExpr::Agent(Ident::dummy("Researcher"));
150        assert_eq!(format!("{agent}"), "Agent<Researcher>");
151    }
152
153    #[test]
154    fn nested_compound_display() {
155        // List<List<Int>>
156        let nested = TypeExpr::List(Box::new(TypeExpr::List(Box::new(TypeExpr::Int))));
157        assert_eq!(format!("{nested}"), "List<List<Int>>");
158
159        // Option<List<String>>
160        let nested = TypeExpr::Option(Box::new(TypeExpr::List(Box::new(TypeExpr::String))));
161        assert_eq!(format!("{nested}"), "Option<List<String>>");
162    }
163
164    #[test]
165    fn fn_type_display() {
166        // Fn(Int) -> Int
167        let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
168        assert_eq!(format!("{fn_type}"), "Fn(Int) -> Int");
169
170        // Fn(String, Int) -> Bool
171        let fn_type = TypeExpr::Fn(
172            vec![TypeExpr::String, TypeExpr::Int],
173            Box::new(TypeExpr::Bool),
174        );
175        assert_eq!(format!("{fn_type}"), "Fn(String, Int) -> Bool");
176
177        // Fn() -> String (no parameters)
178        let fn_type = TypeExpr::Fn(vec![], Box::new(TypeExpr::String));
179        assert_eq!(format!("{fn_type}"), "Fn() -> String");
180
181        // Higher-order: Fn(Int) -> Fn(Int) -> Int
182        let inner_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
183        let outer_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(inner_fn));
184        assert_eq!(format!("{outer_fn}"), "Fn(Int) -> Fn(Int) -> Int");
185    }
186
187    #[test]
188    fn fn_type_is_compound() {
189        let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Bool));
190        assert!(fn_type.is_compound());
191        assert!(!fn_type.is_primitive());
192    }
193
194    #[test]
195    fn is_primitive() {
196        assert!(TypeExpr::Int.is_primitive());
197        assert!(TypeExpr::Float.is_primitive());
198        assert!(TypeExpr::Bool.is_primitive());
199        assert!(TypeExpr::String.is_primitive());
200        assert!(TypeExpr::Unit.is_primitive());
201
202        assert!(!TypeExpr::List(Box::new(TypeExpr::Int)).is_primitive());
203        assert!(!TypeExpr::Option(Box::new(TypeExpr::Int)).is_primitive());
204    }
205
206    #[test]
207    fn is_compound() {
208        assert!(!TypeExpr::Int.is_compound());
209
210        assert!(TypeExpr::List(Box::new(TypeExpr::Int)).is_compound());
211        assert!(TypeExpr::Option(Box::new(TypeExpr::Int)).is_compound());
212        assert!(TypeExpr::Inferred(Box::new(TypeExpr::String)).is_compound());
213        assert!(TypeExpr::Agent(Ident::dummy("Foo")).is_compound());
214    }
215
216    #[test]
217    fn inner_type() {
218        let list = TypeExpr::List(Box::new(TypeExpr::String));
219        assert_eq!(list.inner_type(), Some(&TypeExpr::String));
220
221        let option = TypeExpr::Option(Box::new(TypeExpr::Int));
222        assert_eq!(option.inner_type(), Some(&TypeExpr::Int));
223
224        assert_eq!(TypeExpr::Int.inner_type(), None);
225    }
226
227    #[test]
228    fn equality() {
229        assert_eq!(TypeExpr::Int, TypeExpr::Int);
230        assert_ne!(TypeExpr::Int, TypeExpr::Float);
231
232        let list1 = TypeExpr::List(Box::new(TypeExpr::String));
233        let list2 = TypeExpr::List(Box::new(TypeExpr::String));
234        let list3 = TypeExpr::List(Box::new(TypeExpr::Int));
235
236        assert_eq!(list1, list2);
237        assert_ne!(list1, list3);
238    }
239}