1use super::checking::MethodTable;
7use super::unification::Unifier;
8use super::*;
9use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12fn is_array_or_vec_base(base: &Type) -> bool {
14 match base {
15 Type::Concrete(TypeAnnotation::Reference(name))
16 | Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
17 _ => false,
18 }
19}
20
21pub struct ConstraintSolver {
22 unifier: Unifier,
24 _deferred: Vec<(Type, Type)>,
27 bounds: HashMap<TypeVar, TypeConstraint>,
29 method_table: Option<MethodTable>,
31 trait_impls: HashSet<String>,
33}
34
35impl Default for ConstraintSolver {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl ConstraintSolver {
42 pub fn new() -> Self {
43 ConstraintSolver {
44 unifier: Unifier::new(),
45 _deferred: Vec::new(),
46 bounds: HashMap::new(),
47 method_table: None,
48 trait_impls: HashSet::new(),
49 }
50 }
51
52 pub fn set_method_table(&mut self, table: MethodTable) {
56 self.method_table = Some(table);
57 }
58
59 pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
62 self.trait_impls = impls;
63 }
64
65 pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
67 let mut unsolved = Vec::new();
69
70 for (t1, t2) in constraints.drain(..) {
71 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
72 unsolved.push((t1, t2));
74 }
75 }
76
77 let mut made_progress = true;
79 while made_progress && !unsolved.is_empty() {
80 made_progress = false;
81 let mut still_unsolved = Vec::new();
82
83 for (t1, t2) in unsolved.drain(..) {
84 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
85 still_unsolved.push((t1, t2));
86 } else {
87 made_progress = true;
88 }
89 }
90
91 unsolved = still_unsolved;
92 }
93
94 if !unsolved.is_empty() {
96 return Err(TypeError::UnsolvedConstraints(unsolved));
97 }
98
99 self.apply_bounds()?;
101
102 Ok(())
103 }
104
105 fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
107 let t1 = self.unifier.apply_substitutions(&t1);
111 let t2 = self.unifier.apply_substitutions(&t2);
112
113 match (&t1, &t2) {
114 (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
116
117 (Type::Constrained { var, constraint }, ty)
121 | (ty, Type::Constrained { var, constraint }) => {
122 self.bounds.insert(var.clone(), *constraint.clone());
124
125 self.solve_constraint(Type::Variable(var.clone()), ty.clone())
127 }
128
129 (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
130 if self.occurs_in(var, ty) {
132 return Err(TypeError::InfiniteType(var.clone()));
133 }
134
135 self.unifier.bind(var.clone(), ty.clone());
136 Ok(())
137 }
138
139 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
141 if self.unify_annotations(ann1, ann2)? {
142 Ok(())
143 } else if Self::can_numeric_widen(ann1, ann2) {
144 Ok(())
146 } else {
147 Err(TypeError::TypeMismatch(
148 format!("{:?}", ann1),
149 format!("{:?}", ann2),
150 ))
151 }
152 }
153
154 (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
156 self.solve_constraint(*b1.clone(), *b2.clone())?;
157
158 let is_result_base = |base: &Type| {
159 matches!(
160 base,
161 Type::Concrete(TypeAnnotation::Reference(name))
162 | Type::Concrete(TypeAnnotation::Basic(name))
163 if name == "Result"
164 )
165 };
166
167 if a1.len() != a2.len() {
168 if is_result_base(&b1) && is_result_base(&b2) {
169 match (a1.len(), a2.len()) {
170 (1, 2) | (2, 1) => {
173 self.solve_constraint(a1[0].clone(), a2[0].clone())?;
174 return Ok(());
175 }
176 _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
177 }
178 } else {
179 return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
180 }
181 }
182
183 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
184 self.solve_constraint(arg1.clone(), arg2.clone())?;
185 }
186
187 Ok(())
188 }
189
190 (
192 Type::Function {
193 params: p1,
194 returns: r1,
195 },
196 Type::Function {
197 params: p2,
198 returns: r2,
199 },
200 ) => {
201 if p1.len() != p2.len() {
202 return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
203 }
204 for (param1, param2) in p1.iter().zip(p2.iter()) {
205 self.solve_constraint(param2.clone(), param1.clone())?;
209 }
210 self.solve_constraint(*r1.clone(), *r2.clone())
211 }
212
213 (
215 Type::Function {
216 params: fp,
217 returns: fr,
218 },
219 Type::Concrete(TypeAnnotation::Function {
220 params: cp,
221 returns: cr,
222 }),
223 )
224 | (
225 Type::Concrete(TypeAnnotation::Function {
226 params: cp,
227 returns: cr,
228 }),
229 Type::Function {
230 params: fp,
231 returns: fr,
232 },
233 ) => {
234 if fp.len() != cp.len() {
235 return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
236 }
237 for (f_param, c_param) in fp.iter().zip(cp.iter()) {
238 self.solve_constraint(
239 f_param.clone(),
240 Type::Concrete(c_param.type_annotation.clone()),
241 )?;
242 }
243 self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
244 }
245
246 (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
248 | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
249 if args.len() == 1 && is_array_or_vec_base(base) =>
250 {
251 self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
252 }
253
254 _ => Err(TypeError::TypeMismatch(
255 format!("{:?}", t1),
256 format!("{:?}", t2),
257 )),
258 }
259 }
260
261 fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
263 match ty {
264 Type::Variable(v) => v == var,
265 Type::Generic { base, args } => {
266 self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
267 }
268 Type::Constrained { var: v, .. } => v == var,
269 Type::Function { params, returns } => {
270 params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
271 }
272 Type::Concrete(_) => false,
273 }
274 }
275
276 fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
283 let from_name = match from {
284 TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
285 _ => None,
286 };
287 let to_name = match to {
288 TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
289 _ => None,
290 };
291
292 match (from_name, to_name) {
293 (Some(f), Some(t)) => {
294 BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
295 }
296 _ => false,
297 }
298 }
299
300 fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
302 match (ann1, ann2) {
303 (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
305 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
306 }
307 (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
308 (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
309 | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
310 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
311 }
312
313 (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
315 self.unify_annotations(e1, e2)
316 }
317
318 (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
320 if t1.len() != t2.len() {
321 return Ok(false);
322 }
323
324 for (elem1, elem2) in t1.iter().zip(t2.iter()) {
325 if !self.unify_annotations(elem1, elem2)? {
326 return Ok(false);
327 }
328 }
329
330 Ok(true)
331 }
332
333 (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
335 self.object_fields_compatible(f1, f2)
336 }
337
338 (
340 TypeAnnotation::Function {
341 params: p1,
342 returns: r1,
343 },
344 TypeAnnotation::Function {
345 params: p2,
346 returns: r2,
347 },
348 ) => {
349 if p1.len() != p2.len() {
350 return Ok(false);
351 }
352
353 for (param1, param2) in p1.iter().zip(p2.iter()) {
354 if !self.unify_annotations(¶m1.type_annotation, ¶m2.type_annotation)? {
355 return Ok(false);
356 }
357 }
358
359 self.unify_annotations(r1, r2)
360 }
361
362 (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
365 for t1 in u1 {
367 let mut found_match = false;
368 for t2 in u2 {
369 if self.unify_annotations(t1, t2)? {
370 found_match = true;
371 break;
372 }
373 }
374 if !found_match {
375 return Ok(false);
376 }
377 }
378 for t2 in u2 {
380 let mut found_match = false;
381 for t1 in u1 {
382 if self.unify_annotations(t1, t2)? {
383 found_match = true;
384 break;
385 }
386 }
387 if !found_match {
388 return Ok(false);
389 }
390 }
391 Ok(true)
392 }
393
394 (TypeAnnotation::Union(union_types), other)
396 | (other, TypeAnnotation::Union(union_types)) => {
397 for union_type in union_types {
398 if self.unify_annotations(union_type, other)? {
399 return Ok(true);
400 }
401 }
402 Ok(false)
403 }
404
405 (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
407 self.unify_annotation_sets(i1, i2)
408 }
409
410 (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
412 (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
413 (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
414
415 (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
418 Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
419 }
420
421 (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
423 | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
424 if name == "Array" && args.len() == 1 =>
425 {
426 self.unify_annotations(&args[0], elem)
427 }
428
429 _ => Ok(false),
431 }
432 }
433
434 fn object_fields_compatible(
435 &self,
436 left: &[ObjectTypeField],
437 right: &[ObjectTypeField],
438 ) -> TypeResult<bool> {
439 for left_field in left {
440 let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
441 return Ok(false);
442 };
443 if left_field.optional != right_field.optional {
444 return Ok(false);
445 }
446 if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
447 return Ok(false);
448 }
449 }
450 if left.len() != right.len() {
451 return Ok(false);
452 }
453 Ok(true)
454 }
455
456 fn unify_annotation_sets(
457 &self,
458 left: &[TypeAnnotation],
459 right: &[TypeAnnotation],
460 ) -> TypeResult<bool> {
461 if left.len() != right.len() {
462 return Ok(false);
463 }
464
465 let mut matched = vec![false; right.len()];
466 for left_ann in left {
467 let mut found = false;
468 for (idx, right_ann) in right.iter().enumerate() {
469 if matched[idx] {
470 continue;
471 }
472 if self.unify_annotations(left_ann, right_ann)? {
473 matched[idx] = true;
474 found = true;
475 break;
476 }
477 }
478 if !found {
479 return Ok(false);
480 }
481 }
482
483 Ok(true)
484 }
485
486 fn apply_bounds(&mut self) -> TypeResult<()> {
492 let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
493
494 for (var, constraint) in &self.bounds {
495 let resolved = self
498 .unifier
499 .apply_substitutions(&Type::Variable(var.clone()));
500
501 if let Type::Variable(_) = &resolved {
502 continue;
504 }
505
506 self.check_constraint(&resolved, constraint)?;
507
508 if let TypeConstraint::HasField(field, expected_field_type) = constraint {
511 if let Type::Variable(field_var) = expected_field_type.as_ref() {
512 let field_resolved = self
514 .unifier
515 .apply_substitutions(&Type::Variable(field_var.clone()));
516 if let Type::Variable(_) = &field_resolved {
517 if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
519 if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
520 new_bindings.push((
521 field_var.clone(),
522 Type::Concrete(found_field.type_annotation.clone()),
523 ));
524 }
525 }
526 }
527 }
528 }
529 }
530
531 for (var, ty) in new_bindings {
533 self.unifier.bind(var, ty);
534 }
535
536 Ok(())
537 }
538
539 fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
541 match constraint {
542 TypeConstraint::Numeric => match ty {
543 Type::Concrete(TypeAnnotation::Basic(name))
544 if BuiltinTypes::is_numeric_type_name(name) =>
545 {
546 Ok(())
547 }
548 _ => Err(TypeError::ConstraintViolation(format!(
549 "{:?} is not numeric",
550 ty
551 ))),
552 },
553
554 TypeConstraint::Comparable => match ty {
555 Type::Concrete(TypeAnnotation::Basic(name))
556 if BuiltinTypes::is_numeric_type_name(name)
557 || name == "string"
558 || name == "bool" =>
559 {
560 Ok(())
561 }
562 _ => Err(TypeError::ConstraintViolation(format!(
563 "{:?} is not comparable",
564 ty
565 ))),
566 },
567
568 TypeConstraint::Iterable => match ty {
569 Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
570 Type::Concrete(TypeAnnotation::Basic(name))
571 if name == "string" || name == "rows" =>
572 {
573 Ok(())
574 }
575 _ => Err(TypeError::ConstraintViolation(format!(
576 "{:?} is not iterable",
577 ty
578 ))),
579 },
580
581 TypeConstraint::HasField(field, expected_field_type) => {
582 match ty {
583 Type::Concrete(TypeAnnotation::Object(fields)) => {
584 match fields.iter().find(|f| f.name == *field) {
585 Some(found_field) => {
586 if let Some(expected_ann) = expected_field_type.to_annotation() {
588 if self.unify_annotations(
589 &found_field.type_annotation,
590 &expected_ann,
591 )? {
592 Ok(())
593 } else {
594 Err(TypeError::ConstraintViolation(format!(
595 "field '{}' has type {:?}, expected {:?}",
596 field, found_field.type_annotation, expected_ann
597 )))
598 }
599 } else {
600 Ok(())
602 }
603 }
604 None => Err(TypeError::ConstraintViolation(format!(
605 "{:?} does not have field '{}'",
606 ty, field
607 ))),
608 }
609 }
610 Type::Concrete(TypeAnnotation::Basic(_name)) => {
611 Ok(())
620 }
621 _ => Err(TypeError::ConstraintViolation(format!(
622 "{:?} cannot have fields",
623 ty
624 ))),
625 }
626 }
627
628 TypeConstraint::Callable {
629 params: expected_params,
630 returns: expected_returns,
631 } => {
632 match ty {
633 Type::Concrete(TypeAnnotation::Function {
634 params: actual_params,
635 returns: actual_returns,
636 }) => {
637 if expected_params.len() != actual_params.len() {
639 return Err(TypeError::ConstraintViolation(format!(
640 "function expects {} parameters, got {}",
641 expected_params.len(),
642 actual_params.len()
643 )));
644 }
645
646 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
648 if let Some(expected_ann) = expected.to_annotation() {
649 if !self
650 .unify_annotations(&expected_ann, &actual.type_annotation)?
651 {
652 return Err(TypeError::ConstraintViolation(format!(
653 "parameter type mismatch: expected {:?}, got {:?}",
654 expected_ann, actual.type_annotation
655 )));
656 }
657 }
658 }
659
660 if let Some(expected_ret_ann) = expected_returns.to_annotation() {
662 if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
663 return Err(TypeError::ConstraintViolation(format!(
664 "return type mismatch: expected {:?}, got {:?}",
665 expected_ret_ann, actual_returns
666 )));
667 }
668 }
669
670 Ok(())
671 }
672 Type::Function {
673 params: actual_params,
674 returns: actual_returns,
675 } => {
676 if expected_params.len() != actual_params.len() {
677 return Err(TypeError::ConstraintViolation(format!(
678 "function expects {} parameters, got {}",
679 expected_params.len(),
680 actual_params.len()
681 )));
682 }
683 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
685 if let (Some(e_ann), Some(a_ann)) =
686 (expected.to_annotation(), actual.to_annotation())
687 {
688 if !self.unify_annotations(&e_ann, &a_ann)? {
689 return Err(TypeError::ConstraintViolation(format!(
690 "parameter type mismatch: expected {:?}, got {:?}",
691 e_ann, a_ann
692 )));
693 }
694 }
695 }
696 if let (Some(e_ret), Some(a_ret)) = (
697 expected_returns.to_annotation(),
698 actual_returns.to_annotation(),
699 ) {
700 if !self.unify_annotations(&a_ret, &e_ret)? {
701 return Err(TypeError::ConstraintViolation(format!(
702 "return type mismatch: expected {:?}, got {:?}",
703 e_ret, a_ret
704 )));
705 }
706 }
707 Ok(())
708 }
709 _ => Err(TypeError::ConstraintViolation(format!(
710 "{:?} is not callable",
711 ty
712 ))),
713 }
714 }
715
716 TypeConstraint::OneOf(options) => {
717 for option in options {
718 if let Type::Concrete(ann) = option {
720 if let Type::Concrete(ty_ann) = ty {
721 if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
722 return Ok(());
723 }
724 }
725 }
726 }
727
728 Err(TypeError::ConstraintViolation(format!(
729 "{:?} does not match any of {:?}",
730 ty, options
731 )))
732 }
733
734 TypeConstraint::Extends(base) => {
735 self.is_subtype(ty, base)
737 }
738
739 TypeConstraint::ImplementsTrait { trait_name } => {
740 match ty {
741 Type::Variable(_) => {
742 Err(TypeError::TraitBoundViolation {
745 type_name: format!("{:?}", ty),
746 trait_name: trait_name.clone(),
747 })
748 }
749 Type::Concrete(ann) => {
750 let type_name = match ann {
751 TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => n.clone(),
752 _ => format!("{:?}", ann),
753 };
754 if self.has_trait_impl(trait_name, &type_name) {
755 Ok(())
756 } else {
757 Err(TypeError::TraitBoundViolation {
758 type_name,
759 trait_name: trait_name.clone(),
760 })
761 }
762 }
763 Type::Generic { base, .. } => {
764 let type_name = if let Type::Concrete(
765 TypeAnnotation::Reference(n) | TypeAnnotation::Basic(n),
766 ) = base.as_ref()
767 {
768 n.clone()
769 } else {
770 format!("{:?}", base)
771 };
772 if self.has_trait_impl(trait_name, &type_name) {
773 Ok(())
774 } else {
775 Err(TypeError::TraitBoundViolation {
776 type_name,
777 trait_name: trait_name.clone(),
778 })
779 }
780 }
781 _ => Err(TypeError::TraitBoundViolation {
782 type_name: format!("{:?}", ty),
783 trait_name: trait_name.clone(),
784 }),
785 }
786 }
787
788 TypeConstraint::HasMethod {
789 method_name,
790 arg_types: _,
791 return_type: _,
792 } => {
793 if let Some(method_table) = &self.method_table {
795 match ty {
796 Type::Variable(_) => Ok(()), Type::Concrete(ann) => {
798 let type_name = match ann {
799 TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => {
800 n.clone()
801 }
802 TypeAnnotation::Array(_) => "Vec".to_string(),
803 _ => return Ok(()), };
805 if method_table.lookup(ty, method_name).is_some() {
806 Ok(())
807 } else {
808 Err(TypeError::MethodNotFound {
809 type_name,
810 method_name: method_name.clone(),
811 })
812 }
813 }
814 Type::Generic { base, .. } => {
815 if method_table.lookup(ty, method_name).is_some() {
816 Ok(())
817 } else {
818 let type_name =
819 if let Type::Concrete(TypeAnnotation::Reference(n)) =
820 base.as_ref()
821 {
822 n.clone()
823 } else {
824 format!("{:?}", base)
825 };
826 Err(TypeError::MethodNotFound {
827 type_name,
828 method_name: method_name.clone(),
829 })
830 }
831 }
832 _ => Ok(()), }
834 } else {
835 Ok(())
837 }
838 }
839 }
840 }
841
842 fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
847 let key = format!("{}::{}", trait_name, type_name);
848 if self.trait_impls.contains(&key) {
849 return true;
850 }
851 if BuiltinTypes::is_integer_type_name(type_name) {
853 for widen_to in &["number", "float", "f64"] {
854 let widen_key = format!("{}::{}", trait_name, widen_to);
855 if self.trait_impls.contains(&widen_key) {
856 return true;
857 }
858 }
859 }
860 false
861 }
862
863 fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
870 match (ty, base) {
871 (t1, t2) if t1 == t2 => Ok(()),
873
874 (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
876
877 (
879 Type::Concrete(TypeAnnotation::Array(elem1)),
880 Type::Concrete(TypeAnnotation::Array(elem2)),
881 ) => {
882 let t1 = Type::Concrete(*elem1.clone());
883 let t2 = Type::Concrete(*elem2.clone());
884 self.is_subtype(&t1, &t2)
885 }
886
887 (
889 Type::Concrete(TypeAnnotation::Function {
890 params: p1,
891 returns: r1,
892 }),
893 Type::Concrete(TypeAnnotation::Function {
894 params: p2,
895 returns: r2,
896 }),
897 ) => {
898 if p1.len() != p2.len() {
900 return Err(TypeError::ConstraintViolation(format!(
901 "function parameter count mismatch: {} vs {}",
902 p1.len(),
903 p2.len()
904 )));
905 }
906
907 for (param1, param2) in p1.iter().zip(p2.iter()) {
909 let t1 = Type::Concrete(param2.type_annotation.clone());
910 let t2 = Type::Concrete(param1.type_annotation.clone());
911 self.is_subtype(&t1, &t2)?;
912 }
913
914 let ret1 = Type::Concrete(*r1.clone());
916 let ret2 = Type::Concrete(*r2.clone());
917 self.is_subtype(&ret1, &ret2)
918 }
919
920 (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
922 if name == "Option" && args.len() == 1 =>
923 {
924 let inner = Type::Concrete(args[0].clone());
925 self.is_subtype(t, &inner)
926 }
927
928 (
930 Type::Function {
931 params: p1,
932 returns: r1,
933 },
934 Type::Function {
935 params: p2,
936 returns: r2,
937 },
938 ) => {
939 if p1.len() != p2.len() {
940 return Err(TypeError::ConstraintViolation(format!(
941 "function parameter count mismatch: {} vs {}",
942 p1.len(),
943 p2.len()
944 )));
945 }
946 for (param1, param2) in p1.iter().zip(p2.iter()) {
948 self.is_subtype(param2, param1)?;
949 }
950 self.is_subtype(r1, r2)
952 }
953
954 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
956 if self.unify_annotations(ann1, ann2)? {
957 Ok(())
958 } else {
959 Err(TypeError::ConstraintViolation(format!(
960 "{:?} is not a subtype of {:?}",
961 ty, base
962 )))
963 }
964 }
965
966 _ => Err(TypeError::ConstraintViolation(format!(
968 "{:?} is not a subtype of {:?}",
969 ty, base
970 ))),
971 }
972 }
973
974 pub fn unifier(&self) -> &Unifier {
976 &self.unifier
977 }
978}
979
980#[cfg(test)]
981mod tests {
982 use super::*;
983 use shape_ast::ast::ObjectTypeField;
984
985 #[test]
986 fn test_hasfield_backward_propagation_binds_field_type() {
987 let mut solver = ConstraintSolver::new();
991
992 let obj_var = TypeVar::fresh();
993 let field_result_var = TypeVar::fresh();
994 let bound_var = TypeVar::fresh();
995
996 let mut constraints = vec![
997 (
1001 Type::Variable(obj_var.clone()),
1002 Type::Constrained {
1003 var: bound_var,
1004 constraint: Box::new(TypeConstraint::HasField(
1005 "x".to_string(),
1006 Box::new(Type::Variable(field_result_var.clone())),
1007 )),
1008 },
1009 ),
1010 (
1012 Type::Variable(obj_var),
1013 Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1014 name: "x".to_string(),
1015 optional: false,
1016 type_annotation: TypeAnnotation::Basic("int".to_string()),
1017 annotations: vec![],
1018 }])),
1019 ),
1020 ];
1021
1022 solver.solve(&mut constraints).unwrap();
1023
1024 let resolved = solver
1026 .unifier()
1027 .apply_substitutions(&Type::Variable(field_result_var));
1028 match &resolved {
1029 Type::Concrete(TypeAnnotation::Basic(name)) => {
1030 assert_eq!(name, "int", "field type should be int");
1031 }
1032 _ => panic!(
1033 "Expected field_result_var to be resolved to int, got {:?}",
1034 resolved
1035 ),
1036 }
1037 }
1038
1039 #[test]
1040 fn test_hasfield_backward_propagation_multiple_fields() {
1041 let mut solver = ConstraintSolver::new();
1043
1044 let obj_var = TypeVar::fresh();
1045 let field_x_var = TypeVar::fresh();
1046 let field_y_var = TypeVar::fresh();
1047 let bound_var_x = TypeVar::fresh();
1048 let bound_var_y = TypeVar::fresh();
1049
1050 let mut constraints = vec![
1051 (
1053 Type::Variable(obj_var.clone()),
1054 Type::Constrained {
1055 var: bound_var_x,
1056 constraint: Box::new(TypeConstraint::HasField(
1057 "x".to_string(),
1058 Box::new(Type::Variable(field_x_var.clone())),
1059 )),
1060 },
1061 ),
1062 (
1064 Type::Variable(obj_var.clone()),
1065 Type::Constrained {
1066 var: bound_var_y,
1067 constraint: Box::new(TypeConstraint::HasField(
1068 "y".to_string(),
1069 Box::new(Type::Variable(field_y_var.clone())),
1070 )),
1071 },
1072 ),
1073 (
1075 Type::Variable(obj_var),
1076 Type::Concrete(TypeAnnotation::Object(vec![
1077 ObjectTypeField {
1078 name: "x".to_string(),
1079 optional: false,
1080 type_annotation: TypeAnnotation::Basic("int".to_string()),
1081 annotations: vec![],
1082 },
1083 ObjectTypeField {
1084 name: "y".to_string(),
1085 optional: false,
1086 type_annotation: TypeAnnotation::Basic("string".to_string()),
1087 annotations: vec![],
1088 },
1089 ])),
1090 ),
1091 ];
1092
1093 solver.solve(&mut constraints).unwrap();
1094
1095 let resolved_x = solver
1096 .unifier()
1097 .apply_substitutions(&Type::Variable(field_x_var));
1098 let resolved_y = solver
1099 .unifier()
1100 .apply_substitutions(&Type::Variable(field_y_var));
1101
1102 match &resolved_x {
1103 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1104 _ => panic!("Expected x to be int, got {:?}", resolved_x),
1105 }
1106 match &resolved_y {
1107 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1108 _ => panic!("Expected y to be string, got {:?}", resolved_y),
1109 }
1110 }
1111
1112 #[test]
1115 fn test_int_constrained_numeric_succeeds() {
1116 let mut solver = ConstraintSolver::new();
1118 let bound_var = TypeVar::fresh();
1119 let mut constraints = vec![(
1120 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1121 Type::Constrained {
1122 var: bound_var,
1123 constraint: Box::new(TypeConstraint::Numeric),
1124 },
1125 )];
1126 assert!(solver.solve(&mut constraints).is_ok());
1127 }
1128
1129 #[test]
1130 fn test_numeric_widening_int_to_number() {
1131 let mut solver = ConstraintSolver::new();
1133 let mut constraints = vec![(
1134 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1135 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1136 )];
1137 assert!(solver.solve(&mut constraints).is_ok());
1138 }
1139
1140 #[test]
1141 fn test_numeric_widening_width_aware_integer_to_float_family() {
1142 let mut solver = ConstraintSolver::new();
1143 let mut constraints = vec![(
1144 Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1145 Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1146 )];
1147 assert!(solver.solve(&mut constraints).is_ok());
1148 }
1149
1150 #[test]
1151 fn test_no_widening_number_to_int() {
1152 let mut solver = ConstraintSolver::new();
1154 let mut constraints = vec![(
1155 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1156 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1157 )];
1158 assert!(solver.solve(&mut constraints).is_err());
1159 }
1160
1161 #[test]
1162 fn test_decimal_constrained_numeric_succeeds() {
1163 let mut solver = ConstraintSolver::new();
1164 let bound_var = TypeVar::fresh();
1165 let mut constraints = vec![(
1166 Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1167 Type::Constrained {
1168 var: bound_var,
1169 constraint: Box::new(TypeConstraint::Numeric),
1170 },
1171 )];
1172 assert!(solver.solve(&mut constraints).is_ok());
1173 }
1174
1175 #[test]
1176 fn test_comparable_accepts_int() {
1177 let mut solver = ConstraintSolver::new();
1179 let bound_var = TypeVar::fresh();
1180 let mut constraints = vec![(
1181 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1182 Type::Constrained {
1183 var: bound_var,
1184 constraint: Box::new(TypeConstraint::Comparable),
1185 },
1186 )];
1187 assert!(solver.solve(&mut constraints).is_ok());
1188 }
1189
1190 #[test]
1193 fn test_function_type_preserves_variables() {
1194 let param = Type::Variable(TypeVar::fresh());
1196 let ret = Type::Variable(TypeVar::fresh());
1197 let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1198 match func {
1199 Type::Function { params, returns } => {
1200 assert_eq!(params.len(), 1);
1201 assert_eq!(params[0], param);
1202 assert_eq!(*returns, ret);
1203 }
1204 _ => panic!("Expected Type::Function, got {:?}", func),
1205 }
1206 }
1207
1208 #[test]
1209 fn test_function_unification_binds_variables() {
1210 let mut solver = ConstraintSolver::new();
1212 let t1 = TypeVar::fresh();
1213 let t2 = TypeVar::fresh();
1214
1215 let mut constraints = vec![(
1216 Type::Function {
1217 params: vec![Type::Variable(t1.clone())],
1218 returns: Box::new(Type::Variable(t2.clone())),
1219 },
1220 Type::Function {
1221 params: vec![BuiltinTypes::number()],
1222 returns: Box::new(BuiltinTypes::string()),
1223 },
1224 )];
1225
1226 solver.solve(&mut constraints).unwrap();
1227
1228 let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1229 let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1230 assert_eq!(resolved_t1, BuiltinTypes::number());
1231 assert_eq!(resolved_t2, BuiltinTypes::string());
1232 }
1233
1234 #[test]
1235 fn test_function_cross_unification_with_concrete() {
1236 let mut solver = ConstraintSolver::new();
1238 let t1 = TypeVar::fresh();
1239
1240 let concrete_func = Type::Concrete(TypeAnnotation::Function {
1241 params: vec![shape_ast::ast::FunctionParam {
1242 name: None,
1243 optional: false,
1244 type_annotation: TypeAnnotation::Basic("number".to_string()),
1245 }],
1246 returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1247 });
1248
1249 let mut constraints = vec![(
1250 Type::Function {
1251 params: vec![Type::Variable(t1.clone())],
1252 returns: Box::new(BuiltinTypes::string()),
1253 },
1254 concrete_func,
1255 )];
1256
1257 solver.solve(&mut constraints).unwrap();
1258
1259 let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1260 assert_eq!(resolved, BuiltinTypes::number());
1261 }
1262
1263 #[test]
1264 fn test_object_annotations_unify_structurally() {
1265 let mut solver = ConstraintSolver::new();
1266 let mut constraints = vec![(
1267 Type::Concrete(TypeAnnotation::Object(vec![
1268 ObjectTypeField {
1269 name: "x".to_string(),
1270 optional: false,
1271 type_annotation: TypeAnnotation::Basic("int".to_string()),
1272 annotations: vec![],
1273 },
1274 ObjectTypeField {
1275 name: "y".to_string(),
1276 optional: false,
1277 type_annotation: TypeAnnotation::Basic("int".to_string()),
1278 annotations: vec![],
1279 },
1280 ])),
1281 Type::Concrete(TypeAnnotation::Object(vec![
1282 ObjectTypeField {
1283 name: "x".to_string(),
1284 optional: false,
1285 type_annotation: TypeAnnotation::Basic("int".to_string()),
1286 annotations: vec![],
1287 },
1288 ObjectTypeField {
1289 name: "y".to_string(),
1290 optional: false,
1291 type_annotation: TypeAnnotation::Basic("int".to_string()),
1292 annotations: vec![],
1293 },
1294 ])),
1295 )];
1296 assert!(solver.solve(&mut constraints).is_ok());
1297 }
1298
1299 #[test]
1300 fn test_intersection_annotations_unify_order_independent() {
1301 let mut solver = ConstraintSolver::new();
1302 let obj_xy = TypeAnnotation::Object(vec![
1303 ObjectTypeField {
1304 name: "x".to_string(),
1305 optional: false,
1306 type_annotation: TypeAnnotation::Basic("int".to_string()),
1307 annotations: vec![],
1308 },
1309 ObjectTypeField {
1310 name: "y".to_string(),
1311 optional: false,
1312 type_annotation: TypeAnnotation::Basic("int".to_string()),
1313 annotations: vec![],
1314 },
1315 ]);
1316 let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1317 name: "z".to_string(),
1318 optional: false,
1319 type_annotation: TypeAnnotation::Basic("int".to_string()),
1320 annotations: vec![],
1321 }]);
1322
1323 let mut constraints = vec![(
1324 Type::Concrete(TypeAnnotation::Intersection(vec![
1325 obj_xy.clone(),
1326 obj_z.clone(),
1327 ])),
1328 Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1329 )];
1330 assert!(solver.solve(&mut constraints).is_ok());
1331 }
1332
1333 #[test]
1336 fn test_implements_trait_satisfied() {
1337 let mut solver = ConstraintSolver::new();
1338 let mut impls = std::collections::HashSet::new();
1339 impls.insert("Comparable::number".to_string());
1340 solver.set_trait_impls(impls);
1341
1342 let bound_var = TypeVar::fresh();
1343 let mut constraints = vec![(
1344 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1345 Type::Constrained {
1346 var: bound_var,
1347 constraint: Box::new(TypeConstraint::ImplementsTrait {
1348 trait_name: "Comparable".to_string(),
1349 }),
1350 },
1351 )];
1352 assert!(solver.solve(&mut constraints).is_ok());
1353 }
1354
1355 #[test]
1356 fn test_implements_trait_violated() {
1357 let mut solver = ConstraintSolver::new();
1358 let bound_var = TypeVar::fresh();
1360 let mut constraints = vec![(
1361 Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1362 Type::Constrained {
1363 var: bound_var,
1364 constraint: Box::new(TypeConstraint::ImplementsTrait {
1365 trait_name: "Comparable".to_string(),
1366 }),
1367 },
1368 )];
1369 let result = solver.solve(&mut constraints);
1370 assert!(result.is_err());
1371 match result.unwrap_err() {
1372 TypeError::TraitBoundViolation {
1373 type_name,
1374 trait_name,
1375 } => {
1376 assert_eq!(type_name, "string");
1377 assert_eq!(trait_name, "Comparable");
1378 }
1379 other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1380 }
1381 }
1382
1383 #[test]
1384 fn test_implements_trait_via_variable_resolution() {
1385 let mut solver = ConstraintSolver::new();
1386 let mut impls = std::collections::HashSet::new();
1387 impls.insert("Sortable::number".to_string());
1388 solver.set_trait_impls(impls);
1389
1390 let type_var = TypeVar::fresh();
1391 let bound_var = TypeVar::fresh();
1392
1393 let mut constraints = vec![
1394 (
1396 Type::Variable(type_var.clone()),
1397 Type::Constrained {
1398 var: bound_var,
1399 constraint: Box::new(TypeConstraint::ImplementsTrait {
1400 trait_name: "Sortable".to_string(),
1401 }),
1402 },
1403 ),
1404 (
1406 Type::Variable(type_var),
1407 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1408 ),
1409 ];
1410 assert!(
1411 solver.solve(&mut constraints).is_ok(),
1412 "T resolved to number which implements Sortable"
1413 );
1414 }
1415}