1use crate::span::Ident;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum TypeExpr {
9 Int,
11 Float,
13 Bool,
15 String,
17 Unit,
19 Error,
21 List(Box<TypeExpr>),
23 Option(Box<TypeExpr>),
25 Oracle(Box<TypeExpr>),
27 Agent(Ident),
29 Named(Ident, Vec<TypeExpr>),
32 Fn(Vec<TypeExpr>, Box<TypeExpr>),
35 Map(Box<TypeExpr>, Box<TypeExpr>),
37 Tuple(Vec<TypeExpr>),
39 Result(Box<TypeExpr>, Box<TypeExpr>),
41}
42
43impl TypeExpr {
44 #[must_use]
46 pub fn is_primitive(&self) -> bool {
47 matches!(
48 self,
49 TypeExpr::Int
50 | TypeExpr::Float
51 | TypeExpr::Bool
52 | TypeExpr::String
53 | TypeExpr::Unit
54 | TypeExpr::Error
55 )
56 }
57
58 #[must_use]
60 pub fn is_compound(&self) -> bool {
61 matches!(
62 self,
63 TypeExpr::List(_)
64 | TypeExpr::Option(_)
65 | TypeExpr::Oracle(_)
66 | TypeExpr::Agent(_)
67 | TypeExpr::Fn(_, _)
68 | TypeExpr::Map(_, _)
69 | TypeExpr::Tuple(_)
70 | TypeExpr::Result(_, _)
71 )
72 }
73
74 #[must_use]
76 pub fn inner_type(&self) -> Option<&TypeExpr> {
77 match self {
78 TypeExpr::List(inner) | TypeExpr::Option(inner) | TypeExpr::Oracle(inner) => {
79 Some(inner)
80 }
81 _ => None,
82 }
83 }
84}
85
86impl fmt::Display for TypeExpr {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 match self {
89 TypeExpr::Int => write!(f, "Int"),
90 TypeExpr::Float => write!(f, "Float"),
91 TypeExpr::Bool => write!(f, "Bool"),
92 TypeExpr::String => write!(f, "String"),
93 TypeExpr::Unit => write!(f, "Unit"),
94 TypeExpr::Error => write!(f, "Error"),
95 TypeExpr::List(inner) => write!(f, "List<{inner}>"),
96 TypeExpr::Option(inner) => write!(f, "Option<{inner}>"),
97 TypeExpr::Oracle(inner) => write!(f, "Oracle<{inner}>"),
98 TypeExpr::Agent(name) => write!(f, "Agent<{name}>"),
99 TypeExpr::Named(name, type_args) => {
100 write!(f, "{name}")?;
101 if !type_args.is_empty() {
102 write!(f, "<")?;
103 for (i, arg) in type_args.iter().enumerate() {
104 if i > 0 {
105 write!(f, ", ")?;
106 }
107 write!(f, "{arg}")?;
108 }
109 write!(f, ">")?;
110 }
111 Ok(())
112 }
113 TypeExpr::Fn(params, ret) => {
114 write!(f, "Fn(")?;
115 for (i, param) in params.iter().enumerate() {
116 if i > 0 {
117 write!(f, ", ")?;
118 }
119 write!(f, "{param}")?;
120 }
121 write!(f, ") -> {ret}")
122 }
123 TypeExpr::Map(key, value) => write!(f, "Map<{key}, {value}>"),
124 TypeExpr::Tuple(elems) => {
125 write!(f, "(")?;
126 for (i, elem) in elems.iter().enumerate() {
127 if i > 0 {
128 write!(f, ", ")?;
129 }
130 write!(f, "{elem}")?;
131 }
132 write!(f, ")")
133 }
134 TypeExpr::Result(ok, err) => write!(f, "Result<{ok}, {err}>"),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn primitive_display() {
145 assert_eq!(format!("{}", TypeExpr::Int), "Int");
146 assert_eq!(format!("{}", TypeExpr::Float), "Float");
147 assert_eq!(format!("{}", TypeExpr::Bool), "Bool");
148 assert_eq!(format!("{}", TypeExpr::String), "String");
149 assert_eq!(format!("{}", TypeExpr::Unit), "Unit");
150 }
151
152 #[test]
153 fn compound_display() {
154 let list_str = TypeExpr::List(Box::new(TypeExpr::String));
155 assert_eq!(format!("{list_str}"), "List<String>");
156
157 let option_int = TypeExpr::Option(Box::new(TypeExpr::Int));
158 assert_eq!(format!("{option_int}"), "Option<Int>");
159
160 let oracle_str = TypeExpr::Oracle(Box::new(TypeExpr::String));
161 assert_eq!(format!("{oracle_str}"), "Oracle<String>");
162
163 let agent = TypeExpr::Agent(Ident::dummy("Researcher"));
164 assert_eq!(format!("{agent}"), "Agent<Researcher>");
165 }
166
167 #[test]
168 fn nested_compound_display() {
169 let nested = TypeExpr::List(Box::new(TypeExpr::List(Box::new(TypeExpr::Int))));
171 assert_eq!(format!("{nested}"), "List<List<Int>>");
172
173 let nested = TypeExpr::Option(Box::new(TypeExpr::List(Box::new(TypeExpr::String))));
175 assert_eq!(format!("{nested}"), "Option<List<String>>");
176 }
177
178 #[test]
179 fn fn_type_display() {
180 let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
182 assert_eq!(format!("{fn_type}"), "Fn(Int) -> Int");
183
184 let fn_type = TypeExpr::Fn(
186 vec![TypeExpr::String, TypeExpr::Int],
187 Box::new(TypeExpr::Bool),
188 );
189 assert_eq!(format!("{fn_type}"), "Fn(String, Int) -> Bool");
190
191 let fn_type = TypeExpr::Fn(vec![], Box::new(TypeExpr::String));
193 assert_eq!(format!("{fn_type}"), "Fn() -> String");
194
195 let inner_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Int));
197 let outer_fn = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(inner_fn));
198 assert_eq!(format!("{outer_fn}"), "Fn(Int) -> Fn(Int) -> Int");
199 }
200
201 #[test]
202 fn fn_type_is_compound() {
203 let fn_type = TypeExpr::Fn(vec![TypeExpr::Int], Box::new(TypeExpr::Bool));
204 assert!(fn_type.is_compound());
205 assert!(!fn_type.is_primitive());
206 }
207
208 #[test]
209 fn is_primitive() {
210 assert!(TypeExpr::Int.is_primitive());
211 assert!(TypeExpr::Float.is_primitive());
212 assert!(TypeExpr::Bool.is_primitive());
213 assert!(TypeExpr::String.is_primitive());
214 assert!(TypeExpr::Unit.is_primitive());
215
216 assert!(!TypeExpr::List(Box::new(TypeExpr::Int)).is_primitive());
217 assert!(!TypeExpr::Option(Box::new(TypeExpr::Int)).is_primitive());
218 }
219
220 #[test]
221 fn is_compound() {
222 assert!(!TypeExpr::Int.is_compound());
223
224 assert!(TypeExpr::List(Box::new(TypeExpr::Int)).is_compound());
225 assert!(TypeExpr::Option(Box::new(TypeExpr::Int)).is_compound());
226 assert!(TypeExpr::Oracle(Box::new(TypeExpr::String)).is_compound());
227 assert!(TypeExpr::Agent(Ident::dummy("Foo")).is_compound());
228 }
229
230 #[test]
231 fn inner_type() {
232 let list = TypeExpr::List(Box::new(TypeExpr::String));
233 assert_eq!(list.inner_type(), Some(&TypeExpr::String));
234
235 let option = TypeExpr::Option(Box::new(TypeExpr::Int));
236 assert_eq!(option.inner_type(), Some(&TypeExpr::Int));
237
238 assert_eq!(TypeExpr::Int.inner_type(), None);
239 }
240
241 #[test]
242 fn equality() {
243 assert_eq!(TypeExpr::Int, TypeExpr::Int);
244 assert_ne!(TypeExpr::Int, TypeExpr::Float);
245
246 let list1 = TypeExpr::List(Box::new(TypeExpr::String));
247 let list2 = TypeExpr::List(Box::new(TypeExpr::String));
248 let list3 = TypeExpr::List(Box::new(TypeExpr::Int));
249
250 assert_eq!(list1, list2);
251 assert_ne!(list1, list3);
252 }
253}