1use crate::types::{StackType, Type};
9use std::collections::HashMap;
10
11pub type TypeSubst = HashMap<String, Type>;
13
14pub type RowSubst = HashMap<String, StackType>;
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct Subst {
20 pub types: TypeSubst,
21 pub rows: RowSubst,
22}
23
24impl Subst {
25 pub fn empty() -> Self {
27 Subst {
28 types: HashMap::new(),
29 rows: HashMap::new(),
30 }
31 }
32
33 pub fn apply_type(&self, ty: &Type) -> Type {
35 match ty {
36 Type::Var(name) => self.types.get(name).cloned().unwrap_or(ty.clone()),
37 _ => ty.clone(),
38 }
39 }
40
41 pub fn apply_stack(&self, stack: &StackType) -> StackType {
43 match stack {
44 StackType::Empty => StackType::Empty,
45 StackType::Cons { rest, top } => {
46 let new_rest = self.apply_stack(rest);
47 let new_top = self.apply_type(top);
48 StackType::Cons {
49 rest: Box::new(new_rest),
50 top: new_top,
51 }
52 }
53 StackType::RowVar(name) => self.rows.get(name).cloned().unwrap_or(stack.clone()),
54 }
55 }
56
57 pub fn compose(&self, other: &Subst) -> Subst {
60 let mut types = HashMap::new();
61 let mut rows = HashMap::new();
62
63 for (k, v) in &self.types {
65 types.insert(k.clone(), other.apply_type(v));
66 }
67
68 for (k, v) in &other.types {
70 let v_subst = self.apply_type(v);
71 types.insert(k.clone(), v_subst);
72 }
73
74 for (k, v) in &self.rows {
76 rows.insert(k.clone(), other.apply_stack(v));
77 }
78
79 for (k, v) in &other.rows {
81 let v_subst = self.apply_stack(v);
82 rows.insert(k.clone(), v_subst);
83 }
84
85 Subst { types, rows }
86 }
87}
88
89fn occurs_in_type(var: &str, ty: &Type) -> bool {
103 match ty {
104 Type::Var(name) => name == var,
105 Type::Int | Type::Float | Type::Bool | Type::String => false,
106 Type::Quotation(effect) => {
107 occurs_in_stack(var, &effect.inputs) || occurs_in_stack(var, &effect.outputs)
109 }
110 Type::Closure { effect, captures } => {
111 occurs_in_stack(var, &effect.inputs)
113 || occurs_in_stack(var, &effect.outputs)
114 || captures.iter().any(|t| occurs_in_type(var, t))
115 }
116 }
117}
118
119fn occurs_in_stack(var: &str, stack: &StackType) -> bool {
121 match stack {
122 StackType::Empty => false,
123 StackType::RowVar(name) => name == var,
124 StackType::Cons { rest, top: _ } => {
125 occurs_in_stack(var, rest)
128 }
129 }
130}
131
132pub fn unify_types(t1: &Type, t2: &Type) -> Result<Subst, String> {
134 match (t1, t2) {
135 (Type::Int, Type::Int)
137 | (Type::Float, Type::Float)
138 | (Type::Bool, Type::Bool)
139 | (Type::String, Type::String) => Ok(Subst::empty()),
140
141 (Type::Var(name), ty) | (ty, Type::Var(name)) => {
143 if matches!(ty, Type::Var(ty_name) if ty_name == name) {
145 return Ok(Subst::empty());
146 }
147
148 if occurs_in_type(name, ty) {
150 return Err(format!(
151 "Occurs check failed: cannot unify {:?} with {:?} (would create infinite type)",
152 Type::Var(name.clone()),
153 ty
154 ));
155 }
156
157 let mut subst = Subst::empty();
158 subst.types.insert(name.clone(), ty.clone());
159 Ok(subst)
160 }
161
162 (Type::Quotation(effect1), Type::Quotation(effect2)) => {
164 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
166
167 let out1 = s_in.apply_stack(&effect1.outputs);
169 let out2 = s_in.apply_stack(&effect2.outputs);
170 let s_out = unify_stacks(&out1, &out2)?;
171
172 Ok(s_in.compose(&s_out))
174 }
175
176 (
180 Type::Closure {
181 effect: effect1, ..
182 },
183 Type::Closure {
184 effect: effect2, ..
185 },
186 ) => {
187 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
189
190 let out1 = s_in.apply_stack(&effect1.outputs);
192 let out2 = s_in.apply_stack(&effect2.outputs);
193 let s_out = unify_stacks(&out1, &out2)?;
194
195 Ok(s_in.compose(&s_out))
197 }
198
199 (Type::Quotation(quot_effect), Type::Closure { effect, .. })
203 | (Type::Closure { effect, .. }, Type::Quotation(quot_effect)) => {
204 let s_in = unify_stacks("_effect.inputs, &effect.inputs)?;
206
207 let out1 = s_in.apply_stack("_effect.outputs);
209 let out2 = s_in.apply_stack(&effect.outputs);
210 let s_out = unify_stacks(&out1, &out2)?;
211
212 Ok(s_in.compose(&s_out))
214 }
215
216 _ => Err(format!(
218 "Type mismatch: cannot unify {:?} with {:?}",
219 t1, t2
220 )),
221 }
222}
223
224pub fn unify_stacks(s1: &StackType, s2: &StackType) -> Result<Subst, String> {
226 match (s1, s2) {
227 (StackType::Empty, StackType::Empty) => Ok(Subst::empty()),
229
230 (StackType::RowVar(name), stack) | (stack, StackType::RowVar(name)) => {
232 if matches!(stack, StackType::RowVar(stack_name) if stack_name == name) {
234 return Ok(Subst::empty());
235 }
236
237 if occurs_in_stack(name, stack) {
239 return Err(format!(
240 "Occurs check failed: cannot unify {:?} with {:?} (would create infinite stack type)",
241 StackType::RowVar(name.clone()),
242 stack
243 ));
244 }
245
246 let mut subst = Subst::empty();
247 subst.rows.insert(name.clone(), stack.clone());
248 Ok(subst)
249 }
250
251 (
253 StackType::Cons {
254 rest: rest1,
255 top: top1,
256 },
257 StackType::Cons {
258 rest: rest2,
259 top: top2,
260 },
261 ) => {
262 let s_top = unify_types(top1, top2)?;
264
265 let rest1_subst = s_top.apply_stack(rest1);
267 let rest2_subst = s_top.apply_stack(rest2);
268 let s_rest = unify_stacks(&rest1_subst, &rest2_subst)?;
269
270 Ok(s_top.compose(&s_rest))
272 }
273
274 _ => Err(format!(
276 "Stack shape mismatch: cannot unify {:?} with {:?}",
277 s1, s2
278 )),
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_unify_concrete_types() {
288 assert!(unify_types(&Type::Int, &Type::Int).is_ok());
289 assert!(unify_types(&Type::Bool, &Type::Bool).is_ok());
290 assert!(unify_types(&Type::String, &Type::String).is_ok());
291
292 assert!(unify_types(&Type::Int, &Type::Bool).is_err());
293 }
294
295 #[test]
296 fn test_unify_type_variable() {
297 let subst = unify_types(&Type::Var("T".to_string()), &Type::Int).unwrap();
298 assert_eq!(subst.types.get("T"), Some(&Type::Int));
299
300 let subst = unify_types(&Type::Bool, &Type::Var("U".to_string())).unwrap();
301 assert_eq!(subst.types.get("U"), Some(&Type::Bool));
302 }
303
304 #[test]
305 fn test_unify_empty_stacks() {
306 assert!(unify_stacks(&StackType::Empty, &StackType::Empty).is_ok());
307 }
308
309 #[test]
310 fn test_unify_row_variable() {
311 let subst = unify_stacks(
312 &StackType::RowVar("a".to_string()),
313 &StackType::singleton(Type::Int),
314 )
315 .unwrap();
316
317 assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Int)));
318 }
319
320 #[test]
321 fn test_unify_cons_stacks() {
322 let s1 = StackType::singleton(Type::Int);
324 let s2 = StackType::singleton(Type::Int);
325
326 assert!(unify_stacks(&s1, &s2).is_ok());
327 }
328
329 #[test]
330 fn test_unify_cons_with_type_var() {
331 let s1 = StackType::singleton(Type::Var("T".to_string()));
333 let s2 = StackType::singleton(Type::Int);
334
335 let subst = unify_stacks(&s1, &s2).unwrap();
336 assert_eq!(subst.types.get("T"), Some(&Type::Int));
337 }
338
339 #[test]
340 fn test_unify_row_poly_stack() {
341 let s1 = StackType::RowVar("a".to_string()).push(Type::Int);
343 let s2 = StackType::Empty.push(Type::Bool).push(Type::Int);
344
345 let subst = unify_stacks(&s1, &s2).unwrap();
346
347 assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Bool)));
348 }
349
350 #[test]
351 fn test_unify_polymorphic_dup() {
352 let input_actual = StackType::singleton(Type::Int);
356 let input_declared = StackType::RowVar("a".to_string()).push(Type::Var("T".to_string()));
357
358 let subst = unify_stacks(&input_declared, &input_actual).unwrap();
359
360 assert_eq!(subst.rows.get("a"), Some(&StackType::Empty));
361 assert_eq!(subst.types.get("T"), Some(&Type::Int));
362
363 let output_declared = StackType::RowVar("a".to_string())
365 .push(Type::Var("T".to_string()))
366 .push(Type::Var("T".to_string()));
367
368 let output_actual = subst.apply_stack(&output_declared);
369
370 assert_eq!(
372 output_actual,
373 StackType::Empty.push(Type::Int).push(Type::Int)
374 );
375 }
376
377 #[test]
378 fn test_subst_compose() {
379 let mut s1 = Subst::empty();
381 s1.types.insert("T".to_string(), Type::Int);
382
383 let mut s2 = Subst::empty();
385 s2.types.insert("U".to_string(), Type::Var("T".to_string()));
386
387 let composed = s1.compose(&s2);
389
390 assert_eq!(composed.types.get("T"), Some(&Type::Int));
391 assert_eq!(composed.types.get("U"), Some(&Type::Int));
392 }
393
394 #[test]
395 fn test_occurs_check_type_var_with_itself() {
396 let result = unify_types(&Type::Var("T".to_string()), &Type::Var("T".to_string()));
398 assert!(result.is_ok());
399 let subst = result.unwrap();
400 assert!(subst.types.is_empty());
402 }
403
404 #[test]
405 fn test_occurs_check_row_var_with_itself() {
406 let result = unify_stacks(
408 &StackType::RowVar("a".to_string()),
409 &StackType::RowVar("a".to_string()),
410 );
411 assert!(result.is_ok());
412 let subst = result.unwrap();
413 assert!(subst.rows.is_empty());
415 }
416
417 #[test]
418 fn test_occurs_check_prevents_infinite_stack() {
419 let row_var = StackType::RowVar("a".to_string());
422 let infinite_stack = StackType::RowVar("a".to_string()).push(Type::Int);
423
424 let result = unify_stacks(&row_var, &infinite_stack);
425 assert!(result.is_err());
426 let err = result.unwrap_err();
427 assert!(err.contains("Occurs check failed"));
428 assert!(err.contains("infinite"));
429 }
430
431 #[test]
432 fn test_occurs_check_allows_different_row_vars() {
433 let result = unify_stacks(
435 &StackType::RowVar("a".to_string()),
436 &StackType::RowVar("b".to_string()),
437 );
438 assert!(result.is_ok());
439 let subst = result.unwrap();
440 assert_eq!(
441 subst.rows.get("a"),
442 Some(&StackType::RowVar("b".to_string()))
443 );
444 }
445
446 #[test]
447 fn test_occurs_check_allows_concrete_stack() {
448 let row_var = StackType::RowVar("a".to_string());
450 let concrete = StackType::Empty.push(Type::Int).push(Type::String);
451
452 let result = unify_stacks(&row_var, &concrete);
453 assert!(result.is_ok());
454 let subst = result.unwrap();
455 assert_eq!(subst.rows.get("a"), Some(&concrete));
456 }
457
458 #[test]
459 fn test_occurs_in_type() {
460 assert!(occurs_in_type("T", &Type::Var("T".to_string())));
462
463 assert!(!occurs_in_type("T", &Type::Var("U".to_string())));
465
466 assert!(!occurs_in_type("T", &Type::Int));
468 assert!(!occurs_in_type("T", &Type::String));
469 assert!(!occurs_in_type("T", &Type::Bool));
470 }
471
472 #[test]
473 fn test_occurs_in_stack() {
474 assert!(occurs_in_stack("a", &StackType::RowVar("a".to_string())));
476
477 assert!(!occurs_in_stack("a", &StackType::RowVar("b".to_string())));
479
480 assert!(!occurs_in_stack("a", &StackType::Empty));
482
483 let stack = StackType::RowVar("a".to_string()).push(Type::Int);
485 assert!(occurs_in_stack("a", &stack));
486
487 let stack = StackType::RowVar("b".to_string()).push(Type::Int);
489 assert!(!occurs_in_stack("a", &stack));
490
491 let stack = StackType::Empty.push(Type::Int).push(Type::String);
493 assert!(!occurs_in_stack("a", &stack));
494 }
495}