1use super::checking::MethodTable;
32use super::unification::Unifier;
33use super::*;
34use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
35use std::collections::{HashMap, HashSet};
36
37fn is_array_or_vec_base(base: &Type) -> bool {
39 match base {
40 Type::Concrete(TypeAnnotation::Reference(name)) => name == "Array" || name == "Vec",
41 Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
42 _ => false,
43 }
44}
45
46pub struct ConstraintSolver {
47 unifier: Unifier,
49 _deferred: Vec<(Type, Type)>,
52 bounds: HashMap<TypeVar, TypeConstraint>,
54 method_table: Option<MethodTable>,
56 trait_impls: HashSet<String>,
58}
59
60impl Default for ConstraintSolver {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ConstraintSolver {
67 pub fn new() -> Self {
68 ConstraintSolver {
69 unifier: Unifier::new(),
70 _deferred: Vec::new(),
71 bounds: HashMap::new(),
72 method_table: None,
73 trait_impls: HashSet::new(),
74 }
75 }
76
77 pub fn set_method_table(&mut self, table: MethodTable) {
81 self.method_table = Some(table);
82 }
83
84 pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
87 self.trait_impls = impls;
88 }
89
90 pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
92 let mut unsolved = Vec::new();
94
95 for (t1, t2) in constraints.drain(..) {
96 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
97 unsolved.push((t1, t2));
99 }
100 }
101
102 let mut made_progress = true;
104 while made_progress && !unsolved.is_empty() {
105 made_progress = false;
106 let mut still_unsolved = Vec::new();
107
108 for (t1, t2) in unsolved.drain(..) {
109 if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
110 still_unsolved.push((t1, t2));
111 } else {
112 made_progress = true;
113 }
114 }
115
116 unsolved = still_unsolved;
117 }
118
119 if !unsolved.is_empty() {
121 return Err(TypeError::UnsolvedConstraints(unsolved));
122 }
123
124 self.apply_bounds()?;
126
127 Ok(())
128 }
129
130 fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
132 let t1 = self.unifier.apply_substitutions(&t1);
136 let t2 = self.unifier.apply_substitutions(&t2);
137
138 match (&t1, &t2) {
139 (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
141
142 (Type::Constrained { var, constraint }, ty)
146 | (ty, Type::Constrained { var, constraint }) => {
147 self.bounds.insert(var.clone(), *constraint.clone());
149
150 self.solve_constraint(Type::Variable(var.clone()), ty.clone())
152 }
153
154 (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
155 if self.occurs_in(var, ty) {
157 return Err(TypeError::InfiniteType(var.clone()));
158 }
159
160 self.unifier.bind(var.clone(), ty.clone());
161 Ok(())
162 }
163
164 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
166 if self.unify_annotations(ann1, ann2)? {
167 Ok(())
168 } else if Self::can_numeric_widen(ann1, ann2) {
169 Ok(())
171 } else {
172 Err(TypeError::TypeMismatch(
173 format!("{:?}", ann1),
174 format!("{:?}", ann2),
175 ))
176 }
177 }
178
179 (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
181 self.solve_constraint(*b1.clone(), *b2.clone())?;
182
183 let is_result_base = |base: &Type| match base {
184 Type::Concrete(TypeAnnotation::Reference(name)) => name == "Result",
185 Type::Concrete(TypeAnnotation::Basic(name)) => name == "Result",
186 _ => false,
187 };
188
189 if a1.len() != a2.len() {
190 if is_result_base(&b1) && is_result_base(&b2) {
191 match (a1.len(), a2.len()) {
192 (1, 2) | (2, 1) => {
195 self.solve_constraint(a1[0].clone(), a2[0].clone())?;
196 return Ok(());
197 }
198 _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
199 }
200 } else {
201 return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
202 }
203 }
204
205 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
206 self.solve_constraint(arg1.clone(), arg2.clone())?;
207 }
208
209 Ok(())
210 }
211
212 (
214 Type::Function {
215 params: p1,
216 returns: r1,
217 },
218 Type::Function {
219 params: p2,
220 returns: r2,
221 },
222 ) => {
223 if p1.len() != p2.len() {
224 return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
225 }
226 for (param1, param2) in p1.iter().zip(p2.iter()) {
227 self.solve_constraint(param2.clone(), param1.clone())?;
231 }
232 self.solve_constraint(*r1.clone(), *r2.clone())
233 }
234
235 (
237 Type::Function {
238 params: fp,
239 returns: fr,
240 },
241 Type::Concrete(TypeAnnotation::Function {
242 params: cp,
243 returns: cr,
244 }),
245 )
246 | (
247 Type::Concrete(TypeAnnotation::Function {
248 params: cp,
249 returns: cr,
250 }),
251 Type::Function {
252 params: fp,
253 returns: fr,
254 },
255 ) => {
256 if fp.len() != cp.len() {
257 return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
258 }
259 for (f_param, c_param) in fp.iter().zip(cp.iter()) {
260 self.solve_constraint(
261 f_param.clone(),
262 Type::Concrete(c_param.type_annotation.clone()),
263 )?;
264 }
265 self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
266 }
267
268 (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
270 | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
271 if args.len() == 1 && is_array_or_vec_base(base) =>
272 {
273 self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
274 }
275
276 _ => Err(TypeError::TypeMismatch(
277 format!("{:?}", t1),
278 format!("{:?}", t2),
279 )),
280 }
281 }
282
283 fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
285 match ty {
286 Type::Variable(v) => v == var,
287 Type::Generic { base, args } => {
288 self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
289 }
290 Type::Constrained { var: v, .. } => v == var,
291 Type::Function { params, returns } => {
292 params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
293 }
294 Type::Concrete(_) => false,
295 }
296 }
297
298 fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
305 let from_name = match from {
306 TypeAnnotation::Basic(name) => Some(name.as_str()),
307 TypeAnnotation::Reference(name) => Some(name.as_str()),
308 _ => None,
309 };
310 let to_name = match to {
311 TypeAnnotation::Basic(name) => Some(name.as_str()),
312 TypeAnnotation::Reference(name) => Some(name.as_str()),
313 _ => None,
314 };
315
316 match (from_name, to_name) {
317 (Some(f), Some(t)) => {
318 BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
319 }
320 _ => false,
321 }
322 }
323
324 fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
326 match (ann1, ann2) {
327 (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
329 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
330 }
331 (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
332 (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
333 | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
334 Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
335 }
336
337 (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
339 self.unify_annotations(e1, e2)
340 }
341
342 (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
344 if t1.len() != t2.len() {
345 return Ok(false);
346 }
347
348 for (elem1, elem2) in t1.iter().zip(t2.iter()) {
349 if !self.unify_annotations(elem1, elem2)? {
350 return Ok(false);
351 }
352 }
353
354 Ok(true)
355 }
356
357 (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
359 self.object_fields_compatible(f1, f2)
360 }
361
362 (
364 TypeAnnotation::Function {
365 params: p1,
366 returns: r1,
367 },
368 TypeAnnotation::Function {
369 params: p2,
370 returns: r2,
371 },
372 ) => {
373 if p1.len() != p2.len() {
374 return Ok(false);
375 }
376
377 for (param1, param2) in p1.iter().zip(p2.iter()) {
378 if !self.unify_annotations(¶m1.type_annotation, ¶m2.type_annotation)? {
379 return Ok(false);
380 }
381 }
382
383 self.unify_annotations(r1, r2)
384 }
385
386 (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
389 for t1 in u1 {
391 let mut found_match = false;
392 for t2 in u2 {
393 if self.unify_annotations(t1, t2)? {
394 found_match = true;
395 break;
396 }
397 }
398 if !found_match {
399 return Ok(false);
400 }
401 }
402 for t2 in u2 {
404 let mut found_match = false;
405 for t1 in u1 {
406 if self.unify_annotations(t1, t2)? {
407 found_match = true;
408 break;
409 }
410 }
411 if !found_match {
412 return Ok(false);
413 }
414 }
415 Ok(true)
416 }
417
418 (TypeAnnotation::Union(union_types), other)
420 | (other, TypeAnnotation::Union(union_types)) => {
421 for union_type in union_types {
422 if self.unify_annotations(union_type, other)? {
423 return Ok(true);
424 }
425 }
426 Ok(false)
427 }
428
429 (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
431 self.unify_annotation_sets(i1, i2)
432 }
433
434 (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
436 (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
437 (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
438
439 (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
442 Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
443 }
444
445 (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
447 | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
448 if name == "Array" && args.len() == 1 =>
449 {
450 self.unify_annotations(&args[0], elem)
451 }
452
453 _ => Ok(false),
455 }
456 }
457
458 fn object_fields_compatible(
459 &self,
460 left: &[ObjectTypeField],
461 right: &[ObjectTypeField],
462 ) -> TypeResult<bool> {
463 for left_field in left {
464 let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
465 return Ok(false);
466 };
467 if left_field.optional != right_field.optional {
468 return Ok(false);
469 }
470 if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
471 return Ok(false);
472 }
473 }
474 if left.len() != right.len() {
475 return Ok(false);
476 }
477 Ok(true)
478 }
479
480 fn unify_annotation_sets(
481 &self,
482 left: &[TypeAnnotation],
483 right: &[TypeAnnotation],
484 ) -> TypeResult<bool> {
485 if left.len() != right.len() {
486 return Ok(false);
487 }
488
489 let mut matched = vec![false; right.len()];
490 for left_ann in left {
491 let mut found = false;
492 for (idx, right_ann) in right.iter().enumerate() {
493 if matched[idx] {
494 continue;
495 }
496 if self.unify_annotations(left_ann, right_ann)? {
497 matched[idx] = true;
498 found = true;
499 break;
500 }
501 }
502 if !found {
503 return Ok(false);
504 }
505 }
506
507 Ok(true)
508 }
509
510 fn apply_bounds(&mut self) -> TypeResult<()> {
516 let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
517
518 for (var, constraint) in &self.bounds {
519 let resolved = self
522 .unifier
523 .apply_substitutions(&Type::Variable(var.clone()));
524
525 if let Type::Variable(_) = &resolved {
526 continue;
528 }
529
530 self.check_constraint(&resolved, constraint)?;
531
532 if let TypeConstraint::HasField(field, expected_field_type) = constraint {
535 if let Type::Variable(field_var) = expected_field_type.as_ref() {
536 let field_resolved = self
538 .unifier
539 .apply_substitutions(&Type::Variable(field_var.clone()));
540 if let Type::Variable(_) = &field_resolved {
541 if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
543 if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
544 new_bindings.push((
545 field_var.clone(),
546 Type::Concrete(found_field.type_annotation.clone()),
547 ));
548 }
549 }
550 }
551 }
552 }
553 }
554
555 for (var, ty) in new_bindings {
557 self.unifier.bind(var, ty);
558 }
559
560 Ok(())
561 }
562
563 fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
565 match constraint {
566 TypeConstraint::Comparable => match ty {
567 Type::Concrete(TypeAnnotation::Basic(name))
568 if BuiltinTypes::is_numeric_type_name(name)
569 || name == "string"
570 || name == "bool" =>
571 {
572 Ok(())
573 }
574 _ => Err(TypeError::ConstraintViolation(format!(
575 "{:?} is not comparable",
576 ty
577 ))),
578 },
579
580 TypeConstraint::Iterable => match ty {
581 Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
582 Type::Concrete(TypeAnnotation::Basic(name))
583 if name == "string" || name == "rows" =>
584 {
585 Ok(())
586 }
587 _ => Err(TypeError::ConstraintViolation(format!(
588 "{:?} is not iterable",
589 ty
590 ))),
591 },
592
593 TypeConstraint::HasField(field, expected_field_type) => {
594 match ty {
595 Type::Concrete(TypeAnnotation::Object(fields)) => {
596 match fields.iter().find(|f| f.name == *field) {
597 Some(found_field) => {
598 if let Some(expected_ann) = expected_field_type.to_annotation() {
600 if self.unify_annotations(
601 &found_field.type_annotation,
602 &expected_ann,
603 )? {
604 Ok(())
605 } else {
606 Err(TypeError::ConstraintViolation(format!(
607 "field '{}' has type {:?}, expected {:?}",
608 field, found_field.type_annotation, expected_ann
609 )))
610 }
611 } else {
612 Ok(())
614 }
615 }
616 None => Err(TypeError::ConstraintViolation(format!(
617 "{:?} does not have field '{}'",
618 ty, field
619 ))),
620 }
621 }
622 Type::Concrete(TypeAnnotation::Basic(_name)) => {
623 Ok(())
632 }
633 _ => Err(TypeError::ConstraintViolation(format!(
634 "{:?} cannot have fields",
635 ty
636 ))),
637 }
638 }
639
640 TypeConstraint::Callable {
641 params: expected_params,
642 returns: expected_returns,
643 } => {
644 match ty {
645 Type::Concrete(TypeAnnotation::Function {
646 params: actual_params,
647 returns: actual_returns,
648 }) => {
649 if expected_params.len() != actual_params.len() {
651 return Err(TypeError::ConstraintViolation(format!(
652 "function expects {} parameters, got {}",
653 expected_params.len(),
654 actual_params.len()
655 )));
656 }
657
658 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
660 if let Some(expected_ann) = expected.to_annotation() {
661 if !self
662 .unify_annotations(&expected_ann, &actual.type_annotation)?
663 {
664 return Err(TypeError::ConstraintViolation(format!(
665 "parameter type mismatch: expected {:?}, got {:?}",
666 expected_ann, actual.type_annotation
667 )));
668 }
669 }
670 }
671
672 if let Some(expected_ret_ann) = expected_returns.to_annotation() {
674 if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
675 return Err(TypeError::ConstraintViolation(format!(
676 "return type mismatch: expected {:?}, got {:?}",
677 expected_ret_ann, actual_returns
678 )));
679 }
680 }
681
682 Ok(())
683 }
684 Type::Function {
685 params: actual_params,
686 returns: actual_returns,
687 } => {
688 if expected_params.len() != actual_params.len() {
689 return Err(TypeError::ConstraintViolation(format!(
690 "function expects {} parameters, got {}",
691 expected_params.len(),
692 actual_params.len()
693 )));
694 }
695 for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
697 if let (Some(e_ann), Some(a_ann)) =
698 (expected.to_annotation(), actual.to_annotation())
699 {
700 if !self.unify_annotations(&e_ann, &a_ann)? {
701 return Err(TypeError::ConstraintViolation(format!(
702 "parameter type mismatch: expected {:?}, got {:?}",
703 e_ann, a_ann
704 )));
705 }
706 }
707 }
708 if let (Some(e_ret), Some(a_ret)) = (
709 expected_returns.to_annotation(),
710 actual_returns.to_annotation(),
711 ) {
712 if !self.unify_annotations(&a_ret, &e_ret)? {
713 return Err(TypeError::ConstraintViolation(format!(
714 "return type mismatch: expected {:?}, got {:?}",
715 e_ret, a_ret
716 )));
717 }
718 }
719 Ok(())
720 }
721 _ => Err(TypeError::ConstraintViolation(format!(
722 "{:?} is not callable",
723 ty
724 ))),
725 }
726 }
727
728 TypeConstraint::OneOf(options) => {
729 for option in options {
730 if let Type::Concrete(ann) = option {
732 if let Type::Concrete(ty_ann) = ty {
733 if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
734 return Ok(());
735 }
736 }
737 }
738 }
739
740 Err(TypeError::ConstraintViolation(format!(
741 "{:?} does not match any of {:?}",
742 ty, options
743 )))
744 }
745
746 TypeConstraint::Extends(base) => {
747 self.is_subtype(ty, base)
749 }
750
751 TypeConstraint::ImplementsTrait { trait_name } => {
752 match ty {
753 Type::Variable(_) => {
754 Err(TypeError::TraitBoundViolation {
757 type_name: format!("{:?}", ty),
758 trait_name: trait_name.clone(),
759 })
760 }
761 Type::Concrete(ann) => {
762 let type_name = match ann {
763 TypeAnnotation::Basic(n) => n.clone(),
764 TypeAnnotation::Reference(n) => n.to_string(),
765 _ => format!("{:?}", ann),
766 };
767 if self.has_trait_impl(trait_name, &type_name) {
768 Ok(())
769 } else {
770 Err(TypeError::TraitBoundViolation {
771 type_name,
772 trait_name: trait_name.clone(),
773 })
774 }
775 }
776 Type::Generic { base, .. } => {
777 let type_name = match base.as_ref() {
778 Type::Concrete(TypeAnnotation::Reference(n)) => n.to_string(),
779 Type::Concrete(TypeAnnotation::Basic(n)) => n.clone(),
780 _ => format!("{:?}", base),
781 };
782 if self.has_trait_impl(trait_name, &type_name) {
783 Ok(())
784 } else {
785 Err(TypeError::TraitBoundViolation {
786 type_name,
787 trait_name: trait_name.clone(),
788 })
789 }
790 }
791 _ => Err(TypeError::TraitBoundViolation {
792 type_name: format!("{:?}", ty),
793 trait_name: trait_name.clone(),
794 }),
795 }
796 }
797
798 TypeConstraint::HasMethod {
799 method_name,
800 arg_types: _,
801 return_type: _,
802 } => {
803 if let Some(method_table) = &self.method_table {
805 match ty {
806 Type::Variable(_) => Ok(()), Type::Concrete(ann) => {
808 let type_name = match ann {
809 TypeAnnotation::Basic(n) => n.clone(),
810 TypeAnnotation::Reference(n) => n.to_string(),
811 TypeAnnotation::Array(_) => "Vec".to_string(),
812 _ => return Ok(()), };
814 if method_table.lookup(ty, method_name).is_some() {
815 Ok(())
816 } else {
817 Err(TypeError::MethodNotFound {
818 type_name,
819 method_name: method_name.clone(),
820 })
821 }
822 }
823 Type::Generic { base, .. } => {
824 if method_table.lookup(ty, method_name).is_some() {
825 Ok(())
826 } else {
827 let type_name =
828 if let Type::Concrete(TypeAnnotation::Reference(n)) =
829 base.as_ref()
830 {
831 n.to_string()
832 } else {
833 format!("{:?}", base)
834 };
835 Err(TypeError::MethodNotFound {
836 type_name,
837 method_name: method_name.clone(),
838 })
839 }
840 }
841 _ => Ok(()), }
843 } else {
844 Ok(())
846 }
847 }
848 }
849 }
850
851 fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
859 let key = format!("{}::{}", trait_name, type_name);
860 if self.trait_impls.contains(&key) {
861 return true;
862 }
863 if let Some(canonical) = BuiltinTypes::canonical_numeric_runtime_name(type_name) {
865 let canon_key = format!("{}::{}", trait_name, canonical);
866 if self.trait_impls.contains(&canon_key) {
867 return true;
868 }
869 }
870 if let Some(script_alias) = BuiltinTypes::canonical_script_alias(type_name) {
872 let alias_key = format!("{}::{}", trait_name, script_alias);
873 if self.trait_impls.contains(&alias_key) {
874 return true;
875 }
876 }
877 if BuiltinTypes::is_integer_type_name(type_name) {
879 for widen_to in &["number", "float", "f64"] {
880 let widen_key = format!("{}::{}", trait_name, widen_to);
881 if self.trait_impls.contains(&widen_key) {
882 return true;
883 }
884 }
885 }
886 false
887 }
888
889 fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
896 match (ty, base) {
897 (t1, t2) if t1 == t2 => Ok(()),
899
900 (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
902
903 (
905 Type::Concrete(TypeAnnotation::Array(elem1)),
906 Type::Concrete(TypeAnnotation::Array(elem2)),
907 ) => {
908 let t1 = Type::Concrete(*elem1.clone());
909 let t2 = Type::Concrete(*elem2.clone());
910 self.is_subtype(&t1, &t2)
911 }
912
913 (
915 Type::Concrete(TypeAnnotation::Function {
916 params: p1,
917 returns: r1,
918 }),
919 Type::Concrete(TypeAnnotation::Function {
920 params: p2,
921 returns: r2,
922 }),
923 ) => {
924 if p1.len() != p2.len() {
926 return Err(TypeError::ConstraintViolation(format!(
927 "function parameter count mismatch: {} vs {}",
928 p1.len(),
929 p2.len()
930 )));
931 }
932
933 for (param1, param2) in p1.iter().zip(p2.iter()) {
935 let t1 = Type::Concrete(param2.type_annotation.clone());
936 let t2 = Type::Concrete(param1.type_annotation.clone());
937 self.is_subtype(&t1, &t2)?;
938 }
939
940 let ret1 = Type::Concrete(*r1.clone());
942 let ret2 = Type::Concrete(*r2.clone());
943 self.is_subtype(&ret1, &ret2)
944 }
945
946 (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
948 if name == "Option" && args.len() == 1 =>
949 {
950 let inner = Type::Concrete(args[0].clone());
951 self.is_subtype(t, &inner)
952 }
953
954 (
956 Type::Function {
957 params: p1,
958 returns: r1,
959 },
960 Type::Function {
961 params: p2,
962 returns: r2,
963 },
964 ) => {
965 if p1.len() != p2.len() {
966 return Err(TypeError::ConstraintViolation(format!(
967 "function parameter count mismatch: {} vs {}",
968 p1.len(),
969 p2.len()
970 )));
971 }
972 for (param1, param2) in p1.iter().zip(p2.iter()) {
974 self.is_subtype(param2, param1)?;
975 }
976 self.is_subtype(r1, r2)
978 }
979
980 (Type::Concrete(ann1), Type::Concrete(ann2)) => {
982 if self.unify_annotations(ann1, ann2)? {
983 Ok(())
984 } else {
985 Err(TypeError::ConstraintViolation(format!(
986 "{:?} is not a subtype of {:?}",
987 ty, base
988 )))
989 }
990 }
991
992 _ => Err(TypeError::ConstraintViolation(format!(
994 "{:?} is not a subtype of {:?}",
995 ty, base
996 ))),
997 }
998 }
999
1000 pub fn unifier(&self) -> &Unifier {
1002 &self.unifier
1003 }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008 use super::*;
1009 use shape_ast::ast::ObjectTypeField;
1010
1011 #[test]
1012 fn test_hasfield_backward_propagation_binds_field_type() {
1013 let mut solver = ConstraintSolver::new();
1017
1018 let obj_var = TypeVar::fresh();
1019 let field_result_var = TypeVar::fresh();
1020 let bound_var = TypeVar::fresh();
1021
1022 let mut constraints = vec![
1023 (
1027 Type::Variable(obj_var.clone()),
1028 Type::Constrained {
1029 var: bound_var,
1030 constraint: Box::new(TypeConstraint::HasField(
1031 "x".to_string(),
1032 Box::new(Type::Variable(field_result_var.clone())),
1033 )),
1034 },
1035 ),
1036 (
1038 Type::Variable(obj_var),
1039 Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1040 name: "x".to_string(),
1041 optional: false,
1042 type_annotation: TypeAnnotation::Basic("int".to_string()),
1043 annotations: vec![],
1044 }])),
1045 ),
1046 ];
1047
1048 solver.solve(&mut constraints).unwrap();
1049
1050 let resolved = solver
1052 .unifier()
1053 .apply_substitutions(&Type::Variable(field_result_var));
1054 match &resolved {
1055 Type::Concrete(TypeAnnotation::Basic(name)) => {
1056 assert_eq!(name, "int", "field type should be int");
1057 }
1058 _ => panic!(
1059 "Expected field_result_var to be resolved to int, got {:?}",
1060 resolved
1061 ),
1062 }
1063 }
1064
1065 #[test]
1066 fn test_hasfield_backward_propagation_multiple_fields() {
1067 let mut solver = ConstraintSolver::new();
1069
1070 let obj_var = TypeVar::fresh();
1071 let field_x_var = TypeVar::fresh();
1072 let field_y_var = TypeVar::fresh();
1073 let bound_var_x = TypeVar::fresh();
1074 let bound_var_y = TypeVar::fresh();
1075
1076 let mut constraints = vec![
1077 (
1079 Type::Variable(obj_var.clone()),
1080 Type::Constrained {
1081 var: bound_var_x,
1082 constraint: Box::new(TypeConstraint::HasField(
1083 "x".to_string(),
1084 Box::new(Type::Variable(field_x_var.clone())),
1085 )),
1086 },
1087 ),
1088 (
1090 Type::Variable(obj_var.clone()),
1091 Type::Constrained {
1092 var: bound_var_y,
1093 constraint: Box::new(TypeConstraint::HasField(
1094 "y".to_string(),
1095 Box::new(Type::Variable(field_y_var.clone())),
1096 )),
1097 },
1098 ),
1099 (
1101 Type::Variable(obj_var),
1102 Type::Concrete(TypeAnnotation::Object(vec![
1103 ObjectTypeField {
1104 name: "x".to_string(),
1105 optional: false,
1106 type_annotation: TypeAnnotation::Basic("int".to_string()),
1107 annotations: vec![],
1108 },
1109 ObjectTypeField {
1110 name: "y".to_string(),
1111 optional: false,
1112 type_annotation: TypeAnnotation::Basic("string".to_string()),
1113 annotations: vec![],
1114 },
1115 ])),
1116 ),
1117 ];
1118
1119 solver.solve(&mut constraints).unwrap();
1120
1121 let resolved_x = solver
1122 .unifier()
1123 .apply_substitutions(&Type::Variable(field_x_var));
1124 let resolved_y = solver
1125 .unifier()
1126 .apply_substitutions(&Type::Variable(field_y_var));
1127
1128 match &resolved_x {
1129 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1130 _ => panic!("Expected x to be int, got {:?}", resolved_x),
1131 }
1132 match &resolved_y {
1133 Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1134 _ => panic!("Expected y to be string, got {:?}", resolved_y),
1135 }
1136 }
1137
1138 #[test]
1141 fn test_int_constrained_numeric_succeeds() {
1142 let mut solver = ConstraintSolver::new();
1144 let trait_impls: std::collections::HashSet<String> = [
1146 "Numeric::int", "Numeric::number", "Numeric::decimal",
1147 "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1148 "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1149 "Numeric::f32", "Numeric::f64",
1150 ].iter().map(|s| s.to_string()).collect();
1151 solver.set_trait_impls(trait_impls);
1152 let bound_var = TypeVar::fresh();
1153 let mut constraints = vec![(
1154 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1155 Type::Constrained {
1156 var: bound_var,
1157 constraint: Box::new(TypeConstraint::ImplementsTrait {
1158 trait_name: "Numeric".to_string(),
1159 }),
1160 },
1161 )];
1162 assert!(solver.solve(&mut constraints).is_ok());
1163 }
1164
1165 #[test]
1166 fn test_numeric_widening_int_to_number() {
1167 let mut solver = ConstraintSolver::new();
1169 let mut constraints = vec![(
1170 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1171 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1172 )];
1173 assert!(solver.solve(&mut constraints).is_ok());
1174 }
1175
1176 #[test]
1177 fn test_numeric_widening_width_aware_integer_to_float_family() {
1178 let mut solver = ConstraintSolver::new();
1179 let mut constraints = vec![(
1180 Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1181 Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1182 )];
1183 assert!(solver.solve(&mut constraints).is_ok());
1184 }
1185
1186 #[test]
1187 fn test_no_widening_number_to_int() {
1188 let mut solver = ConstraintSolver::new();
1190 let mut constraints = vec![(
1191 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1192 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1193 )];
1194 assert!(solver.solve(&mut constraints).is_err());
1195 }
1196
1197 #[test]
1198 fn test_decimal_constrained_numeric_succeeds() {
1199 let mut solver = ConstraintSolver::new();
1200 let trait_impls: std::collections::HashSet<String> = [
1201 "Numeric::int", "Numeric::number", "Numeric::decimal",
1202 "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1203 "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1204 "Numeric::f32", "Numeric::f64",
1205 ].iter().map(|s| s.to_string()).collect();
1206 solver.set_trait_impls(trait_impls);
1207 let bound_var = TypeVar::fresh();
1208 let mut constraints = vec![(
1209 Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1210 Type::Constrained {
1211 var: bound_var,
1212 constraint: Box::new(TypeConstraint::ImplementsTrait {
1213 trait_name: "Numeric".to_string(),
1214 }),
1215 },
1216 )];
1217 assert!(solver.solve(&mut constraints).is_ok());
1218 }
1219
1220 #[test]
1221 fn test_comparable_accepts_int() {
1222 let mut solver = ConstraintSolver::new();
1224 let bound_var = TypeVar::fresh();
1225 let mut constraints = vec![(
1226 Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1227 Type::Constrained {
1228 var: bound_var,
1229 constraint: Box::new(TypeConstraint::Comparable),
1230 },
1231 )];
1232 assert!(solver.solve(&mut constraints).is_ok());
1233 }
1234
1235 #[test]
1238 fn test_function_type_preserves_variables() {
1239 let param = Type::fresh_var();
1241 let ret = Type::fresh_var();
1242 let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1243 match func {
1244 Type::Function { params, returns } => {
1245 assert_eq!(params.len(), 1);
1246 assert_eq!(params[0], param);
1247 assert_eq!(*returns, ret);
1248 }
1249 _ => panic!("Expected Type::Function, got {:?}", func),
1250 }
1251 }
1252
1253 #[test]
1254 fn test_function_unification_binds_variables() {
1255 let mut solver = ConstraintSolver::new();
1257 let t1 = TypeVar::fresh();
1258 let t2 = TypeVar::fresh();
1259
1260 let mut constraints = vec![(
1261 Type::Function {
1262 params: vec![Type::Variable(t1.clone())],
1263 returns: Box::new(Type::Variable(t2.clone())),
1264 },
1265 Type::Function {
1266 params: vec![BuiltinTypes::number()],
1267 returns: Box::new(BuiltinTypes::string()),
1268 },
1269 )];
1270
1271 solver.solve(&mut constraints).unwrap();
1272
1273 let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1274 let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1275 assert_eq!(resolved_t1, BuiltinTypes::number());
1276 assert_eq!(resolved_t2, BuiltinTypes::string());
1277 }
1278
1279 #[test]
1280 fn test_function_cross_unification_with_concrete() {
1281 let mut solver = ConstraintSolver::new();
1283 let t1 = TypeVar::fresh();
1284
1285 let concrete_func = Type::Concrete(TypeAnnotation::Function {
1286 params: vec![shape_ast::ast::FunctionParam {
1287 name: None,
1288 optional: false,
1289 type_annotation: TypeAnnotation::Basic("number".to_string()),
1290 }],
1291 returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1292 });
1293
1294 let mut constraints = vec![(
1295 Type::Function {
1296 params: vec![Type::Variable(t1.clone())],
1297 returns: Box::new(BuiltinTypes::string()),
1298 },
1299 concrete_func,
1300 )];
1301
1302 solver.solve(&mut constraints).unwrap();
1303
1304 let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1305 assert_eq!(resolved, BuiltinTypes::number());
1306 }
1307
1308 #[test]
1309 fn test_object_annotations_unify_structurally() {
1310 let mut solver = ConstraintSolver::new();
1311 let mut constraints = vec![(
1312 Type::Concrete(TypeAnnotation::Object(vec![
1313 ObjectTypeField {
1314 name: "x".to_string(),
1315 optional: false,
1316 type_annotation: TypeAnnotation::Basic("int".to_string()),
1317 annotations: vec![],
1318 },
1319 ObjectTypeField {
1320 name: "y".to_string(),
1321 optional: false,
1322 type_annotation: TypeAnnotation::Basic("int".to_string()),
1323 annotations: vec![],
1324 },
1325 ])),
1326 Type::Concrete(TypeAnnotation::Object(vec![
1327 ObjectTypeField {
1328 name: "x".to_string(),
1329 optional: false,
1330 type_annotation: TypeAnnotation::Basic("int".to_string()),
1331 annotations: vec![],
1332 },
1333 ObjectTypeField {
1334 name: "y".to_string(),
1335 optional: false,
1336 type_annotation: TypeAnnotation::Basic("int".to_string()),
1337 annotations: vec![],
1338 },
1339 ])),
1340 )];
1341 assert!(solver.solve(&mut constraints).is_ok());
1342 }
1343
1344 #[test]
1345 fn test_intersection_annotations_unify_order_independent() {
1346 let mut solver = ConstraintSolver::new();
1347 let obj_xy = TypeAnnotation::Object(vec![
1348 ObjectTypeField {
1349 name: "x".to_string(),
1350 optional: false,
1351 type_annotation: TypeAnnotation::Basic("int".to_string()),
1352 annotations: vec![],
1353 },
1354 ObjectTypeField {
1355 name: "y".to_string(),
1356 optional: false,
1357 type_annotation: TypeAnnotation::Basic("int".to_string()),
1358 annotations: vec![],
1359 },
1360 ]);
1361 let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1362 name: "z".to_string(),
1363 optional: false,
1364 type_annotation: TypeAnnotation::Basic("int".to_string()),
1365 annotations: vec![],
1366 }]);
1367
1368 let mut constraints = vec![(
1369 Type::Concrete(TypeAnnotation::Intersection(vec![
1370 obj_xy.clone(),
1371 obj_z.clone(),
1372 ])),
1373 Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1374 )];
1375 assert!(solver.solve(&mut constraints).is_ok());
1376 }
1377
1378 #[test]
1381 fn test_implements_trait_satisfied() {
1382 let mut solver = ConstraintSolver::new();
1383 let mut impls = std::collections::HashSet::new();
1384 impls.insert("Comparable::number".to_string());
1385 solver.set_trait_impls(impls);
1386
1387 let bound_var = TypeVar::fresh();
1388 let mut constraints = vec![(
1389 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1390 Type::Constrained {
1391 var: bound_var,
1392 constraint: Box::new(TypeConstraint::ImplementsTrait {
1393 trait_name: "Comparable".to_string(),
1394 }),
1395 },
1396 )];
1397 assert!(solver.solve(&mut constraints).is_ok());
1398 }
1399
1400 #[test]
1401 fn test_implements_trait_violated() {
1402 let mut solver = ConstraintSolver::new();
1403 let bound_var = TypeVar::fresh();
1405 let mut constraints = vec![(
1406 Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1407 Type::Constrained {
1408 var: bound_var,
1409 constraint: Box::new(TypeConstraint::ImplementsTrait {
1410 trait_name: "Comparable".to_string(),
1411 }),
1412 },
1413 )];
1414 let result = solver.solve(&mut constraints);
1415 assert!(result.is_err());
1416 match result.unwrap_err() {
1417 TypeError::TraitBoundViolation {
1418 type_name,
1419 trait_name,
1420 } => {
1421 assert_eq!(type_name, "string");
1422 assert_eq!(trait_name, "Comparable");
1423 }
1424 other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1425 }
1426 }
1427
1428 #[test]
1429 fn test_implements_trait_via_variable_resolution() {
1430 let mut solver = ConstraintSolver::new();
1431 let mut impls = std::collections::HashSet::new();
1432 impls.insert("Sortable::number".to_string());
1433 solver.set_trait_impls(impls);
1434
1435 let type_var = TypeVar::fresh();
1436 let bound_var = TypeVar::fresh();
1437
1438 let mut constraints = vec![
1439 (
1441 Type::Variable(type_var.clone()),
1442 Type::Constrained {
1443 var: bound_var,
1444 constraint: Box::new(TypeConstraint::ImplementsTrait {
1445 trait_name: "Sortable".to_string(),
1446 }),
1447 },
1448 ),
1449 (
1451 Type::Variable(type_var),
1452 Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1453 ),
1454 ];
1455 assert!(
1456 solver.solve(&mut constraints).is_ok(),
1457 "T resolved to number which implements Sortable"
1458 );
1459 }
1460}