1use super::checking::MethodTable;
7use super::unification::Unifier;
8use super::*;
9use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12fn is_array_or_vec_base(base: &Type) -> bool {
14 match base {
15 Type::Concrete(TypeAnnotation::Reference(name))
16 | Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
17 _ => false,
18 }
19}
20
21pub struct ConstraintSolver {
22 unifier: Unifier,
24 _deferred: Vec<(Type, Type)>,
27 bounds: HashMap<TypeVar, TypeConstraint>,
29 method_table: Option<MethodTable>,
31 trait_impls: HashSet<String>,
33}
34
35impl Default for ConstraintSolver {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl ConstraintSolver {
42 pub fn new() -> Self {
43 ConstraintSolver {
44 unifier: Unifier::new(),
45 _deferred: Vec::new(),
46 bounds: HashMap::new(),
47 method_table: None,
48 trait_impls: HashSet::new(),
49 }
50 }
51
52 pub fn set_method_table(&mut self, table: MethodTable) {
56 self.method_table = Some(table);
57 }
58
59 pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
62 self.trait_impls = impls;
63 }
64
65 pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
67 let mut unsolved = Vec::new();
69
70 for (t1, t2) in constraints.drain(..) {
71 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
72 unsolved.push((t1, t2));
74 }
75 }
76
77 let mut made_progress = true;
79 while made_progress && !unsolved.is_empty() {
80 made_progress = false;
81 let mut still_unsolved = Vec::new();
82
83 for (t1, t2) in unsolved.drain(..) {
84 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
85 still_unsolved.push((t1, t2));
86 } else {
87 made_progress = true;
88 }
89 }
90
91 unsolved = still_unsolved;
92 }
93
94 if !unsolved.is_empty() {
96 return Err(TypeError::UnsolvedConstraints(unsolved));
97 }
98
99 self.apply_bounds()?;
101
102 Ok(())
103 }
104
105 fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
107 let t1 = self.unifier.apply_substitutions(&t1);
111 let t2 = self.unifier.apply_substitutions(&t2);
112
113 match (&t1, &t2) {
114 (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
116
117 (Type::Constrained { var, constraint }, ty)
121 | (ty, Type::Constrained { var, constraint }) => {
122 self.bounds.insert(var.clone(), *constraint.clone());
124
125 self.solve_constraint(Type::Variable(var.clone()), ty.clone())
127 }
128
129 (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
130 if self.occurs_in(var, ty) {
132 return Err(TypeError::InfiniteType(var.clone()));
133 }
134
135 self.unifier.bind(var.clone(), ty.clone());
136 Ok(())
137 }
138
139 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
141 if self.unify_annotations(ann1, ann2)? {
142 Ok(())
143 } else if Self::can_numeric_widen(ann1, ann2) {
144 Ok(())
146 } else {
147 Err(TypeError::TypeMismatch(
148 format!("{:?}", ann1),
149 format!("{:?}", ann2),
150 ))
151 }
152 }
153
154 (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
156 self.solve_constraint(*b1.clone(), *b2.clone())?;
157
158 let is_result_base = |base: &Type| {
159 matches!(
160 base,
161 Type::Concrete(TypeAnnotation::Reference(name))
162 | Type::Concrete(TypeAnnotation::Basic(name))
163 if name == "Result"
164 )
165 };
166
167 if a1.len() != a2.len() {
168 if is_result_base(&b1) && is_result_base(&b2) {
169 match (a1.len(), a2.len()) {
170 (1, 2) | (2, 1) => {
173 self.solve_constraint(a1[0].clone(), a2[0].clone())?;
174 return Ok(());
175 }
176 _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
177 }
178 } else {
179 return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
180 }
181 }
182
183 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
184 self.solve_constraint(arg1.clone(), arg2.clone())?;
185 }
186
187 Ok(())
188 }
189
190 (
192 Type::Function {
193 params: p1,
194 returns: r1,
195 },
196 Type::Function {
197 params: p2,
198 returns: r2,
199 },
200 ) => {
201 if p1.len() != p2.len() {
202 return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
203 }
204 for (param1, param2) in p1.iter().zip(p2.iter()) {
205 self.solve_constraint(param2.clone(), param1.clone())?;
209 }
210 self.solve_constraint(*r1.clone(), *r2.clone())
211 }
212
213 (
215 Type::Function {
216 params: fp,
217 returns: fr,
218 },
219 Type::Concrete(TypeAnnotation::Function {
220 params: cp,
221 returns: cr,
222 }),
223 )
224 | (
225 Type::Concrete(TypeAnnotation::Function {
226 params: cp,
227 returns: cr,
228 }),
229 Type::Function {
230 params: fp,
231 returns: fr,
232 },
233 ) => {
234 if fp.len() != cp.len() {
235 return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
236 }
237 for (f_param, c_param) in fp.iter().zip(cp.iter()) {
238 self.solve_constraint(
239 f_param.clone(),
240 Type::Concrete(c_param.type_annotation.clone()),
241 )?;
242 }
243 self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
244 }
245
246 (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
248 | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
249 if args.len() == 1 && is_array_or_vec_base(base) =>
250 {
251 self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
252 }
253
254 _ => Err(TypeError::TypeMismatch(
255 format!("{:?}", t1),
256 format!("{:?}", t2),
257 )),
258 }
259 }
260
261 fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
263 match ty {
264 Type::Variable(v) => v == var,
265 Type::Generic { base, args } => {
266 self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
267 }
268 Type::Constrained { var: v, .. } => v == var,
269 Type::Function { params, returns } => {
270 params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
271 }
272 Type::Concrete(_) => false,
273 }
274 }
275
276 fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
283 let from_name = match from {
284 TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
285 _ => None,
286 };
287 let to_name = match to {
288 TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
289 _ => None,
290 };
291
292 match (from_name, to_name) {
293 (Some(f), Some(t)) => {
294 BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
295 }
296 _ => false,
297 }
298 }
299
300 fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
302 match (ann1, ann2) {
303 (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
305 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
306 }
307 (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
308 (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
309 | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
310 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
311 }
312
313 (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
315 self.unify_annotations(e1, e2)
316 }
317
318 (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
320 if t1.len() != t2.len() {
321 return Ok(false);
322 }
323
324 for (elem1, elem2) in t1.iter().zip(t2.iter()) {
325 if !self.unify_annotations(elem1, elem2)? {
326 return Ok(false);
327 }
328 }
329
330 Ok(true)
331 }
332
333 (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
335 self.object_fields_compatible(f1, f2)
336 }
337
338 (
340 TypeAnnotation::Function {
341 params: p1,
342 returns: r1,
343 },
344 TypeAnnotation::Function {
345 params: p2,
346 returns: r2,
347 },
348 ) => {
349 if p1.len() != p2.len() {
350 return Ok(false);
351 }
352
353 for (param1, param2) in p1.iter().zip(p2.iter()) {
354 if !self.unify_annotations(¶m1.type_annotation, ¶m2.type_annotation)? {
355 return Ok(false);
356 }
357 }
358
359 self.unify_annotations(r1, r2)
360 }
361
362 (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
365 for t1 in u1 {
367 let mut found_match = false;
368 for t2 in u2 {
369 if self.unify_annotations(t1, t2)? {
370 found_match = true;
371 break;
372 }
373 }
374 if !found_match {
375 return Ok(false);
376 }
377 }
378 for t2 in u2 {
380 let mut found_match = false;
381 for t1 in u1 {
382 if self.unify_annotations(t1, t2)? {
383 found_match = true;
384 break;
385 }
386 }
387 if !found_match {
388 return Ok(false);
389 }
390 }
391 Ok(true)
392 }
393
394 (TypeAnnotation::Union(union_types), other)
396 | (other, TypeAnnotation::Union(union_types)) => {
397 for union_type in union_types {
398 if self.unify_annotations(union_type, other)? {
399 return Ok(true);
400 }
401 }
402 Ok(false)
403 }
404
405 (TypeAnnotation::Optional(o1), TypeAnnotation::Optional(o2)) => {
407 self.unify_annotations(o1, o2)
408 }
409
410 (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
412 self.unify_annotation_sets(i1, i2)
413 }
414
415 (TypeAnnotation::Any, _) | (_, TypeAnnotation::Any) => Ok(true),
417
418 (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
420 (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
421 (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
422
423 (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
426 Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
427 }
428
429 (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
431 | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
432 if name == "Array" && args.len() == 1 =>
433 {
434 self.unify_annotations(&args[0], elem)
435 }
436
437 _ => Ok(false),
439 }
440 }
441
442 fn object_fields_compatible(
443 &self,
444 left: &[ObjectTypeField],
445 right: &[ObjectTypeField],
446 ) -> TypeResult<bool> {
447 for left_field in left {
448 let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
449 return Ok(false);
450 };
451 if left_field.optional != right_field.optional {
452 return Ok(false);
453 }
454 if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
455 return Ok(false);
456 }
457 }
458 if left.len() != right.len() {
459 return Ok(false);
460 }
461 Ok(true)
462 }
463
464 fn unify_annotation_sets(
465 &self,
466 left: &[TypeAnnotation],
467 right: &[TypeAnnotation],
468 ) -> TypeResult<bool> {
469 if left.len() != right.len() {
470 return Ok(false);
471 }
472
473 let mut matched = vec![false; right.len()];
474 for left_ann in left {
475 let mut found = false;
476 for (idx, right_ann) in right.iter().enumerate() {
477 if matched[idx] {
478 continue;
479 }
480 if self.unify_annotations(left_ann, right_ann)? {
481 matched[idx] = true;
482 found = true;
483 break;
484 }
485 }
486 if !found {
487 return Ok(false);
488 }
489 }
490
491 Ok(true)
492 }
493
494 fn apply_bounds(&mut self) -> TypeResult<()> {
500 let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
501
502 for (var, constraint) in &self.bounds {
503 let resolved = self
506 .unifier
507 .apply_substitutions(&Type::Variable(var.clone()));
508
509 if let Type::Variable(_) = &resolved {
510 continue;
512 }
513
514 self.check_constraint(&resolved, constraint)?;
515
516 if let TypeConstraint::HasField(field, expected_field_type) = constraint {
519 if let Type::Variable(field_var) = expected_field_type.as_ref() {
520 let field_resolved = self
522 .unifier
523 .apply_substitutions(&Type::Variable(field_var.clone()));
524 if let Type::Variable(_) = &field_resolved {
525 if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
527 if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
528 new_bindings.push((
529 field_var.clone(),
530 Type::Concrete(found_field.type_annotation.clone()),
531 ));
532 }
533 }
534 }
535 }
536 }
537 }
538
539 for (var, ty) in new_bindings {
541 self.unifier.bind(var, ty);
542 }
543
544 Ok(())
545 }
546
547 fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
549 match constraint {
550 TypeConstraint::Numeric => match ty {
551 Type::Concrete(TypeAnnotation::Basic(name))
552 if BuiltinTypes::is_numeric_type_name(name) =>
553 {
554 Ok(())
555 }
556 _ => Err(TypeError::ConstraintViolation(format!(
557 "{:?} is not numeric",
558 ty
559 ))),
560 },
561
562 TypeConstraint::Comparable => match ty {
563 Type::Concrete(TypeAnnotation::Basic(name))
564 if BuiltinTypes::is_numeric_type_name(name)
565 || name == "string"
566 || name == "bool" =>
567 {
568 Ok(())
569 }
570 _ => Err(TypeError::ConstraintViolation(format!(
571 "{:?} is not comparable",
572 ty
573 ))),
574 },
575
576 TypeConstraint::Iterable => match ty {
577 Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
578 Type::Concrete(TypeAnnotation::Basic(name))
579 if name == "string" || name == "rows" =>
580 {
581 Ok(())
582 }
583 _ => Err(TypeError::ConstraintViolation(format!(
584 "{:?} is not iterable",
585 ty
586 ))),
587 },
588
589 TypeConstraint::HasField(field, expected_field_type) => {
590 match ty {
591 Type::Concrete(TypeAnnotation::Object(fields)) => {
592 match fields.iter().find(|f| f.name == *field) {
593 Some(found_field) => {
594 if let Some(expected_ann) = expected_field_type.to_annotation() {
596 if self.unify_annotations(
597 &found_field.type_annotation,
598 &expected_ann,
599 )? {
600 Ok(())
601 } else {
602 Err(TypeError::ConstraintViolation(format!(
603 "field '{}' has type {:?}, expected {:?}",
604 field, found_field.type_annotation, expected_ann
605 )))
606 }
607 } else {
608 Ok(())
610 }
611 }
612 None => Err(TypeError::ConstraintViolation(format!(
613 "{:?} does not have field '{}'",
614 ty, field
615 ))),
616 }
617 }
618 Type::Concrete(TypeAnnotation::Basic(_name)) => {
619 Ok(())
628 }
629 _ => Err(TypeError::ConstraintViolation(format!(
630 "{:?} cannot have fields",
631 ty
632 ))),
633 }
634 }
635
636 TypeConstraint::Callable {
637 params: expected_params,
638 returns: expected_returns,
639 } => {
640 match ty {
641 Type::Concrete(TypeAnnotation::Function {
642 params: actual_params,
643 returns: actual_returns,
644 }) => {
645 if expected_params.len() != actual_params.len() {
647 return Err(TypeError::ConstraintViolation(format!(
648 "function expects {} parameters, got {}",
649 expected_params.len(),
650 actual_params.len()
651 )));
652 }
653
654 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
656 if let Some(expected_ann) = expected.to_annotation() {
657 if !self
658 .unify_annotations(&expected_ann, &actual.type_annotation)?
659 {
660 return Err(TypeError::ConstraintViolation(format!(
661 "parameter type mismatch: expected {:?}, got {:?}",
662 expected_ann, actual.type_annotation
663 )));
664 }
665 }
666 }
667
668 if let Some(expected_ret_ann) = expected_returns.to_annotation() {
670 if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
671 return Err(TypeError::ConstraintViolation(format!(
672 "return type mismatch: expected {:?}, got {:?}",
673 expected_ret_ann, actual_returns
674 )));
675 }
676 }
677
678 Ok(())
679 }
680 Type::Function {
681 params: actual_params,
682 returns: actual_returns,
683 } => {
684 if expected_params.len() != actual_params.len() {
685 return Err(TypeError::ConstraintViolation(format!(
686 "function expects {} parameters, got {}",
687 expected_params.len(),
688 actual_params.len()
689 )));
690 }
691 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
693 if let (Some(e_ann), Some(a_ann)) =
694 (expected.to_annotation(), actual.to_annotation())
695 {
696 if !self.unify_annotations(&e_ann, &a_ann)? {
697 return Err(TypeError::ConstraintViolation(format!(
698 "parameter type mismatch: expected {:?}, got {:?}",
699 e_ann, a_ann
700 )));
701 }
702 }
703 }
704 if let (Some(e_ret), Some(a_ret)) = (
705 expected_returns.to_annotation(),
706 actual_returns.to_annotation(),
707 ) {
708 if !self.unify_annotations(&a_ret, &e_ret)? {
709 return Err(TypeError::ConstraintViolation(format!(
710 "return type mismatch: expected {:?}, got {:?}",
711 e_ret, a_ret
712 )));
713 }
714 }
715 Ok(())
716 }
717 _ => Err(TypeError::ConstraintViolation(format!(
718 "{:?} is not callable",
719 ty
720 ))),
721 }
722 }
723
724 TypeConstraint::OneOf(options) => {
725 for option in options {
726 if let Type::Concrete(ann) = option {
728 if let Type::Concrete(ty_ann) = ty {
729 if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
730 return Ok(());
731 }
732 }
733 }
734 }
735
736 Err(TypeError::ConstraintViolation(format!(
737 "{:?} does not match any of {:?}",
738 ty, options
739 )))
740 }
741
742 TypeConstraint::Extends(base) => {
743 self.is_subtype(ty, base)
745 }
746
747 TypeConstraint::ImplementsTrait { trait_name } => {
748 match ty {
749 Type::Variable(_) => {
750 Err(TypeError::TraitBoundViolation {
753 type_name: format!("{:?}", ty),
754 trait_name: trait_name.clone(),
755 })
756 }
757 Type::Concrete(ann) => {
758 let type_name = match ann {
759 TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => n.clone(),
760 _ => format!("{:?}", ann),
761 };
762 if self.has_trait_impl(trait_name, &type_name) {
763 Ok(())
764 } else {
765 Err(TypeError::TraitBoundViolation {
766 type_name,
767 trait_name: trait_name.clone(),
768 })
769 }
770 }
771 Type::Generic { base, .. } => {
772 let type_name = if let Type::Concrete(
773 TypeAnnotation::Reference(n) | TypeAnnotation::Basic(n),
774 ) = base.as_ref()
775 {
776 n.clone()
777 } else {
778 format!("{:?}", base)
779 };
780 if self.has_trait_impl(trait_name, &type_name) {
781 Ok(())
782 } else {
783 Err(TypeError::TraitBoundViolation {
784 type_name,
785 trait_name: trait_name.clone(),
786 })
787 }
788 }
789 _ => Err(TypeError::TraitBoundViolation {
790 type_name: format!("{:?}", ty),
791 trait_name: trait_name.clone(),
792 }),
793 }
794 }
795
796 TypeConstraint::HasMethod {
797 method_name,
798 arg_types: _,
799 return_type: _,
800 } => {
801 if let Some(method_table) = &self.method_table {
803 match ty {
804 Type::Variable(_) => Ok(()), Type::Concrete(ann) => {
806 let type_name = match ann {
807 TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => {
808 n.clone()
809 }
810 TypeAnnotation::Array(_) => "Vec".to_string(),
811 _ => return Ok(()), };
813 if method_table.lookup(ty, method_name).is_some() {
814 Ok(())
815 } else {
816 Err(TypeError::MethodNotFound {
817 type_name,
818 method_name: method_name.clone(),
819 })
820 }
821 }
822 Type::Generic { base, .. } => {
823 if method_table.lookup(ty, method_name).is_some() {
824 Ok(())
825 } else {
826 let type_name =
827 if let Type::Concrete(TypeAnnotation::Reference(n)) =
828 base.as_ref()
829 {
830 n.clone()
831 } else {
832 format!("{:?}", base)
833 };
834 Err(TypeError::MethodNotFound {
835 type_name,
836 method_name: method_name.clone(),
837 })
838 }
839 }
840 _ => Ok(()), }
842 } else {
843 Ok(())
845 }
846 }
847 }
848 }
849
850 fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
855 let key = format!("{}::{}", trait_name, type_name);
856 if self.trait_impls.contains(&key) {
857 return true;
858 }
859 if BuiltinTypes::is_integer_type_name(type_name) {
861 for widen_to in &["number", "float", "f64"] {
862 let widen_key = format!("{}::{}", trait_name, widen_to);
863 if self.trait_impls.contains(&widen_key) {
864 return true;
865 }
866 }
867 }
868 false
869 }
870
871 fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
878 match (ty, base) {
879 (t1, t2) if t1 == t2 => Ok(()),
881
882 (_, Type::Concrete(TypeAnnotation::Any)) => Ok(()),
884
885 (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
887
888 (
890 Type::Concrete(TypeAnnotation::Array(elem1)),
891 Type::Concrete(TypeAnnotation::Array(elem2)),
892 ) => {
893 let t1 = Type::Concrete(*elem1.clone());
894 let t2 = Type::Concrete(*elem2.clone());
895 self.is_subtype(&t1, &t2)
896 }
897
898 (
900 Type::Concrete(TypeAnnotation::Function {
901 params: p1,
902 returns: r1,
903 }),
904 Type::Concrete(TypeAnnotation::Function {
905 params: p2,
906 returns: r2,
907 }),
908 ) => {
909 if p1.len() != p2.len() {
911 return Err(TypeError::ConstraintViolation(format!(
912 "function parameter count mismatch: {} vs {}",
913 p1.len(),
914 p2.len()
915 )));
916 }
917
918 for (param1, param2) in p1.iter().zip(p2.iter()) {
920 let t1 = Type::Concrete(param2.type_annotation.clone());
921 let t2 = Type::Concrete(param1.type_annotation.clone());
922 self.is_subtype(&t1, &t2)?;
923 }
924
925 let ret1 = Type::Concrete(*r1.clone());
927 let ret2 = Type::Concrete(*r2.clone());
928 self.is_subtype(&ret1, &ret2)
929 }
930
931 (t, Type::Concrete(TypeAnnotation::Optional(opt_inner))) => {
933 let inner = Type::Concrete(*opt_inner.clone());
934 self.is_subtype(t, &inner)
935 }
936
937 (
939 Type::Function {
940 params: p1,
941 returns: r1,
942 },
943 Type::Function {
944 params: p2,
945 returns: r2,
946 },
947 ) => {
948 if p1.len() != p2.len() {
949 return Err(TypeError::ConstraintViolation(format!(
950 "function parameter count mismatch: {} vs {}",
951 p1.len(),
952 p2.len()
953 )));
954 }
955 for (param1, param2) in p1.iter().zip(p2.iter()) {
957 self.is_subtype(param2, param1)?;
958 }
959 self.is_subtype(r1, r2)
961 }
962
963 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
965 if self.unify_annotations(ann1, ann2)? {
966 Ok(())
967 } else {
968 Err(TypeError::ConstraintViolation(format!(
969 "{:?} is not a subtype of {:?}",
970 ty, base
971 )))
972 }
973 }
974
975 _ => Err(TypeError::ConstraintViolation(format!(
977 "{:?} is not a subtype of {:?}",
978 ty, base
979 ))),
980 }
981 }
982
983 pub fn unifier(&self) -> &Unifier {
985 &self.unifier
986 }
987}
988
989#[cfg(test)]
990mod tests {
991 use super::*;
992 use shape_ast::ast::ObjectTypeField;
993
994 #[test]
995 fn test_hasfield_backward_propagation_binds_field_type() {
996 let mut solver = ConstraintSolver::new();
1000
1001 let obj_var = TypeVar::fresh();
1002 let field_result_var = TypeVar::fresh();
1003 let bound_var = TypeVar::fresh();
1004
1005 let mut constraints = vec![
1006 (
1010 Type::Variable(obj_var.clone()),
1011 Type::Constrained {
1012 var: bound_var,
1013 constraint: Box::new(TypeConstraint::HasField(
1014 "x".to_string(),
1015 Box::new(Type::Variable(field_result_var.clone())),
1016 )),
1017 },
1018 ),
1019 (
1021 Type::Variable(obj_var),
1022 Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1023 name: "x".to_string(),
1024 optional: false,
1025 type_annotation: TypeAnnotation::Basic("int".to_string()),
1026 annotations: vec![],
1027 }])),
1028 ),
1029 ];
1030
1031 solver.solve(&mut constraints).unwrap();
1032
1033 let resolved = solver
1035 .unifier()
1036 .apply_substitutions(&Type::Variable(field_result_var));
1037 match &resolved {
1038 Type::Concrete(TypeAnnotation::Basic(name)) => {
1039 assert_eq!(name, "int", "field type should be int");
1040 }
1041 _ => panic!(
1042 "Expected field_result_var to be resolved to int, got {:?}",
1043 resolved
1044 ),
1045 }
1046 }
1047
1048 #[test]
1049 fn test_hasfield_backward_propagation_multiple_fields() {
1050 let mut solver = ConstraintSolver::new();
1052
1053 let obj_var = TypeVar::fresh();
1054 let field_x_var = TypeVar::fresh();
1055 let field_y_var = TypeVar::fresh();
1056 let bound_var_x = TypeVar::fresh();
1057 let bound_var_y = TypeVar::fresh();
1058
1059 let mut constraints = vec![
1060 (
1062 Type::Variable(obj_var.clone()),
1063 Type::Constrained {
1064 var: bound_var_x,
1065 constraint: Box::new(TypeConstraint::HasField(
1066 "x".to_string(),
1067 Box::new(Type::Variable(field_x_var.clone())),
1068 )),
1069 },
1070 ),
1071 (
1073 Type::Variable(obj_var.clone()),
1074 Type::Constrained {
1075 var: bound_var_y,
1076 constraint: Box::new(TypeConstraint::HasField(
1077 "y".to_string(),
1078 Box::new(Type::Variable(field_y_var.clone())),
1079 )),
1080 },
1081 ),
1082 (
1084 Type::Variable(obj_var),
1085 Type::Concrete(TypeAnnotation::Object(vec![
1086 ObjectTypeField {
1087 name: "x".to_string(),
1088 optional: false,
1089 type_annotation: TypeAnnotation::Basic("int".to_string()),
1090 annotations: vec![],
1091 },
1092 ObjectTypeField {
1093 name: "y".to_string(),
1094 optional: false,
1095 type_annotation: TypeAnnotation::Basic("string".to_string()),
1096 annotations: vec![],
1097 },
1098 ])),
1099 ),
1100 ];
1101
1102 solver.solve(&mut constraints).unwrap();
1103
1104 let resolved_x = solver
1105 .unifier()
1106 .apply_substitutions(&Type::Variable(field_x_var));
1107 let resolved_y = solver
1108 .unifier()
1109 .apply_substitutions(&Type::Variable(field_y_var));
1110
1111 match &resolved_x {
1112 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1113 _ => panic!("Expected x to be int, got {:?}", resolved_x),
1114 }
1115 match &resolved_y {
1116 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1117 _ => panic!("Expected y to be string, got {:?}", resolved_y),
1118 }
1119 }
1120
1121 #[test]
1124 fn test_int_constrained_numeric_succeeds() {
1125 let mut solver = ConstraintSolver::new();
1127 let bound_var = TypeVar::fresh();
1128 let mut constraints = vec![(
1129 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1130 Type::Constrained {
1131 var: bound_var,
1132 constraint: Box::new(TypeConstraint::Numeric),
1133 },
1134 )];
1135 assert!(solver.solve(&mut constraints).is_ok());
1136 }
1137
1138 #[test]
1139 fn test_numeric_widening_int_to_number() {
1140 let mut solver = ConstraintSolver::new();
1142 let mut constraints = vec![(
1143 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1144 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1145 )];
1146 assert!(solver.solve(&mut constraints).is_ok());
1147 }
1148
1149 #[test]
1150 fn test_numeric_widening_width_aware_integer_to_float_family() {
1151 let mut solver = ConstraintSolver::new();
1152 let mut constraints = vec![(
1153 Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1154 Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1155 )];
1156 assert!(solver.solve(&mut constraints).is_ok());
1157 }
1158
1159 #[test]
1160 fn test_no_widening_number_to_int() {
1161 let mut solver = ConstraintSolver::new();
1163 let mut constraints = vec![(
1164 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1165 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1166 )];
1167 assert!(solver.solve(&mut constraints).is_err());
1168 }
1169
1170 #[test]
1171 fn test_decimal_constrained_numeric_succeeds() {
1172 let mut solver = ConstraintSolver::new();
1173 let bound_var = TypeVar::fresh();
1174 let mut constraints = vec![(
1175 Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1176 Type::Constrained {
1177 var: bound_var,
1178 constraint: Box::new(TypeConstraint::Numeric),
1179 },
1180 )];
1181 assert!(solver.solve(&mut constraints).is_ok());
1182 }
1183
1184 #[test]
1185 fn test_comparable_accepts_int() {
1186 let mut solver = ConstraintSolver::new();
1188 let bound_var = TypeVar::fresh();
1189 let mut constraints = vec![(
1190 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1191 Type::Constrained {
1192 var: bound_var,
1193 constraint: Box::new(TypeConstraint::Comparable),
1194 },
1195 )];
1196 assert!(solver.solve(&mut constraints).is_ok());
1197 }
1198
1199 #[test]
1202 fn test_function_type_preserves_variables() {
1203 let param = Type::Variable(TypeVar::fresh());
1205 let ret = Type::Variable(TypeVar::fresh());
1206 let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1207 match func {
1208 Type::Function { params, returns } => {
1209 assert_eq!(params.len(), 1);
1210 assert_eq!(params[0], param);
1211 assert_eq!(*returns, ret);
1212 }
1213 _ => panic!("Expected Type::Function, got {:?}", func),
1214 }
1215 }
1216
1217 #[test]
1218 fn test_function_unification_binds_variables() {
1219 let mut solver = ConstraintSolver::new();
1221 let t1 = TypeVar::fresh();
1222 let t2 = TypeVar::fresh();
1223
1224 let mut constraints = vec![(
1225 Type::Function {
1226 params: vec![Type::Variable(t1.clone())],
1227 returns: Box::new(Type::Variable(t2.clone())),
1228 },
1229 Type::Function {
1230 params: vec![BuiltinTypes::number()],
1231 returns: Box::new(BuiltinTypes::string()),
1232 },
1233 )];
1234
1235 solver.solve(&mut constraints).unwrap();
1236
1237 let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1238 let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1239 assert_eq!(resolved_t1, BuiltinTypes::number());
1240 assert_eq!(resolved_t2, BuiltinTypes::string());
1241 }
1242
1243 #[test]
1244 fn test_function_cross_unification_with_concrete() {
1245 let mut solver = ConstraintSolver::new();
1247 let t1 = TypeVar::fresh();
1248
1249 let concrete_func = Type::Concrete(TypeAnnotation::Function {
1250 params: vec![shape_ast::ast::FunctionParam {
1251 name: None,
1252 optional: false,
1253 type_annotation: TypeAnnotation::Basic("number".to_string()),
1254 }],
1255 returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1256 });
1257
1258 let mut constraints = vec![(
1259 Type::Function {
1260 params: vec![Type::Variable(t1.clone())],
1261 returns: Box::new(BuiltinTypes::string()),
1262 },
1263 concrete_func,
1264 )];
1265
1266 solver.solve(&mut constraints).unwrap();
1267
1268 let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1269 assert_eq!(resolved, BuiltinTypes::number());
1270 }
1271
1272 #[test]
1273 fn test_object_annotations_unify_structurally() {
1274 let mut solver = ConstraintSolver::new();
1275 let mut constraints = vec![(
1276 Type::Concrete(TypeAnnotation::Object(vec![
1277 ObjectTypeField {
1278 name: "x".to_string(),
1279 optional: false,
1280 type_annotation: TypeAnnotation::Basic("int".to_string()),
1281 annotations: vec![],
1282 },
1283 ObjectTypeField {
1284 name: "y".to_string(),
1285 optional: false,
1286 type_annotation: TypeAnnotation::Basic("int".to_string()),
1287 annotations: vec![],
1288 },
1289 ])),
1290 Type::Concrete(TypeAnnotation::Object(vec![
1291 ObjectTypeField {
1292 name: "x".to_string(),
1293 optional: false,
1294 type_annotation: TypeAnnotation::Basic("int".to_string()),
1295 annotations: vec![],
1296 },
1297 ObjectTypeField {
1298 name: "y".to_string(),
1299 optional: false,
1300 type_annotation: TypeAnnotation::Basic("int".to_string()),
1301 annotations: vec![],
1302 },
1303 ])),
1304 )];
1305 assert!(solver.solve(&mut constraints).is_ok());
1306 }
1307
1308 #[test]
1309 fn test_intersection_annotations_unify_order_independent() {
1310 let mut solver = ConstraintSolver::new();
1311 let obj_xy = TypeAnnotation::Object(vec![
1312 ObjectTypeField {
1313 name: "x".to_string(),
1314 optional: false,
1315 type_annotation: TypeAnnotation::Basic("int".to_string()),
1316 annotations: vec![],
1317 },
1318 ObjectTypeField {
1319 name: "y".to_string(),
1320 optional: false,
1321 type_annotation: TypeAnnotation::Basic("int".to_string()),
1322 annotations: vec![],
1323 },
1324 ]);
1325 let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1326 name: "z".to_string(),
1327 optional: false,
1328 type_annotation: TypeAnnotation::Basic("int".to_string()),
1329 annotations: vec![],
1330 }]);
1331
1332 let mut constraints = vec![(
1333 Type::Concrete(TypeAnnotation::Intersection(vec![
1334 obj_xy.clone(),
1335 obj_z.clone(),
1336 ])),
1337 Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1338 )];
1339 assert!(solver.solve(&mut constraints).is_ok());
1340 }
1341
1342 #[test]
1345 fn test_implements_trait_satisfied() {
1346 let mut solver = ConstraintSolver::new();
1347 let mut impls = std::collections::HashSet::new();
1348 impls.insert("Comparable::number".to_string());
1349 solver.set_trait_impls(impls);
1350
1351 let bound_var = TypeVar::fresh();
1352 let mut constraints = vec![(
1353 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1354 Type::Constrained {
1355 var: bound_var,
1356 constraint: Box::new(TypeConstraint::ImplementsTrait {
1357 trait_name: "Comparable".to_string(),
1358 }),
1359 },
1360 )];
1361 assert!(solver.solve(&mut constraints).is_ok());
1362 }
1363
1364 #[test]
1365 fn test_implements_trait_violated() {
1366 let mut solver = ConstraintSolver::new();
1367 let bound_var = TypeVar::fresh();
1369 let mut constraints = vec![(
1370 Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1371 Type::Constrained {
1372 var: bound_var,
1373 constraint: Box::new(TypeConstraint::ImplementsTrait {
1374 trait_name: "Comparable".to_string(),
1375 }),
1376 },
1377 )];
1378 let result = solver.solve(&mut constraints);
1379 assert!(result.is_err());
1380 match result.unwrap_err() {
1381 TypeError::TraitBoundViolation {
1382 type_name,
1383 trait_name,
1384 } => {
1385 assert_eq!(type_name, "string");
1386 assert_eq!(trait_name, "Comparable");
1387 }
1388 other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1389 }
1390 }
1391
1392 #[test]
1393 fn test_implements_trait_via_variable_resolution() {
1394 let mut solver = ConstraintSolver::new();
1395 let mut impls = std::collections::HashSet::new();
1396 impls.insert("Sortable::number".to_string());
1397 solver.set_trait_impls(impls);
1398
1399 let type_var = TypeVar::fresh();
1400 let bound_var = TypeVar::fresh();
1401
1402 let mut constraints = vec![
1403 (
1405 Type::Variable(type_var.clone()),
1406 Type::Constrained {
1407 var: bound_var,
1408 constraint: Box::new(TypeConstraint::ImplementsTrait {
1409 trait_name: "Sortable".to_string(),
1410 }),
1411 },
1412 ),
1413 (
1415 Type::Variable(type_var),
1416 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1417 ),
1418 ];
1419 assert!(
1420 solver.solve(&mut constraints).is_ok(),
1421 "T resolved to number which implements Sortable"
1422 );
1423 }
1424}