1use super::checking::MethodTable;
32use super::unification::Unifier;
33use super::*;
34use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
35use std::collections::{HashMap, HashSet};
36
37fn is_array_or_vec_base(base: &Type) -> bool {
39 match base {
40 Type::Concrete(TypeAnnotation::Reference(name)) => name == "Array" || name == "Vec",
41 Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
42 _ => false,
43 }
44}
45
46pub struct ConstraintSolver {
47 unifier: Unifier,
49 _deferred: Vec<(Type, Type)>,
52 bounds: HashMap<TypeVar, TypeConstraint>,
54 method_table: Option<MethodTable>,
56 trait_impls: HashSet<String>,
58}
59
60impl Default for ConstraintSolver {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ConstraintSolver {
67 pub fn new() -> Self {
68 ConstraintSolver {
69 unifier: Unifier::new(),
70 _deferred: Vec::new(),
71 bounds: HashMap::new(),
72 method_table: None,
73 trait_impls: HashSet::new(),
74 }
75 }
76
77 pub fn set_method_table(&mut self, table: MethodTable) {
81 self.method_table = Some(table);
82 }
83
84 pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
87 self.trait_impls = impls;
88 }
89
90 pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
92 let mut unsolved = Vec::new();
94
95 for (t1, t2) in constraints.drain(..) {
96 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
97 unsolved.push((t1, t2));
99 }
100 }
101
102 let mut made_progress = true;
104 while made_progress && !unsolved.is_empty() {
105 made_progress = false;
106 let mut still_unsolved = Vec::new();
107
108 for (t1, t2) in unsolved.drain(..) {
109 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
110 still_unsolved.push((t1, t2));
111 } else {
112 made_progress = true;
113 }
114 }
115
116 unsolved = still_unsolved;
117 }
118
119 if !unsolved.is_empty() {
121 return Err(TypeError::UnsolvedConstraints(unsolved));
122 }
123
124 self.apply_bounds()?;
126
127 Ok(())
128 }
129
130 fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
132 let t1 = self.unifier.apply_substitutions(&t1);
136 let t2 = self.unifier.apply_substitutions(&t2);
137
138 match (&t1, &t2) {
139 (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
141
142 (Type::Constrained { var, constraint }, ty)
146 | (ty, Type::Constrained { var, constraint }) => {
147 self.bounds.insert(var.clone(), *constraint.clone());
149
150 self.solve_constraint(Type::Variable(var.clone()), ty.clone())
152 }
153
154 (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
155 if self.occurs_in(var, ty) {
157 return Err(TypeError::InfiniteType(var.clone()));
158 }
159
160 self.unifier.bind(var.clone(), ty.clone());
161 Ok(())
162 }
163
164 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
166 if self.unify_annotations(ann1, ann2)? {
167 Ok(())
168 } else if Self::can_numeric_widen(ann1, ann2) {
169 Ok(())
171 } else {
172 Err(TypeError::TypeMismatch(
173 format!("{:?}", ann1),
174 format!("{:?}", ann2),
175 ))
176 }
177 }
178
179 (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
181 self.solve_constraint(*b1.clone(), *b2.clone())?;
182
183 let is_result_base = |base: &Type| match base {
184 Type::Concrete(TypeAnnotation::Reference(name)) => name == "Result",
185 Type::Concrete(TypeAnnotation::Basic(name)) => name == "Result",
186 _ => false,
187 };
188
189 if a1.len() != a2.len() {
190 if is_result_base(&b1) && is_result_base(&b2) {
191 match (a1.len(), a2.len()) {
192 (1, 2) | (2, 1) => {
195 self.solve_constraint(a1[0].clone(), a2[0].clone())?;
196 return Ok(());
197 }
198 _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
199 }
200 } else {
201 return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
202 }
203 }
204
205 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
206 self.solve_constraint(arg1.clone(), arg2.clone())?;
207 }
208
209 Ok(())
210 }
211
212 (
214 Type::Function {
215 params: p1,
216 returns: r1,
217 },
218 Type::Function {
219 params: p2,
220 returns: r2,
221 },
222 ) => {
223 if p1.len() != p2.len() {
224 return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
225 }
226 for (param1, param2) in p1.iter().zip(p2.iter()) {
227 self.solve_constraint(param2.clone(), param1.clone())?;
231 }
232 self.solve_constraint(*r1.clone(), *r2.clone())
233 }
234
235 (
237 Type::Function {
238 params: fp,
239 returns: fr,
240 },
241 Type::Concrete(TypeAnnotation::Function {
242 params: cp,
243 returns: cr,
244 }),
245 )
246 | (
247 Type::Concrete(TypeAnnotation::Function {
248 params: cp,
249 returns: cr,
250 }),
251 Type::Function {
252 params: fp,
253 returns: fr,
254 },
255 ) => {
256 if fp.len() != cp.len() {
257 return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
258 }
259 for (f_param, c_param) in fp.iter().zip(cp.iter()) {
260 self.solve_constraint(
261 f_param.clone(),
262 Type::Concrete(c_param.type_annotation.clone()),
263 )?;
264 }
265 self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
266 }
267
268 (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
270 | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
271 if args.len() == 1 && is_array_or_vec_base(base) =>
272 {
273 self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
274 }
275
276 _ => Err(TypeError::TypeMismatch(
277 format!("{:?}", t1),
278 format!("{:?}", t2),
279 )),
280 }
281 }
282
283 fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
285 match ty {
286 Type::Variable(v) => v == var,
287 Type::Generic { base, args } => {
288 self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
289 }
290 Type::Constrained { var: v, .. } => v == var,
291 Type::Function { params, returns } => {
292 params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
293 }
294 Type::Concrete(_) => false,
295 }
296 }
297
298 fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
305 let from_name = match from {
306 TypeAnnotation::Basic(name) => Some(name.as_str()),
307 TypeAnnotation::Reference(name) => Some(name.as_str()),
308 _ => None,
309 };
310 let to_name = match to {
311 TypeAnnotation::Basic(name) => Some(name.as_str()),
312 TypeAnnotation::Reference(name) => Some(name.as_str()),
313 _ => None,
314 };
315
316 match (from_name, to_name) {
317 (Some(f), Some(t)) => {
318 BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
319 }
320 _ => false,
321 }
322 }
323
324 fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
326 match (ann1, ann2) {
327 (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
329 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
330 }
331 (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
332 (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
333 | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
334 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
335 }
336
337 (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
339 self.unify_annotations(e1, e2)
340 }
341
342 (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
344 if t1.len() != t2.len() {
345 return Ok(false);
346 }
347
348 for (elem1, elem2) in t1.iter().zip(t2.iter()) {
349 if !self.unify_annotations(elem1, elem2)? {
350 return Ok(false);
351 }
352 }
353
354 Ok(true)
355 }
356
357 (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
359 self.object_fields_compatible(f1, f2)
360 }
361
362 (
364 TypeAnnotation::Function {
365 params: p1,
366 returns: r1,
367 },
368 TypeAnnotation::Function {
369 params: p2,
370 returns: r2,
371 },
372 ) => {
373 if p1.len() != p2.len() {
374 return Ok(false);
375 }
376
377 for (param1, param2) in p1.iter().zip(p2.iter()) {
378 if !self.unify_annotations(¶m1.type_annotation, ¶m2.type_annotation)? {
379 return Ok(false);
380 }
381 }
382
383 self.unify_annotations(r1, r2)
384 }
385
386 (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
389 for t1 in u1 {
391 let mut found_match = false;
392 for t2 in u2 {
393 if self.unify_annotations(t1, t2)? {
394 found_match = true;
395 break;
396 }
397 }
398 if !found_match {
399 return Ok(false);
400 }
401 }
402 for t2 in u2 {
404 let mut found_match = false;
405 for t1 in u1 {
406 if self.unify_annotations(t1, t2)? {
407 found_match = true;
408 break;
409 }
410 }
411 if !found_match {
412 return Ok(false);
413 }
414 }
415 Ok(true)
416 }
417
418 (TypeAnnotation::Union(union_types), other)
420 | (other, TypeAnnotation::Union(union_types)) => {
421 for union_type in union_types {
422 if self.unify_annotations(union_type, other)? {
423 return Ok(true);
424 }
425 }
426 Ok(false)
427 }
428
429 (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
431 self.unify_annotation_sets(i1, i2)
432 }
433
434 (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
436 (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
437 (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
438
439 (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
442 Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
443 }
444
445 (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
447 | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
448 if name == "Array" && args.len() == 1 =>
449 {
450 self.unify_annotations(&args[0], elem)
451 }
452
453 _ => Ok(false),
455 }
456 }
457
458 fn object_fields_compatible(
459 &self,
460 left: &[ObjectTypeField],
461 right: &[ObjectTypeField],
462 ) -> TypeResult<bool> {
463 for left_field in left {
464 let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
465 return Ok(false);
466 };
467 if left_field.optional != right_field.optional {
468 return Ok(false);
469 }
470 if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
471 return Ok(false);
472 }
473 }
474 if left.len() != right.len() {
475 return Ok(false);
476 }
477 Ok(true)
478 }
479
480 fn unify_annotation_sets(
481 &self,
482 left: &[TypeAnnotation],
483 right: &[TypeAnnotation],
484 ) -> TypeResult<bool> {
485 if left.len() != right.len() {
486 return Ok(false);
487 }
488
489 let mut matched = vec![false; right.len()];
490 for left_ann in left {
491 let mut found = false;
492 for (idx, right_ann) in right.iter().enumerate() {
493 if matched[idx] {
494 continue;
495 }
496 if self.unify_annotations(left_ann, right_ann)? {
497 matched[idx] = true;
498 found = true;
499 break;
500 }
501 }
502 if !found {
503 return Ok(false);
504 }
505 }
506
507 Ok(true)
508 }
509
510 fn apply_bounds(&mut self) -> TypeResult<()> {
516 let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
517
518 for (var, constraint) in &self.bounds {
519 let resolved = self
522 .unifier
523 .apply_substitutions(&Type::Variable(var.clone()));
524
525 if let Type::Variable(_) = &resolved {
526 continue;
528 }
529
530 self.check_constraint(&resolved, constraint)?;
531
532 if let TypeConstraint::HasField(field, expected_field_type) = constraint {
535 if let Type::Variable(field_var) = expected_field_type.as_ref() {
536 let field_resolved = self
538 .unifier
539 .apply_substitutions(&Type::Variable(field_var.clone()));
540 if let Type::Variable(_) = &field_resolved {
541 if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
543 if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
544 new_bindings.push((
545 field_var.clone(),
546 Type::Concrete(found_field.type_annotation.clone()),
547 ));
548 }
549 }
550 }
551 }
552 }
553
554 if let TypeConstraint::Indexable(expected_elem_type) = constraint {
563 if let Type::Variable(elem_var) = expected_elem_type.as_ref() {
564 let elem_resolved = self
565 .unifier
566 .apply_substitutions(&Type::Variable(elem_var.clone()));
567 if let Type::Variable(_) = &elem_resolved {
568 let actual_elem: Option<Type> = match &resolved {
569 Type::Concrete(TypeAnnotation::Array(elem)) => {
570 Some(Type::Concrete((**elem).clone()))
571 }
572 Type::Generic { base, args }
573 if args.len() == 1 && is_array_or_vec_base(base) =>
574 {
575 Some(args[0].clone())
576 }
577 Type::Concrete(TypeAnnotation::Basic(name))
579 if name == "string" =>
580 {
581 Some(BuiltinTypes::string())
582 }
583 _ => None,
584 };
585 if let Some(elem_ty) = actual_elem {
586 if !matches!(elem_ty, Type::Variable(_)) {
587 new_bindings.push((elem_var.clone(), elem_ty));
588 }
589 }
590 }
591 }
592 }
593 }
594
595 for (var, ty) in new_bindings {
597 self.unifier.bind(var, ty);
598 }
599
600 Ok(())
601 }
602
603 fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
605 match constraint {
606 TypeConstraint::Comparable => match ty {
607 Type::Concrete(TypeAnnotation::Basic(name))
608 if BuiltinTypes::is_numeric_type_name(name)
609 || name == "string"
610 || name == "bool" =>
611 {
612 Ok(())
613 }
614 _ => Err(TypeError::ConstraintViolation(format!(
615 "{:?} is not comparable",
616 ty
617 ))),
618 },
619
620 TypeConstraint::Iterable => match ty {
621 Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
622 Type::Concrete(TypeAnnotation::Basic(name))
623 if name == "string" || name == "rows" =>
624 {
625 Ok(())
626 }
627 _ => Err(TypeError::ConstraintViolation(format!(
628 "{:?} is not iterable",
629 ty
630 ))),
631 },
632
633 TypeConstraint::Indexable(_) => match ty {
637 Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
638 Type::Generic { base, args }
639 if args.len() == 1 && is_array_or_vec_base(base) =>
640 {
641 Ok(())
642 }
643 Type::Concrete(TypeAnnotation::Basic(name))
644 if name == "string" || name == "rows" =>
645 {
646 Ok(())
647 }
648 _ => Err(TypeError::ConstraintViolation(format!(
649 "{:?} does not support index access",
650 ty
651 ))),
652 },
653
654 TypeConstraint::HasField(field, expected_field_type) => {
655 match ty {
656 Type::Concrete(TypeAnnotation::Object(fields)) => {
657 match fields.iter().find(|f| f.name == *field) {
658 Some(found_field) => {
659 if let Some(expected_ann) = expected_field_type.to_annotation() {
661 if self.unify_annotations(
662 &found_field.type_annotation,
663 &expected_ann,
664 )? {
665 Ok(())
666 } else {
667 Err(TypeError::ConstraintViolation(format!(
668 "field '{}' has type {:?}, expected {:?}",
669 field, found_field.type_annotation, expected_ann
670 )))
671 }
672 } else {
673 Ok(())
675 }
676 }
677 None => Err(TypeError::ConstraintViolation(format!(
678 "{:?} does not have field '{}'",
679 ty, field
680 ))),
681 }
682 }
683 Type::Concrete(TypeAnnotation::Basic(_name)) => {
684 Ok(())
693 }
694 _ => Err(TypeError::ConstraintViolation(format!(
695 "{:?} cannot have fields",
696 ty
697 ))),
698 }
699 }
700
701 TypeConstraint::Callable {
702 params: expected_params,
703 returns: expected_returns,
704 } => {
705 match ty {
706 Type::Concrete(TypeAnnotation::Function {
707 params: actual_params,
708 returns: actual_returns,
709 }) => {
710 if expected_params.len() != actual_params.len() {
712 return Err(TypeError::ConstraintViolation(format!(
713 "function expects {} parameters, got {}",
714 expected_params.len(),
715 actual_params.len()
716 )));
717 }
718
719 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
721 if let Some(expected_ann) = expected.to_annotation() {
722 if !self
723 .unify_annotations(&expected_ann, &actual.type_annotation)?
724 {
725 return Err(TypeError::ConstraintViolation(format!(
726 "parameter type mismatch: expected {:?}, got {:?}",
727 expected_ann, actual.type_annotation
728 )));
729 }
730 }
731 }
732
733 if let Some(expected_ret_ann) = expected_returns.to_annotation() {
735 if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
736 return Err(TypeError::ConstraintViolation(format!(
737 "return type mismatch: expected {:?}, got {:?}",
738 expected_ret_ann, actual_returns
739 )));
740 }
741 }
742
743 Ok(())
744 }
745 Type::Function {
746 params: actual_params,
747 returns: actual_returns,
748 } => {
749 if expected_params.len() != actual_params.len() {
750 return Err(TypeError::ConstraintViolation(format!(
751 "function expects {} parameters, got {}",
752 expected_params.len(),
753 actual_params.len()
754 )));
755 }
756 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
758 if let (Some(e_ann), Some(a_ann)) =
759 (expected.to_annotation(), actual.to_annotation())
760 {
761 if !self.unify_annotations(&e_ann, &a_ann)? {
762 return Err(TypeError::ConstraintViolation(format!(
763 "parameter type mismatch: expected {:?}, got {:?}",
764 e_ann, a_ann
765 )));
766 }
767 }
768 }
769 if let (Some(e_ret), Some(a_ret)) = (
770 expected_returns.to_annotation(),
771 actual_returns.to_annotation(),
772 ) {
773 if !self.unify_annotations(&a_ret, &e_ret)? {
774 return Err(TypeError::ConstraintViolation(format!(
775 "return type mismatch: expected {:?}, got {:?}",
776 e_ret, a_ret
777 )));
778 }
779 }
780 Ok(())
781 }
782 _ => Err(TypeError::ConstraintViolation(format!(
783 "{:?} is not callable",
784 ty
785 ))),
786 }
787 }
788
789 TypeConstraint::OneOf(options) => {
790 for option in options {
791 if let Type::Concrete(ann) = option {
793 if let Type::Concrete(ty_ann) = ty {
794 if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
795 return Ok(());
796 }
797 }
798 }
799 }
800
801 Err(TypeError::ConstraintViolation(format!(
802 "{:?} does not match any of {:?}",
803 ty, options
804 )))
805 }
806
807 TypeConstraint::Extends(base) => {
808 self.is_subtype(ty, base)
810 }
811
812 TypeConstraint::ImplementsTrait { trait_name } => {
813 match ty {
814 Type::Variable(_) => {
815 Err(TypeError::TraitBoundViolation {
818 type_name: format!("{:?}", ty),
819 trait_name: trait_name.clone(),
820 })
821 }
822 Type::Concrete(ann) => {
823 let type_name = match ann {
824 TypeAnnotation::Basic(n) => n.clone(),
825 TypeAnnotation::Reference(n) => n.to_string(),
826 _ => format!("{:?}", ann),
827 };
828 if self.has_trait_impl(trait_name, &type_name) {
829 Ok(())
830 } else {
831 Err(TypeError::TraitBoundViolation {
832 type_name,
833 trait_name: trait_name.clone(),
834 })
835 }
836 }
837 Type::Generic { base, .. } => {
838 let type_name = match base.as_ref() {
839 Type::Concrete(TypeAnnotation::Reference(n)) => n.to_string(),
840 Type::Concrete(TypeAnnotation::Basic(n)) => n.clone(),
841 _ => format!("{:?}", base),
842 };
843 if self.has_trait_impl(trait_name, &type_name) {
844 Ok(())
845 } else {
846 Err(TypeError::TraitBoundViolation {
847 type_name,
848 trait_name: trait_name.clone(),
849 })
850 }
851 }
852 _ => Err(TypeError::TraitBoundViolation {
853 type_name: format!("{:?}", ty),
854 trait_name: trait_name.clone(),
855 }),
856 }
857 }
858
859 TypeConstraint::HasMethod {
860 method_name,
861 arg_types: _,
862 return_type: _,
863 } => {
864 if let Some(method_table) = &self.method_table {
866 match ty {
867 Type::Variable(_) => Ok(()), Type::Concrete(ann) => {
869 let type_name = match ann {
870 TypeAnnotation::Basic(n) => n.clone(),
871 TypeAnnotation::Reference(n) => n.to_string(),
872 TypeAnnotation::Array(_) => "Vec".to_string(),
873 _ => return Ok(()), };
875 if method_table.lookup(ty, method_name).is_some() {
876 Ok(())
877 } else {
878 Err(TypeError::MethodNotFound {
879 type_name,
880 method_name: method_name.clone(),
881 })
882 }
883 }
884 Type::Generic { base, .. } => {
885 if method_table.lookup(ty, method_name).is_some() {
886 Ok(())
887 } else {
888 let type_name =
889 if let Type::Concrete(TypeAnnotation::Reference(n)) =
890 base.as_ref()
891 {
892 n.to_string()
893 } else {
894 format!("{:?}", base)
895 };
896 Err(TypeError::MethodNotFound {
897 type_name,
898 method_name: method_name.clone(),
899 })
900 }
901 }
902 _ => Ok(()), }
904 } else {
905 Ok(())
907 }
908 }
909 }
910 }
911
912 fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
920 let key = format!("{}::{}", trait_name, type_name);
921 if self.trait_impls.contains(&key) {
922 return true;
923 }
924 if let Some(canonical) = BuiltinTypes::canonical_numeric_runtime_name(type_name) {
926 let canon_key = format!("{}::{}", trait_name, canonical);
927 if self.trait_impls.contains(&canon_key) {
928 return true;
929 }
930 }
931 if let Some(script_alias) = BuiltinTypes::canonical_script_alias(type_name) {
933 let alias_key = format!("{}::{}", trait_name, script_alias);
934 if self.trait_impls.contains(&alias_key) {
935 return true;
936 }
937 }
938 if BuiltinTypes::is_integer_type_name(type_name) {
940 for widen_to in &["number", "float", "f64"] {
941 let widen_key = format!("{}::{}", trait_name, widen_to);
942 if self.trait_impls.contains(&widen_key) {
943 return true;
944 }
945 }
946 }
947 false
948 }
949
950 fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
957 match (ty, base) {
958 (t1, t2) if t1 == t2 => Ok(()),
960
961 (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
963
964 (
966 Type::Concrete(TypeAnnotation::Array(elem1)),
967 Type::Concrete(TypeAnnotation::Array(elem2)),
968 ) => {
969 let t1 = Type::Concrete(*elem1.clone());
970 let t2 = Type::Concrete(*elem2.clone());
971 self.is_subtype(&t1, &t2)
972 }
973
974 (
976 Type::Concrete(TypeAnnotation::Function {
977 params: p1,
978 returns: r1,
979 }),
980 Type::Concrete(TypeAnnotation::Function {
981 params: p2,
982 returns: r2,
983 }),
984 ) => {
985 if p1.len() != p2.len() {
987 return Err(TypeError::ConstraintViolation(format!(
988 "function parameter count mismatch: {} vs {}",
989 p1.len(),
990 p2.len()
991 )));
992 }
993
994 for (param1, param2) in p1.iter().zip(p2.iter()) {
996 let t1 = Type::Concrete(param2.type_annotation.clone());
997 let t2 = Type::Concrete(param1.type_annotation.clone());
998 self.is_subtype(&t1, &t2)?;
999 }
1000
1001 let ret1 = Type::Concrete(*r1.clone());
1003 let ret2 = Type::Concrete(*r2.clone());
1004 self.is_subtype(&ret1, &ret2)
1005 }
1006
1007 (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
1009 if name == "Option" && args.len() == 1 =>
1010 {
1011 let inner = Type::Concrete(args[0].clone());
1012 self.is_subtype(t, &inner)
1013 }
1014
1015 (
1017 Type::Function {
1018 params: p1,
1019 returns: r1,
1020 },
1021 Type::Function {
1022 params: p2,
1023 returns: r2,
1024 },
1025 ) => {
1026 if p1.len() != p2.len() {
1027 return Err(TypeError::ConstraintViolation(format!(
1028 "function parameter count mismatch: {} vs {}",
1029 p1.len(),
1030 p2.len()
1031 )));
1032 }
1033 for (param1, param2) in p1.iter().zip(p2.iter()) {
1035 self.is_subtype(param2, param1)?;
1036 }
1037 self.is_subtype(r1, r2)
1039 }
1040
1041 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
1043 if self.unify_annotations(ann1, ann2)? {
1044 Ok(())
1045 } else {
1046 Err(TypeError::ConstraintViolation(format!(
1047 "{:?} is not a subtype of {:?}",
1048 ty, base
1049 )))
1050 }
1051 }
1052
1053 _ => Err(TypeError::ConstraintViolation(format!(
1055 "{:?} is not a subtype of {:?}",
1056 ty, base
1057 ))),
1058 }
1059 }
1060
1061 pub fn unifier(&self) -> &Unifier {
1063 &self.unifier
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use crate::type_system::TypeVarGen;
1071 use shape_ast::ast::ObjectTypeField;
1072
1073 fn fresh_var(tvgen: &mut TypeVarGen) -> TypeVar {
1077 tvgen.fresh_var()
1078 }
1079
1080 fn fresh_type(tvgen: &mut TypeVarGen) -> Type {
1081 tvgen.fresh_type()
1082 }
1083
1084 #[test]
1085 fn test_hasfield_backward_propagation_binds_field_type() {
1086 let mut solver = ConstraintSolver::new();
1090 let mut tvgen = TypeVarGen::new();
1091
1092 let obj_var = fresh_var(&mut tvgen);
1093 let field_result_var = fresh_var(&mut tvgen);
1094 let bound_var = fresh_var(&mut tvgen);
1095
1096 let mut constraints = vec![
1097 (
1101 Type::Variable(obj_var.clone()),
1102 Type::Constrained {
1103 var: bound_var,
1104 constraint: Box::new(TypeConstraint::HasField(
1105 "x".to_string(),
1106 Box::new(Type::Variable(field_result_var.clone())),
1107 )),
1108 },
1109 ),
1110 (
1112 Type::Variable(obj_var),
1113 Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1114 name: "x".to_string(),
1115 optional: false,
1116 type_annotation: TypeAnnotation::Basic("int".to_string()),
1117 annotations: vec![],
1118 }])),
1119 ),
1120 ];
1121
1122 solver.solve(&mut constraints).unwrap();
1123
1124 let resolved = solver
1126 .unifier()
1127 .apply_substitutions(&Type::Variable(field_result_var));
1128 match &resolved {
1129 Type::Concrete(TypeAnnotation::Basic(name)) => {
1130 assert_eq!(name, "int", "field type should be int");
1131 }
1132 _ => panic!(
1133 "Expected field_result_var to be resolved to int, got {:?}",
1134 resolved
1135 ),
1136 }
1137 }
1138
1139 #[test]
1140 fn test_hasfield_backward_propagation_multiple_fields() {
1141 let mut solver = ConstraintSolver::new();
1143 let mut tvgen = TypeVarGen::new();
1144
1145 let obj_var = fresh_var(&mut tvgen);
1146 let field_x_var = fresh_var(&mut tvgen);
1147 let field_y_var = fresh_var(&mut tvgen);
1148 let bound_var_x = fresh_var(&mut tvgen);
1149 let bound_var_y = fresh_var(&mut tvgen);
1150
1151 let mut constraints = vec![
1152 (
1154 Type::Variable(obj_var.clone()),
1155 Type::Constrained {
1156 var: bound_var_x,
1157 constraint: Box::new(TypeConstraint::HasField(
1158 "x".to_string(),
1159 Box::new(Type::Variable(field_x_var.clone())),
1160 )),
1161 },
1162 ),
1163 (
1165 Type::Variable(obj_var.clone()),
1166 Type::Constrained {
1167 var: bound_var_y,
1168 constraint: Box::new(TypeConstraint::HasField(
1169 "y".to_string(),
1170 Box::new(Type::Variable(field_y_var.clone())),
1171 )),
1172 },
1173 ),
1174 (
1176 Type::Variable(obj_var),
1177 Type::Concrete(TypeAnnotation::Object(vec![
1178 ObjectTypeField {
1179 name: "x".to_string(),
1180 optional: false,
1181 type_annotation: TypeAnnotation::Basic("int".to_string()),
1182 annotations: vec![],
1183 },
1184 ObjectTypeField {
1185 name: "y".to_string(),
1186 optional: false,
1187 type_annotation: TypeAnnotation::Basic("string".to_string()),
1188 annotations: vec![],
1189 },
1190 ])),
1191 ),
1192 ];
1193
1194 solver.solve(&mut constraints).unwrap();
1195
1196 let resolved_x = solver
1197 .unifier()
1198 .apply_substitutions(&Type::Variable(field_x_var));
1199 let resolved_y = solver
1200 .unifier()
1201 .apply_substitutions(&Type::Variable(field_y_var));
1202
1203 match &resolved_x {
1204 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1205 _ => panic!("Expected x to be int, got {:?}", resolved_x),
1206 }
1207 match &resolved_y {
1208 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1209 _ => panic!("Expected y to be string, got {:?}", resolved_y),
1210 }
1211 }
1212
1213 #[test]
1216 fn test_int_constrained_numeric_succeeds() {
1217 let mut solver = ConstraintSolver::new();
1219 let trait_impls: std::collections::HashSet<String> = [
1221 "Numeric::int", "Numeric::number", "Numeric::decimal",
1222 "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1223 "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1224 "Numeric::f32", "Numeric::f64",
1225 ].iter().map(|s| s.to_string()).collect();
1226 solver.set_trait_impls(trait_impls);
1227 let mut tvgen = TypeVarGen::new();
1228 let bound_var = fresh_var(&mut tvgen);
1229 let mut constraints = vec![(
1230 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1231 Type::Constrained {
1232 var: bound_var,
1233 constraint: Box::new(TypeConstraint::ImplementsTrait {
1234 trait_name: "Numeric".to_string(),
1235 }),
1236 },
1237 )];
1238 assert!(solver.solve(&mut constraints).is_ok());
1239 }
1240
1241 #[test]
1242 fn test_numeric_widening_int_to_number() {
1243 let mut solver = ConstraintSolver::new();
1245 let mut constraints = vec![(
1246 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1247 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1248 )];
1249 assert!(solver.solve(&mut constraints).is_ok());
1250 }
1251
1252 #[test]
1253 fn test_numeric_widening_width_aware_integer_to_float_family() {
1254 let mut solver = ConstraintSolver::new();
1255 let mut constraints = vec![(
1256 Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1257 Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1258 )];
1259 assert!(solver.solve(&mut constraints).is_ok());
1260 }
1261
1262 #[test]
1263 fn test_no_widening_number_to_int() {
1264 let mut solver = ConstraintSolver::new();
1266 let mut constraints = vec![(
1267 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1268 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1269 )];
1270 assert!(solver.solve(&mut constraints).is_err());
1271 }
1272
1273 #[test]
1274 fn test_decimal_constrained_numeric_succeeds() {
1275 let mut solver = ConstraintSolver::new();
1276 let trait_impls: std::collections::HashSet<String> = [
1277 "Numeric::int", "Numeric::number", "Numeric::decimal",
1278 "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1279 "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1280 "Numeric::f32", "Numeric::f64",
1281 ].iter().map(|s| s.to_string()).collect();
1282 solver.set_trait_impls(trait_impls);
1283 let mut tvgen = TypeVarGen::new();
1284 let bound_var = fresh_var(&mut tvgen);
1285 let mut constraints = vec![(
1286 Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1287 Type::Constrained {
1288 var: bound_var,
1289 constraint: Box::new(TypeConstraint::ImplementsTrait {
1290 trait_name: "Numeric".to_string(),
1291 }),
1292 },
1293 )];
1294 assert!(solver.solve(&mut constraints).is_ok());
1295 }
1296
1297 #[test]
1298 fn test_comparable_accepts_int() {
1299 let mut solver = ConstraintSolver::new();
1301 let mut tvgen = TypeVarGen::new();
1302 let bound_var = fresh_var(&mut tvgen);
1303 let mut constraints = vec![(
1304 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1305 Type::Constrained {
1306 var: bound_var,
1307 constraint: Box::new(TypeConstraint::Comparable),
1308 },
1309 )];
1310 assert!(solver.solve(&mut constraints).is_ok());
1311 }
1312
1313 #[test]
1316 fn test_function_type_preserves_variables() {
1317 let mut tvgen = TypeVarGen::new();
1319 let param = fresh_type(&mut tvgen);
1320 let ret = fresh_type(&mut tvgen);
1321 let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1322 match func {
1323 Type::Function { params, returns } => {
1324 assert_eq!(params.len(), 1);
1325 assert_eq!(params[0], param);
1326 assert_eq!(*returns, ret);
1327 }
1328 _ => panic!("Expected Type::Function, got {:?}", func),
1329 }
1330 }
1331
1332 #[test]
1333 fn test_function_unification_binds_variables() {
1334 let mut solver = ConstraintSolver::new();
1336 let mut tvgen = TypeVarGen::new();
1337 let t1 = fresh_var(&mut tvgen);
1338 let t2 = fresh_var(&mut tvgen);
1339
1340 let mut constraints = vec![(
1341 Type::Function {
1342 params: vec![Type::Variable(t1.clone())],
1343 returns: Box::new(Type::Variable(t2.clone())),
1344 },
1345 Type::Function {
1346 params: vec![BuiltinTypes::number()],
1347 returns: Box::new(BuiltinTypes::string()),
1348 },
1349 )];
1350
1351 solver.solve(&mut constraints).unwrap();
1352
1353 let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1354 let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1355 assert_eq!(resolved_t1, BuiltinTypes::number());
1356 assert_eq!(resolved_t2, BuiltinTypes::string());
1357 }
1358
1359 #[test]
1360 fn test_function_cross_unification_with_concrete() {
1361 let mut solver = ConstraintSolver::new();
1363 let mut tvgen = TypeVarGen::new();
1364 let t1 = fresh_var(&mut tvgen);
1365
1366 let concrete_func = Type::Concrete(TypeAnnotation::Function {
1367 params: vec![shape_ast::ast::FunctionParam {
1368 name: None,
1369 optional: false,
1370 type_annotation: TypeAnnotation::Basic("number".to_string()),
1371 }],
1372 returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1373 });
1374
1375 let mut constraints = vec![(
1376 Type::Function {
1377 params: vec![Type::Variable(t1.clone())],
1378 returns: Box::new(BuiltinTypes::string()),
1379 },
1380 concrete_func,
1381 )];
1382
1383 solver.solve(&mut constraints).unwrap();
1384
1385 let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1386 assert_eq!(resolved, BuiltinTypes::number());
1387 }
1388
1389 #[test]
1390 fn test_object_annotations_unify_structurally() {
1391 let mut solver = ConstraintSolver::new();
1392 let mut constraints = vec![(
1393 Type::Concrete(TypeAnnotation::Object(vec![
1394 ObjectTypeField {
1395 name: "x".to_string(),
1396 optional: false,
1397 type_annotation: TypeAnnotation::Basic("int".to_string()),
1398 annotations: vec![],
1399 },
1400 ObjectTypeField {
1401 name: "y".to_string(),
1402 optional: false,
1403 type_annotation: TypeAnnotation::Basic("int".to_string()),
1404 annotations: vec![],
1405 },
1406 ])),
1407 Type::Concrete(TypeAnnotation::Object(vec![
1408 ObjectTypeField {
1409 name: "x".to_string(),
1410 optional: false,
1411 type_annotation: TypeAnnotation::Basic("int".to_string()),
1412 annotations: vec![],
1413 },
1414 ObjectTypeField {
1415 name: "y".to_string(),
1416 optional: false,
1417 type_annotation: TypeAnnotation::Basic("int".to_string()),
1418 annotations: vec![],
1419 },
1420 ])),
1421 )];
1422 assert!(solver.solve(&mut constraints).is_ok());
1423 }
1424
1425 #[test]
1426 fn test_intersection_annotations_unify_order_independent() {
1427 let mut solver = ConstraintSolver::new();
1428 let obj_xy = TypeAnnotation::Object(vec![
1429 ObjectTypeField {
1430 name: "x".to_string(),
1431 optional: false,
1432 type_annotation: TypeAnnotation::Basic("int".to_string()),
1433 annotations: vec![],
1434 },
1435 ObjectTypeField {
1436 name: "y".to_string(),
1437 optional: false,
1438 type_annotation: TypeAnnotation::Basic("int".to_string()),
1439 annotations: vec![],
1440 },
1441 ]);
1442 let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1443 name: "z".to_string(),
1444 optional: false,
1445 type_annotation: TypeAnnotation::Basic("int".to_string()),
1446 annotations: vec![],
1447 }]);
1448
1449 let mut constraints = vec![(
1450 Type::Concrete(TypeAnnotation::Intersection(vec![
1451 obj_xy.clone(),
1452 obj_z.clone(),
1453 ])),
1454 Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1455 )];
1456 assert!(solver.solve(&mut constraints).is_ok());
1457 }
1458
1459 #[test]
1462 fn test_implements_trait_satisfied() {
1463 let mut solver = ConstraintSolver::new();
1464 let mut impls = std::collections::HashSet::new();
1465 impls.insert("Comparable::number".to_string());
1466 solver.set_trait_impls(impls);
1467
1468 let mut tvgen = TypeVarGen::new();
1469 let bound_var = fresh_var(&mut tvgen);
1470 let mut constraints = vec![(
1471 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1472 Type::Constrained {
1473 var: bound_var,
1474 constraint: Box::new(TypeConstraint::ImplementsTrait {
1475 trait_name: "Comparable".to_string(),
1476 }),
1477 },
1478 )];
1479 assert!(solver.solve(&mut constraints).is_ok());
1480 }
1481
1482 #[test]
1483 fn test_implements_trait_violated() {
1484 let mut solver = ConstraintSolver::new();
1485 let mut tvgen = TypeVarGen::new();
1487 let bound_var = fresh_var(&mut tvgen);
1488 let mut constraints = vec![(
1489 Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1490 Type::Constrained {
1491 var: bound_var,
1492 constraint: Box::new(TypeConstraint::ImplementsTrait {
1493 trait_name: "Comparable".to_string(),
1494 }),
1495 },
1496 )];
1497 let result = solver.solve(&mut constraints);
1498 assert!(result.is_err());
1499 match result.unwrap_err() {
1500 TypeError::TraitBoundViolation {
1501 type_name,
1502 trait_name,
1503 } => {
1504 assert_eq!(type_name, "string");
1505 assert_eq!(trait_name, "Comparable");
1506 }
1507 other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1508 }
1509 }
1510
1511 #[test]
1512 fn test_implements_trait_via_variable_resolution() {
1513 let mut solver = ConstraintSolver::new();
1514 let mut impls = std::collections::HashSet::new();
1515 impls.insert("Sortable::number".to_string());
1516 solver.set_trait_impls(impls);
1517
1518 let mut tvgen = TypeVarGen::new();
1519 let type_var = fresh_var(&mut tvgen);
1520 let bound_var = fresh_var(&mut tvgen);
1521
1522 let mut constraints = vec![
1523 (
1525 Type::Variable(type_var.clone()),
1526 Type::Constrained {
1527 var: bound_var,
1528 constraint: Box::new(TypeConstraint::ImplementsTrait {
1529 trait_name: "Sortable".to_string(),
1530 }),
1531 },
1532 ),
1533 (
1535 Type::Variable(type_var),
1536 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1537 ),
1538 ];
1539 assert!(
1540 solver.solve(&mut constraints).is_ok(),
1541 "T resolved to number which implements Sortable"
1542 );
1543 }
1544}