Skip to main content

sage_parser/
ty.rs

1//! Type expressions for the Sage language.
2
3use crate::span::Ident;
4use std::fmt;
5
6/// A type expression as it appears in source code.
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum TypeExpr {
9    /// 64-bit signed integer.
10    Int,
11    /// 64-bit IEEE 754 floating point.
12    Float,
13    /// Boolean.
14    Bool,
15    /// UTF-8 string.
16    String,
17    /// Unit type (void equivalent).
18    Unit,
19    /// Error type for error handling (has `.message` and `.kind` fields).
20    Error,
21    /// Homogeneous list: `List<T>`.
22    List(Box<TypeExpr>),
23    /// Optional value: `Option<T>`.
24    Option(Box<TypeExpr>),
25    /// LLM oracle result: `Oracle<T>`.
26    Oracle(Box<TypeExpr>),
27    /// Agent handle: `Agent<AgentName>`.
28    Agent(Ident),
29    /// Named type with optional type arguments (e.g., `Point`, `Pair<Int, String>`).
30    /// The Vec is empty for non-generic types.
31    Named(Ident, Vec<TypeExpr>),
32    /// Function type: `Fn(A, B) -> C`.
33    /// The Vec holds parameter types; the Box holds the return type.
34    Fn(Vec<TypeExpr>, Box<TypeExpr>),
35    /// Map type: `Map<K, V>`.
36    Map(Box<TypeExpr>, Box<TypeExpr>),
37    /// Tuple type: `(A, B, C)`.
38    Tuple(Vec<TypeExpr>),
39    /// Result type: `Result<T, E>`.
40    Result(Box<TypeExpr>, Box<TypeExpr>),
41}
42
43impl TypeExpr {
44    /// Check if this is a primitive type.
45    #[must_use]
46    pub fn is_primitive(&self) -> bool {
47        matches!(
48            self,
49            TypeExpr::Int
50                | TypeExpr::Float
51                | TypeExpr::Bool
52                | TypeExpr::String
53                | TypeExpr::Unit
54                | TypeExpr::Error
55        )
56    }
57
58    /// Check if this is a compound type.
59    #[must_use]
60    pub fn is_compound(&self) -> bool {
61        matches!(
62            self,
63            TypeExpr::List(_)
64                | TypeExpr::Option(_)
65                | TypeExpr::Oracle(_)
66                | TypeExpr::Agent(_)
67                | TypeExpr::Fn(_, _)
68                | TypeExpr::Map(_, _)
69                | TypeExpr::Tuple(_)
70                | TypeExpr::Result(_, _)
71        )
72    }
73
74    /// Get the inner type for generic types, if any.
75    #[must_use]
76    pub fn inner_type(&self) -> Option<&TypeExpr> {
77        match self {
78            TypeExpr::List(inner) | TypeExpr::Option(inner) | TypeExpr::Oracle(inner) => {
79                Some(inner)
80            }
81            _ => None,
82        }
83    }
84}
85
86impl fmt::Display for TypeExpr {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            TypeExpr::Int => write!(f, "Int"),
90            TypeExpr::Float => write!(f, "Float"),
91            TypeExpr::Bool => write!(f, "Bool"),
92            TypeExpr::String => write!(f, "String"),
93            TypeExpr::Unit => write!(f, "Unit"),
94            TypeExpr::Error => write!(f, "Error"),
95            TypeExpr::List(inner) => write!(f, "List<{inner}>"),
96            TypeExpr::Option(inner) => write!(f, "Option<{inner}>"),
97            TypeExpr::Oracle(inner) => write!(f, "Oracle<{inner}>"),
98            TypeExpr::Agent(name) => write!(f, "Agent<{name}>"),
99            TypeExpr::Named(name, type_args) => {
100                write!(f, "{name}")?;
101                if !type_args.is_empty() {
102                    write!(f, "<")?;
103                    for (i, arg) in type_args.iter().enumerate() {
104                        if i > 0 {
105                            write!(f, ", ")?;
106                        }
107                        write!(f, "{arg}")?;
108                    }
109                    write!(f, ">")?;
110                }
111                Ok(())
112            }
113            TypeExpr::Fn(params, ret) => {
114                write!(f, "Fn(")?;
115                for (i, param) in params.iter().enumerate() {
116                    if i > 0 {
117                        write!(f, ", ")?;
118                    }
119                    write!(f, "{param}")?;
120                }
121                write!(f, ") -> {ret}")
122            }
123            TypeExpr::Map(key, value) => write!(f, "Map<{key}, {value}>"),
124            TypeExpr::Tuple(elems) => {
125                write!(f, "(")?;
126                for (i, elem) in elems.iter().enumerate() {
127                    if i > 0 {
128                        write!(f, ", ")?;
129                    }
130                    write!(f, "{elem}")?;
131                }
132                write!(f, ")")
133            }
134            TypeExpr::Result(ok, err) => write!(f, "Result<{ok}, {err}>"),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn primitive_display() {
145        assert_eq!(format!("{}", TypeExpr::Int), "Int");
146        assert_eq!(format!("{}", TypeExpr::Float), "Float");
147        assert_eq!(format!("{}", TypeExpr::Bool), "Bool");
148        assert_eq!(format!("{}", TypeExpr::String), "String");
149        assert_eq!(format!("{}", TypeExpr::Unit), "Unit");
150    }
151
152    #[test]
153    fn compound_display() {
154        let list_str = TypeExpr::List(Box::new(TypeExpr::String));
155        assert_eq!(format!("{list_str}"), "List<String>");
156
157        let option_int = TypeExpr::Option(Box::new(TypeExpr::Int));
158        assert_eq!(format!("{option_int}"), "Option<Int>");
159
160        let oracle_str = TypeExpr::Oracle(Box::new(TypeExpr::String));
161        assert_eq!(format!("{oracle_str}"), "Oracle<String>");
162
163        let agent = TypeExpr::Agent(Ident::dummy("Researcher"));
164        assert_eq!(format!("{agent}"), "Agent<Researcher>");
165    }
166
167    #[test]
168    fn nested_compound_display() {
169        // List<List<Int>>
170        let nested = TypeExpr::List(Box::new(TypeExpr::List(Box::new(TypeExpr::Int))));
171        assert_eq!(format!("{nested}"), "List<List<Int>>");
172
173        // Option<List<String>>
174        let nested = TypeExpr::Option(Box::new(TypeExpr::List(Box::new(TypeExpr::String))));
175        assert_eq!(format!("{nested}"), "Option<List<String>>");
176    }
177
178    #[test]
179    fn fn_type_display() {
180        // Fn(Int) -> Int
181        let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
182        assert_eq!(format!("{fn_type}"), "Fn(Int) -> Int");
183
184        // Fn(String, Int) -> Bool
185        let fn_type = TypeExpr::Fn(
186            vec![TypeExpr::String, TypeExpr::Int],
187            Box::new(TypeExpr::Bool),
188        );
189        assert_eq!(format!("{fn_type}"), "Fn(String, Int) -> Bool");
190
191        // Fn() -> String (no parameters)
192        let fn_type = TypeExpr::Fn(vec![], Box::new(TypeExpr::String));
193        assert_eq!(format!("{fn_type}"), "Fn() -> String");
194
195        // Higher-order: Fn(Int) -> Fn(Int) -> Int
196        let inner_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
197        let outer_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(inner_fn));
198        assert_eq!(format!("{outer_fn}"), "Fn(Int) -> Fn(Int) -> Int");
199    }
200
201    #[test]
202    fn fn_type_is_compound() {
203        let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Bool));
204        assert!(fn_type.is_compound());
205        assert!(!fn_type.is_primitive());
206    }
207
208    #[test]
209    fn is_primitive() {
210        assert!(TypeExpr::Int.is_primitive());
211        assert!(TypeExpr::Float.is_primitive());
212        assert!(TypeExpr::Bool.is_primitive());
213        assert!(TypeExpr::String.is_primitive());
214        assert!(TypeExpr::Unit.is_primitive());
215
216        assert!(!TypeExpr::List(Box::new(TypeExpr::Int)).is_primitive());
217        assert!(!TypeExpr::Option(Box::new(TypeExpr::Int)).is_primitive());
218    }
219
220    #[test]
221    fn is_compound() {
222        assert!(!TypeExpr::Int.is_compound());
223
224        assert!(TypeExpr::List(Box::new(TypeExpr::Int)).is_compound());
225        assert!(TypeExpr::Option(Box::new(TypeExpr::Int)).is_compound());
226        assert!(TypeExpr::Oracle(Box::new(TypeExpr::String)).is_compound());
227        assert!(TypeExpr::Agent(Ident::dummy("Foo")).is_compound());
228    }
229
230    #[test]
231    fn inner_type() {
232        let list = TypeExpr::List(Box::new(TypeExpr::String));
233        assert_eq!(list.inner_type(), Some(&TypeExpr::String));
234
235        let option = TypeExpr::Option(Box::new(TypeExpr::Int));
236        assert_eq!(option.inner_type(), Some(&TypeExpr::Int));
237
238        assert_eq!(TypeExpr::Int.inner_type(), None);
239    }
240
241    #[test]
242    fn equality() {
243        assert_eq!(TypeExpr::Int, TypeExpr::Int);
244        assert_ne!(TypeExpr::Int, TypeExpr::Float);
245
246        let list1 = TypeExpr::List(Box::new(TypeExpr::String));
247        let list2 = TypeExpr::List(Box::new(TypeExpr::String));
248        let list3 = TypeExpr::List(Box::new(TypeExpr::Int));
249
250        assert_eq!(list1, list2);
251        assert_ne!(list1, list3);
252    }
253}