Skip to main content

shape_runtime/type_system/
constraints.rs

1//! Type Constraint Solver
2//!
3//! Solves type constraints generated during type inference to determine
4//! concrete types for type variables. The solver operates in three phases:
5//!
6//! ## Phase 1: Eager unification
7//!
8//! Each constraint `(T1, T2)` is attempted immediately via `solve_constraint`.
9//! Simple bindings (variable-to-concrete, variable-to-variable) succeed here.
10//! Constraints that fail (e.g. because a variable is not yet resolved) are
11//! deferred to the next phase.
12//!
13//! ## Phase 2: Fixed-point iteration on deferred constraints
14//!
15//! Deferred constraints are retried in a loop. Each successful resolution may
16//! unlock further deferred constraints by refining substitutions. The loop
17//! terminates when a full pass makes no progress. Any constraints still
18//! unsolved after the fixed-point are reported as `UnsolvedConstraints`.
19//!
20//! ## Phase 3: Bound application
21//!
22//! After all equality constraints are resolved, `apply_bounds` validates
23//! type variable bounds (`Numeric`, `Comparable`, `Iterable`, `HasField`,
24//! `HasMethod`, `ImplementsTrait`). `HasField` constraints additionally
25//! perform backward propagation: when a structural object field is found,
26//! the field's result type variable is bound to the actual field type.
27//!
28//! The solver delegates low-level variable binding and substitution to the
29//! `Unifier` (Robinson's algorithm with path compression).
30
31use super::checking::MethodTable;
32use super::unification::Unifier;
33use super::*;
34use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
35use std::collections::{HashMap, HashSet};
36
37/// Check if a Type::Generic base is "Array" or "Vec".
38fn is_array_or_vec_base(base: &Type) -> bool {
39    match base {
40        Type::Concrete(TypeAnnotation::Reference(name)) => name == "Array" || name == "Vec",
41        Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
42        _ => false,
43    }
44}
45
46pub struct ConstraintSolver {
47    /// Type unifier
48    unifier: Unifier,
49    /// Deferred constraints that couldn't be solved immediately.
50    /// These are handled in solve() via multiple passes.
51    _deferred: Vec<(Type, Type)>,
52    /// Type variable bounds
53    bounds: HashMap<TypeVar, TypeConstraint>,
54    /// Method table for HasMethod constraint enforcement
55    method_table: Option<MethodTable>,
56    /// Trait implementation registry: set of "TraitName::TypeName" keys
57    trait_impls: HashSet<String>,
58}
59
60impl Default for ConstraintSolver {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl ConstraintSolver {
67    pub fn new() -> Self {
68        ConstraintSolver {
69            unifier: Unifier::new(),
70            _deferred: Vec::new(),
71            bounds: HashMap::new(),
72            method_table: None,
73            trait_impls: HashSet::new(),
74        }
75    }
76
77    /// Attach a method table for HasMethod constraint enforcement.
78    /// When set, HasMethod constraints are validated against this table
79    /// instead of being accepted unconditionally.
80    pub fn set_method_table(&mut self, table: MethodTable) {
81        self.method_table = Some(table);
82    }
83
84    /// Register trait implementations for ImplementsTrait constraint enforcement.
85    /// Each entry is a "TraitName::TypeName" key indicating that TypeName implements TraitName.
86    pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
87        self.trait_impls = impls;
88    }
89
90    /// Solve all type constraints
91    pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
92        // First pass: solve simple unification constraints
93        let mut unsolved = Vec::new();
94
95        for (t1, t2) in constraints.drain(..) {
96            if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
97                // If we can't solve it now, defer it
98                unsolved.push((t1, t2));
99            }
100        }
101
102        // Second pass: try deferred constraints
103        let mut made_progress = true;
104        while made_progress && !unsolved.is_empty() {
105            made_progress = false;
106            let mut still_unsolved = Vec::new();
107
108            for (t1, t2) in unsolved.drain(..) {
109                if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
110                    still_unsolved.push((t1, t2));
111                } else {
112                    made_progress = true;
113                }
114            }
115
116            unsolved = still_unsolved;
117        }
118
119        // Check if any constraints remain unsolved
120        if !unsolved.is_empty() {
121            return Err(TypeError::UnsolvedConstraints(unsolved));
122        }
123
124        // Apply bounds to type variables
125        self.apply_bounds()?;
126
127        Ok(())
128    }
129
130    /// Solve a single constraint
131    fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
132        // Apply current substitutions before matching to avoid overwriting
133        // existing bindings (e.g., T17=string overwritten by T17=T19 during
134        // Function param/return pairwise unification).
135        let t1 = self.unifier.apply_substitutions(&t1);
136        let t2 = self.unifier.apply_substitutions(&t2);
137
138        match (&t1, &t2) {
139            // Variable constraints
140            (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
141
142            // Constrained type variables — must be matched BEFORE the general
143            // Variable arm, otherwise (Variable, Constrained) pairs are caught
144            // by the Variable arm and the bound is never recorded.
145            (Type::Constrained { var, constraint }, ty)
146            | (ty, Type::Constrained { var, constraint }) => {
147                // Record the constraint
148                self.bounds.insert(var.clone(), *constraint.clone());
149
150                // Unify with the underlying type
151                self.solve_constraint(Type::Variable(var.clone()), ty.clone())
152            }
153
154            (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
155                // Check occurs check
156                if self.occurs_in(var, ty) {
157                    return Err(TypeError::InfiniteType(var.clone()));
158                }
159
160                self.unifier.bind(var.clone(), ty.clone());
161                Ok(())
162            }
163
164            // Concrete type constraints
165            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
166                if self.unify_annotations(ann1, ann2)? {
167                    Ok(())
168                } else if Self::can_numeric_widen(ann1, ann2) {
169                    // Implicit numeric promotion (int → number/float)
170                    Ok(())
171                } else {
172                    Err(TypeError::TypeMismatch(
173                        format!("{:?}", ann1),
174                        format!("{:?}", ann2),
175                    ))
176                }
177            }
178
179            // Generic type constraints
180            (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
181                self.solve_constraint(*b1.clone(), *b2.clone())?;
182
183                let is_result_base = |base: &Type| match base {
184                    Type::Concrete(TypeAnnotation::Reference(name)) => name == "Result",
185                    Type::Concrete(TypeAnnotation::Basic(name)) => name == "Result",
186                    _ => false,
187                };
188
189                if a1.len() != a2.len() {
190                    if is_result_base(&b1) && is_result_base(&b2) {
191                        match (a1.len(), a2.len()) {
192                            // `Result<T>` is error-agnostic shorthand and should unify
193                            // with `Result<T, E>` by constraining only the success type.
194                            (1, 2) | (2, 1) => {
195                                self.solve_constraint(a1[0].clone(), a2[0].clone())?;
196                                return Ok(());
197                            }
198                            _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
199                        }
200                    } else {
201                        return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
202                    }
203                }
204
205                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
206                    self.solve_constraint(arg1.clone(), arg2.clone())?;
207                }
208
209                Ok(())
210            }
211
212            // Function ~ Function: pairwise unify params + returns
213            (
214                Type::Function {
215                    params: p1,
216                    returns: r1,
217                },
218                Type::Function {
219                    params: p2,
220                    returns: r2,
221                },
222            ) => {
223                if p1.len() != p2.len() {
224                    return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
225                }
226                for (param1, param2) in p1.iter().zip(p2.iter()) {
227                    // Parameter compatibility is checked from observed/actual to
228                    // declared/expected shape so directional numeric widening
229                    // (e.g. int -> number) remains valid in call constraints.
230                    self.solve_constraint(param2.clone(), param1.clone())?;
231                }
232                self.solve_constraint(*r1.clone(), *r2.clone())
233            }
234
235            // Cross-compatibility: Type::Function ~ Concrete(TypeAnnotation::Function)
236            (
237                Type::Function {
238                    params: fp,
239                    returns: fr,
240                },
241                Type::Concrete(TypeAnnotation::Function {
242                    params: cp,
243                    returns: cr,
244                }),
245            )
246            | (
247                Type::Concrete(TypeAnnotation::Function {
248                    params: cp,
249                    returns: cr,
250                }),
251                Type::Function {
252                    params: fp,
253                    returns: fr,
254                },
255            ) => {
256                if fp.len() != cp.len() {
257                    return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
258                }
259                for (f_param, c_param) in fp.iter().zip(cp.iter()) {
260                    self.solve_constraint(
261                        f_param.clone(),
262                        Type::Concrete(c_param.type_annotation.clone()),
263                    )?;
264                }
265                self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
266            }
267
268            // Array<T> (Type::Generic with base "Array" or "Vec") ~ Concrete(Array(T))
269            (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
270            | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
271                if args.len() == 1 && is_array_or_vec_base(base) =>
272            {
273                self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
274            }
275
276            _ => Err(TypeError::TypeMismatch(
277                format!("{:?}", t1),
278                format!("{:?}", t2),
279            )),
280        }
281    }
282
283    /// Check if a type variable occurs in a type (occurs check)
284    fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
285        match ty {
286            Type::Variable(v) => v == var,
287            Type::Generic { base, args } => {
288                self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
289            }
290            Type::Constrained { var: v, .. } => v == var,
291            Type::Function { params, returns } => {
292                params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
293            }
294            Type::Concrete(_) => false,
295        }
296    }
297
298    /// Check if a numeric type can widen to another (directional).
299    ///
300    /// Integer-family types (`int`, `i16`, `u32`, `byte`, ...) can widen to
301    /// number-family types (`number`, `f32`, `f64`, ...).
302    /// `number → int` does NOT widen (lossy). `decimal → number` does NOT widen
303    /// (different precision semantics).
304    fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
305        let from_name = match from {
306            TypeAnnotation::Basic(name) => Some(name.as_str()),
307            TypeAnnotation::Reference(name) => Some(name.as_str()),
308            _ => None,
309        };
310        let to_name = match to {
311            TypeAnnotation::Basic(name) => Some(name.as_str()),
312            TypeAnnotation::Reference(name) => Some(name.as_str()),
313            _ => None,
314        };
315
316        match (from_name, to_name) {
317            (Some(f), Some(t)) => {
318                BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
319            }
320            _ => false,
321        }
322    }
323
324    /// Unify two type annotations
325    fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
326        match (ann1, ann2) {
327            // Basic types
328            (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
329                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
330            }
331            (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
332            (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
333            | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
334                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
335            }
336
337            // Array types
338            (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
339                self.unify_annotations(e1, e2)
340            }
341
342            // Tuple types
343            (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
344                if t1.len() != t2.len() {
345                    return Ok(false);
346                }
347
348                for (elem1, elem2) in t1.iter().zip(t2.iter()) {
349                    if !self.unify_annotations(elem1, elem2)? {
350                        return Ok(false);
351                    }
352                }
353
354                Ok(true)
355            }
356
357            // Structural object types
358            (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
359                self.object_fields_compatible(f1, f2)
360            }
361
362            // Function types
363            (
364                TypeAnnotation::Function {
365                    params: p1,
366                    returns: r1,
367                },
368                TypeAnnotation::Function {
369                    params: p2,
370                    returns: r2,
371                },
372            ) => {
373                if p1.len() != p2.len() {
374                    return Ok(false);
375                }
376
377                for (param1, param2) in p1.iter().zip(p2.iter()) {
378                    if !self.unify_annotations(&param1.type_annotation, &param2.type_annotation)? {
379                        return Ok(false);
380                    }
381                }
382
383                self.unify_annotations(r1, r2)
384            }
385
386            // Union types
387            // A | B unifies with C | D if each type in one union can unify with at least one type in the other
388            (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
389                // Check that every type in u1 can unify with at least one type in u2
390                for t1 in u1 {
391                    let mut found_match = false;
392                    for t2 in u2 {
393                        if self.unify_annotations(t1, t2)? {
394                            found_match = true;
395                            break;
396                        }
397                    }
398                    if !found_match {
399                        return Ok(false);
400                    }
401                }
402                // Check that every type in u2 can unify with at least one type in u1
403                for t2 in u2 {
404                    let mut found_match = false;
405                    for t1 in u1 {
406                        if self.unify_annotations(t1, t2)? {
407                            found_match = true;
408                            break;
409                        }
410                    }
411                    if !found_match {
412                        return Ok(false);
413                    }
414                }
415                Ok(true)
416            }
417
418            // Union with non-union: A | B unifies with C if either A or B unifies with C
419            (TypeAnnotation::Union(union_types), other)
420            | (other, TypeAnnotation::Union(union_types)) => {
421                for union_type in union_types {
422                    if self.unify_annotations(union_type, other)? {
423                        return Ok(true);
424                    }
425                }
426                Ok(false)
427            }
428
429            // Intersection types (order-independent)
430            (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
431                self.unify_annotation_sets(i1, i2)
432            }
433
434            // Void, Null, Undefined
435            (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
436            (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
437            (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
438
439            // Trait object types: dyn Trait1 + Trait2
440            // Two trait objects unify if they have the same set of traits
441            (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
442                Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
443            }
444
445            // Array<T> (Generic) is equivalent to Vec<T> (Array)
446            (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
447            | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
448                if name == "Array" && args.len() == 1 =>
449            {
450                self.unify_annotations(&args[0], elem)
451            }
452
453            // Different types don't unify
454            _ => Ok(false),
455        }
456    }
457
458    fn object_fields_compatible(
459        &self,
460        left: &[ObjectTypeField],
461        right: &[ObjectTypeField],
462    ) -> TypeResult<bool> {
463        for left_field in left {
464            let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
465                return Ok(false);
466            };
467            if left_field.optional != right_field.optional {
468                return Ok(false);
469            }
470            if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
471                return Ok(false);
472            }
473        }
474        if left.len() != right.len() {
475            return Ok(false);
476        }
477        Ok(true)
478    }
479
480    fn unify_annotation_sets(
481        &self,
482        left: &[TypeAnnotation],
483        right: &[TypeAnnotation],
484    ) -> TypeResult<bool> {
485        if left.len() != right.len() {
486            return Ok(false);
487        }
488
489        let mut matched = vec![false; right.len()];
490        for left_ann in left {
491            let mut found = false;
492            for (idx, right_ann) in right.iter().enumerate() {
493                if matched[idx] {
494                    continue;
495                }
496                if self.unify_annotations(left_ann, right_ann)? {
497                    matched[idx] = true;
498                    found = true;
499                    break;
500                }
501            }
502            if !found {
503                return Ok(false);
504            }
505        }
506
507        Ok(true)
508    }
509
510    /// Apply type variable bounds, propagating resolved field types back to type variables.
511    ///
512    /// When a `HasField` constraint is satisfied and the expected field type was a
513    /// type variable, this binds that variable to the actual field type. This enables
514    /// backward propagation: `let f = |obj| obj.x; f({x: 42})` resolves `obj.x` to `int`.
515    fn apply_bounds(&mut self) -> TypeResult<()> {
516        let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
517
518        for (var, constraint) in &self.bounds {
519            // Use apply_substitutions to follow the full variable chain
520            // (lookup only returns the direct binding, not the resolved type).
521            let resolved = self
522                .unifier
523                .apply_substitutions(&Type::Variable(var.clone()));
524
525            if let Type::Variable(_) = &resolved {
526                // Still unresolved — skip for now
527                continue;
528            }
529
530            self.check_constraint(&resolved, constraint)?;
531
532            // Backward propagation: when a HasField constraint is satisfied,
533            // bind the result type variable to the actual field type.
534            if let TypeConstraint::HasField(field, expected_field_type) = constraint {
535                if let Type::Variable(field_var) = expected_field_type.as_ref() {
536                    // Also check if the field var is already resolved
537                    let field_resolved = self
538                        .unifier
539                        .apply_substitutions(&Type::Variable(field_var.clone()));
540                    if let Type::Variable(_) = &field_resolved {
541                        // Field var still unresolved — try to bind it
542                        if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
543                            if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
544                                new_bindings.push((
545                                    field_var.clone(),
546                                    Type::Concrete(found_field.type_annotation.clone()),
547                                ));
548                            }
549                        }
550                    }
551                }
552            }
553        }
554
555        // Apply collected bindings
556        for (var, ty) in new_bindings {
557            self.unifier.bind(var, ty);
558        }
559
560        Ok(())
561    }
562
563    /// Check if a type satisfies a constraint
564    fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
565        match constraint {
566            TypeConstraint::Comparable => match ty {
567                Type::Concrete(TypeAnnotation::Basic(name))
568                    if BuiltinTypes::is_numeric_type_name(name)
569                        || name == "string"
570                        || name == "bool" =>
571                {
572                    Ok(())
573                }
574                _ => Err(TypeError::ConstraintViolation(format!(
575                    "{:?} is not comparable",
576                    ty
577                ))),
578            },
579
580            TypeConstraint::Iterable => match ty {
581                Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
582                Type::Concrete(TypeAnnotation::Basic(name))
583                    if name == "string" || name == "rows" =>
584                {
585                    Ok(())
586                }
587                _ => Err(TypeError::ConstraintViolation(format!(
588                    "{:?} is not iterable",
589                    ty
590                ))),
591            },
592
593            TypeConstraint::HasField(field, expected_field_type) => {
594                match ty {
595                    Type::Concrete(TypeAnnotation::Object(fields)) => {
596                        match fields.iter().find(|f| f.name == *field) {
597                            Some(found_field) => {
598                                // Check that field type matches expected type
599                                if let Some(expected_ann) = expected_field_type.to_annotation() {
600                                    if self.unify_annotations(
601                                        &found_field.type_annotation,
602                                        &expected_ann,
603                                    )? {
604                                        Ok(())
605                                    } else {
606                                        Err(TypeError::ConstraintViolation(format!(
607                                            "field '{}' has type {:?}, expected {:?}",
608                                            field, found_field.type_annotation, expected_ann
609                                        )))
610                                    }
611                                } else {
612                                    // Expected type is a type variable, accept any field type
613                                    Ok(())
614                                }
615                            }
616                            None => Err(TypeError::ConstraintViolation(format!(
617                                "{:?} does not have field '{}'",
618                                ty, field
619                            ))),
620                        }
621                    }
622                    Type::Concrete(TypeAnnotation::Basic(_name)) => {
623                        // For named types, we assume property access was validated
624                        // during inference using the schema registry. If a HasField
625                        // constraint reaches here, it means the type wasn't a known
626                        // schema type during inference, so we accept it tentatively.
627                        // Runtime will do the final validation.
628                        //
629                        // Note: Previously this hardcoded "row" with OHLCV fields.
630                        // Now schema validation happens in TypeInferenceEngine::infer_property_access.
631                        Ok(())
632                    }
633                    _ => Err(TypeError::ConstraintViolation(format!(
634                        "{:?} cannot have fields",
635                        ty
636                    ))),
637                }
638            }
639
640            TypeConstraint::Callable {
641                params: expected_params,
642                returns: expected_returns,
643            } => {
644                match ty {
645                    Type::Concrete(TypeAnnotation::Function {
646                        params: actual_params,
647                        returns: actual_returns,
648                    }) => {
649                        // Check parameter count matches
650                        if expected_params.len() != actual_params.len() {
651                            return Err(TypeError::ConstraintViolation(format!(
652                                "function expects {} parameters, got {}",
653                                expected_params.len(),
654                                actual_params.len()
655                            )));
656                        }
657
658                        // Check each parameter type (contravariant: expected <: actual)
659                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
660                            if let Some(expected_ann) = expected.to_annotation() {
661                                if !self
662                                    .unify_annotations(&expected_ann, &actual.type_annotation)?
663                                {
664                                    return Err(TypeError::ConstraintViolation(format!(
665                                        "parameter type mismatch: expected {:?}, got {:?}",
666                                        expected_ann, actual.type_annotation
667                                    )));
668                                }
669                            }
670                        }
671
672                        // Check return type (covariant: actual <: expected)
673                        if let Some(expected_ret_ann) = expected_returns.to_annotation() {
674                            if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
675                                return Err(TypeError::ConstraintViolation(format!(
676                                    "return type mismatch: expected {:?}, got {:?}",
677                                    expected_ret_ann, actual_returns
678                                )));
679                            }
680                        }
681
682                        Ok(())
683                    }
684                    Type::Function {
685                        params: actual_params,
686                        returns: actual_returns,
687                    } => {
688                        if expected_params.len() != actual_params.len() {
689                            return Err(TypeError::ConstraintViolation(format!(
690                                "function expects {} parameters, got {}",
691                                expected_params.len(),
692                                actual_params.len()
693                            )));
694                        }
695                        // Type::Function params are Type, not FunctionParam — compare directly
696                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
697                            if let (Some(e_ann), Some(a_ann)) =
698                                (expected.to_annotation(), actual.to_annotation())
699                            {
700                                if !self.unify_annotations(&e_ann, &a_ann)? {
701                                    return Err(TypeError::ConstraintViolation(format!(
702                                        "parameter type mismatch: expected {:?}, got {:?}",
703                                        e_ann, a_ann
704                                    )));
705                                }
706                            }
707                        }
708                        if let (Some(e_ret), Some(a_ret)) = (
709                            expected_returns.to_annotation(),
710                            actual_returns.to_annotation(),
711                        ) {
712                            if !self.unify_annotations(&a_ret, &e_ret)? {
713                                return Err(TypeError::ConstraintViolation(format!(
714                                    "return type mismatch: expected {:?}, got {:?}",
715                                    e_ret, a_ret
716                                )));
717                            }
718                        }
719                        Ok(())
720                    }
721                    _ => Err(TypeError::ConstraintViolation(format!(
722                        "{:?} is not callable",
723                        ty
724                    ))),
725                }
726            }
727
728            TypeConstraint::OneOf(options) => {
729                for option in options {
730                    // If type matches any option, constraint is satisfied
731                    if let Type::Concrete(ann) = option {
732                        if let Type::Concrete(ty_ann) = ty {
733                            if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
734                                return Ok(());
735                            }
736                        }
737                    }
738                }
739
740                Err(TypeError::ConstraintViolation(format!(
741                    "{:?} does not match any of {:?}",
742                    ty, options
743                )))
744            }
745
746            TypeConstraint::Extends(base) => {
747                // Implement subtyping check
748                self.is_subtype(ty, base)
749            }
750
751            TypeConstraint::ImplementsTrait { trait_name } => {
752                match ty {
753                    Type::Variable(_) => {
754                        // Type variable not yet resolved — this is a compile error
755                        // (no deferring per Sprint 2 spec)
756                        Err(TypeError::TraitBoundViolation {
757                            type_name: format!("{:?}", ty),
758                            trait_name: trait_name.clone(),
759                        })
760                    }
761                    Type::Concrete(ann) => {
762                        let type_name = match ann {
763                            TypeAnnotation::Basic(n) => n.clone(),
764                            TypeAnnotation::Reference(n) => n.to_string(),
765                            _ => format!("{:?}", ann),
766                        };
767                        if self.has_trait_impl(trait_name, &type_name) {
768                            Ok(())
769                        } else {
770                            Err(TypeError::TraitBoundViolation {
771                                type_name,
772                                trait_name: trait_name.clone(),
773                            })
774                        }
775                    }
776                    Type::Generic { base, .. } => {
777                        let type_name = match base.as_ref() {
778                            Type::Concrete(TypeAnnotation::Reference(n)) => n.to_string(),
779                            Type::Concrete(TypeAnnotation::Basic(n)) => n.clone(),
780                            _ => format!("{:?}", base),
781                        };
782                        if self.has_trait_impl(trait_name, &type_name) {
783                            Ok(())
784                        } else {
785                            Err(TypeError::TraitBoundViolation {
786                                type_name,
787                                trait_name: trait_name.clone(),
788                            })
789                        }
790                    }
791                    _ => Err(TypeError::TraitBoundViolation {
792                        type_name: format!("{:?}", ty),
793                        trait_name: trait_name.clone(),
794                    }),
795                }
796            }
797
798            TypeConstraint::HasMethod {
799                method_name,
800                arg_types: _,
801                return_type: _,
802            } => {
803                // If we have a method table, enforce the constraint
804                if let Some(method_table) = &self.method_table {
805                    match ty {
806                        Type::Variable(_) => Ok(()), // Unresolved type var, defer
807                        Type::Concrete(ann) => {
808                            let type_name = match ann {
809                                TypeAnnotation::Basic(n) => n.clone(),
810                                TypeAnnotation::Reference(n) => n.to_string(),
811                                TypeAnnotation::Array(_) => "Vec".to_string(),
812                                _ => return Ok(()), // Complex types: accept
813                            };
814                            if method_table.lookup(ty, method_name).is_some() {
815                                Ok(())
816                            } else {
817                                Err(TypeError::MethodNotFound {
818                                    type_name,
819                                    method_name: method_name.clone(),
820                                })
821                            }
822                        }
823                        Type::Generic { base, .. } => {
824                            if method_table.lookup(ty, method_name).is_some() {
825                                Ok(())
826                            } else {
827                                let type_name =
828                                    if let Type::Concrete(TypeAnnotation::Reference(n)) =
829                                        base.as_ref()
830                                    {
831                                        n.to_string()
832                                    } else {
833                                        format!("{:?}", base)
834                                    };
835                                Err(TypeError::MethodNotFound {
836                                    type_name,
837                                    method_name: method_name.clone(),
838                                })
839                            }
840                        }
841                        _ => Ok(()), // Function, Constrained: accept
842                    }
843                } else {
844                    // No method table attached — accept all (backward compatible)
845                    Ok(())
846                }
847            }
848        }
849    }
850
851    /// Check if a type implements a trait, considering aliases and numeric widening.
852    ///
853    /// Handles three resolution strategies:
854    /// 1. Direct lookup: `"Numeric::int"` in the trait_impls set
855    /// 2. Canonical alias: `"Float"` → `"f64"`, `"byte"` → `"u8"` via runtime name table
856    /// 3. Script alias: `"i64"` → `"int"`, `"f64"` → `"number"` via script alias table
857    /// 4. Numeric widening: integer-family names can satisfy number/float/f64 impls
858    fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
859        let key = format!("{}::{}", trait_name, type_name);
860        if self.trait_impls.contains(&key) {
861            return true;
862        }
863        // Try canonical runtime alias (e.g. "Float" -> "f64", "byte" -> "u8")
864        if let Some(canonical) = BuiltinTypes::canonical_numeric_runtime_name(type_name) {
865            let canon_key = format!("{}::{}", trait_name, canonical);
866            if self.trait_impls.contains(&canon_key) {
867                return true;
868            }
869        }
870        // Try script-facing alias (e.g. "i64" -> "int", "f64" -> "number")
871        if let Some(script_alias) = BuiltinTypes::canonical_script_alias(type_name) {
872            let alias_key = format!("{}::{}", trait_name, script_alias);
873            if self.trait_impls.contains(&alias_key) {
874                return true;
875            }
876        }
877        // Numeric widening: integer-family aliases can use number/float/f64 impls.
878        if BuiltinTypes::is_integer_type_name(type_name) {
879            for widen_to in &["number", "float", "f64"] {
880                let widen_key = format!("{}::{}", trait_name, widen_to);
881                if self.trait_impls.contains(&widen_key) {
882                    return true;
883                }
884            }
885        }
886        false
887    }
888
889    /// Check if ty is a subtype of base (ty <: base)
890    /// Subtyping rules:
891    /// - Same types are subtypes of each other
892    /// - Any is a supertype of everything
893    /// - Vec<A> <: Vec<B> if A <: B (covariant)
894    /// - Function<P1, R1> <: Function<P2, R2> if P2 <: P1 (contravariant params) and R1 <: R2 (covariant return)
895    fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
896        match (ty, base) {
897            // Same types are subtypes
898            (t1, t2) if t1 == t2 => Ok(()),
899
900            // Type variables - if we can unify, it's compatible
901            (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
902
903            // Array subtyping (covariant)
904            (
905                Type::Concrete(TypeAnnotation::Array(elem1)),
906                Type::Concrete(TypeAnnotation::Array(elem2)),
907            ) => {
908                let t1 = Type::Concrete(*elem1.clone());
909                let t2 = Type::Concrete(*elem2.clone());
910                self.is_subtype(&t1, &t2)
911            }
912
913            // Function subtyping (contravariant params, covariant return)
914            (
915                Type::Concrete(TypeAnnotation::Function {
916                    params: p1,
917                    returns: r1,
918                }),
919                Type::Concrete(TypeAnnotation::Function {
920                    params: p2,
921                    returns: r2,
922                }),
923            ) => {
924                // Check parameter count
925                if p1.len() != p2.len() {
926                    return Err(TypeError::ConstraintViolation(format!(
927                        "function parameter count mismatch: {} vs {}",
928                        p1.len(),
929                        p2.len()
930                    )));
931                }
932
933                // Contravariant: base params must be subtypes of ty params
934                for (param1, param2) in p1.iter().zip(p2.iter()) {
935                    let t1 = Type::Concrete(param2.type_annotation.clone());
936                    let t2 = Type::Concrete(param1.type_annotation.clone());
937                    self.is_subtype(&t1, &t2)?;
938                }
939
940                // Covariant: ty return must be subtype of base return
941                let ret1 = Type::Concrete(*r1.clone());
942                let ret2 = Type::Concrete(*r2.clone());
943                self.is_subtype(&ret1, &ret2)
944            }
945
946            // Optional subtyping: T <: Option<T>
947            (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
948                if name == "Option" && args.len() == 1 =>
949            {
950                let inner = Type::Concrete(args[0].clone());
951                self.is_subtype(t, &inner)
952            }
953
954            // Type::Function subtyping (contravariant params, covariant return)
955            (
956                Type::Function {
957                    params: p1,
958                    returns: r1,
959                },
960                Type::Function {
961                    params: p2,
962                    returns: r2,
963                },
964            ) => {
965                if p1.len() != p2.len() {
966                    return Err(TypeError::ConstraintViolation(format!(
967                        "function parameter count mismatch: {} vs {}",
968                        p1.len(),
969                        p2.len()
970                    )));
971                }
972                // Contravariant params
973                for (param1, param2) in p1.iter().zip(p2.iter()) {
974                    self.is_subtype(param2, param1)?;
975                }
976                // Covariant return
977                self.is_subtype(r1, r2)
978            }
979
980            // Basic types - check if they unify
981            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
982                if self.unify_annotations(ann1, ann2)? {
983                    Ok(())
984                } else {
985                    Err(TypeError::ConstraintViolation(format!(
986                        "{:?} is not a subtype of {:?}",
987                        ty, base
988                    )))
989                }
990            }
991
992            // Default: not a subtype
993            _ => Err(TypeError::ConstraintViolation(format!(
994                "{:?} is not a subtype of {:?}",
995                ty, base
996            ))),
997        }
998    }
999
1000    /// Get the unifier for applying substitutions
1001    pub fn unifier(&self) -> &Unifier {
1002        &self.unifier
1003    }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008    use super::*;
1009    use crate::type_system::TypeVarGen;
1010    use shape_ast::ast::ObjectTypeField;
1011
1012    /// Test-local helper: allocate a fresh type variable from a
1013    /// per-test counter. Each test owns its own `TypeVarGen`, so IDs
1014    /// (`T0`, `T1`, ...) are deterministic and independent across tests.
1015    fn fresh_var(tvgen: &mut TypeVarGen) -> TypeVar {
1016        tvgen.fresh_var()
1017    }
1018
1019    fn fresh_type(tvgen: &mut TypeVarGen) -> Type {
1020        tvgen.fresh_type()
1021    }
1022
1023    #[test]
1024    fn test_hasfield_backward_propagation_binds_field_type() {
1025        // When a TypeVar has a HasField constraint and is resolved to a concrete
1026        // object type, the field's result type variable should be bound to the
1027        // actual field type. This enables backward type propagation.
1028        let mut solver = ConstraintSolver::new();
1029        let mut tvgen = TypeVarGen::new();
1030
1031        let obj_var = fresh_var(&mut tvgen);
1032        let field_result_var = fresh_var(&mut tvgen);
1033        let bound_var = fresh_var(&mut tvgen);
1034
1035        let mut constraints = vec![
1036            // obj_var ~ Constrained { var: bound_var, HasField("x", field_result_var) }
1037            // This records bound: bound_var → HasField("x", field_result_var)
1038            // and solves: bound_var ~ obj_var
1039            (
1040                Type::Variable(obj_var.clone()),
1041                Type::Constrained {
1042                    var: bound_var,
1043                    constraint: Box::new(TypeConstraint::HasField(
1044                        "x".to_string(),
1045                        Box::new(Type::Variable(field_result_var.clone())),
1046                    )),
1047                },
1048            ),
1049            // obj_var = {x: int}
1050            (
1051                Type::Variable(obj_var),
1052                Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1053                    name: "x".to_string(),
1054                    optional: false,
1055                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1056                    annotations: vec![],
1057                }])),
1058            ),
1059        ];
1060
1061        solver.solve(&mut constraints).unwrap();
1062
1063        // field_result_var should now be resolved to int via apply_bounds
1064        let resolved = solver
1065            .unifier()
1066            .apply_substitutions(&Type::Variable(field_result_var));
1067        match &resolved {
1068            Type::Concrete(TypeAnnotation::Basic(name)) => {
1069                assert_eq!(name, "int", "field type should be int");
1070            }
1071            _ => panic!(
1072                "Expected field_result_var to be resolved to int, got {:?}",
1073                resolved
1074            ),
1075        }
1076    }
1077
1078    #[test]
1079    fn test_hasfield_backward_propagation_multiple_fields() {
1080        // Test that multiple HasField constraints on the same object all propagate
1081        let mut solver = ConstraintSolver::new();
1082        let mut tvgen = TypeVarGen::new();
1083
1084        let obj_var = fresh_var(&mut tvgen);
1085        let field_x_var = fresh_var(&mut tvgen);
1086        let field_y_var = fresh_var(&mut tvgen);
1087        let bound_var_x = fresh_var(&mut tvgen);
1088        let bound_var_y = fresh_var(&mut tvgen);
1089
1090        let mut constraints = vec![
1091            // HasField("x", field_x_var)
1092            (
1093                Type::Variable(obj_var.clone()),
1094                Type::Constrained {
1095                    var: bound_var_x,
1096                    constraint: Box::new(TypeConstraint::HasField(
1097                        "x".to_string(),
1098                        Box::new(Type::Variable(field_x_var.clone())),
1099                    )),
1100                },
1101            ),
1102            // HasField("y", field_y_var)
1103            (
1104                Type::Variable(obj_var.clone()),
1105                Type::Constrained {
1106                    var: bound_var_y,
1107                    constraint: Box::new(TypeConstraint::HasField(
1108                        "y".to_string(),
1109                        Box::new(Type::Variable(field_y_var.clone())),
1110                    )),
1111                },
1112            ),
1113            // obj_var = {x: int, y: string}
1114            (
1115                Type::Variable(obj_var),
1116                Type::Concrete(TypeAnnotation::Object(vec![
1117                    ObjectTypeField {
1118                        name: "x".to_string(),
1119                        optional: false,
1120                        type_annotation: TypeAnnotation::Basic("int".to_string()),
1121                        annotations: vec![],
1122                    },
1123                    ObjectTypeField {
1124                        name: "y".to_string(),
1125                        optional: false,
1126                        type_annotation: TypeAnnotation::Basic("string".to_string()),
1127                        annotations: vec![],
1128                    },
1129                ])),
1130            ),
1131        ];
1132
1133        solver.solve(&mut constraints).unwrap();
1134
1135        let resolved_x = solver
1136            .unifier()
1137            .apply_substitutions(&Type::Variable(field_x_var));
1138        let resolved_y = solver
1139            .unifier()
1140            .apply_substitutions(&Type::Variable(field_y_var));
1141
1142        match &resolved_x {
1143            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1144            _ => panic!("Expected x to be int, got {:?}", resolved_x),
1145        }
1146        match &resolved_y {
1147            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1148            _ => panic!("Expected y to be string, got {:?}", resolved_y),
1149        }
1150    }
1151
1152    // ===== Fix 1: Numeric type preservation tests =====
1153
1154    #[test]
1155    fn test_int_constrained_numeric_succeeds() {
1156        // Concrete(int) ~ Constrained(ImplementsTrait("Numeric")) should succeed
1157        let mut solver = ConstraintSolver::new();
1158        // Inject Numeric trait impls (same as TypeEnvironment registers)
1159        let trait_impls: std::collections::HashSet<String> = [
1160            "Numeric::int", "Numeric::number", "Numeric::decimal",
1161            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1162            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1163            "Numeric::f32", "Numeric::f64",
1164        ].iter().map(|s| s.to_string()).collect();
1165        solver.set_trait_impls(trait_impls);
1166        let mut tvgen = TypeVarGen::new();
1167        let bound_var = fresh_var(&mut tvgen);
1168        let mut constraints = vec![(
1169            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1170            Type::Constrained {
1171                var: bound_var,
1172                constraint: Box::new(TypeConstraint::ImplementsTrait {
1173                    trait_name: "Numeric".to_string(),
1174                }),
1175            },
1176        )];
1177        assert!(solver.solve(&mut constraints).is_ok());
1178    }
1179
1180    #[test]
1181    fn test_numeric_widening_int_to_number() {
1182        // (Concrete(int), Concrete(number)) should succeed via widening
1183        let mut solver = ConstraintSolver::new();
1184        let mut constraints = vec![(
1185            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1186            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1187        )];
1188        assert!(solver.solve(&mut constraints).is_ok());
1189    }
1190
1191    #[test]
1192    fn test_numeric_widening_width_aware_integer_to_float_family() {
1193        let mut solver = ConstraintSolver::new();
1194        let mut constraints = vec![(
1195            Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1196            Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1197        )];
1198        assert!(solver.solve(&mut constraints).is_ok());
1199    }
1200
1201    #[test]
1202    fn test_no_widening_number_to_int() {
1203        // (Concrete(number), Concrete(int)) should fail — lossy
1204        let mut solver = ConstraintSolver::new();
1205        let mut constraints = vec![(
1206            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1207            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1208        )];
1209        assert!(solver.solve(&mut constraints).is_err());
1210    }
1211
1212    #[test]
1213    fn test_decimal_constrained_numeric_succeeds() {
1214        let mut solver = ConstraintSolver::new();
1215        let trait_impls: std::collections::HashSet<String> = [
1216            "Numeric::int", "Numeric::number", "Numeric::decimal",
1217            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1218            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1219            "Numeric::f32", "Numeric::f64",
1220        ].iter().map(|s| s.to_string()).collect();
1221        solver.set_trait_impls(trait_impls);
1222        let mut tvgen = TypeVarGen::new();
1223        let bound_var = fresh_var(&mut tvgen);
1224        let mut constraints = vec![(
1225            Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1226            Type::Constrained {
1227                var: bound_var,
1228                constraint: Box::new(TypeConstraint::ImplementsTrait {
1229                    trait_name: "Numeric".to_string(),
1230                }),
1231            },
1232        )];
1233        assert!(solver.solve(&mut constraints).is_ok());
1234    }
1235
1236    #[test]
1237    fn test_comparable_accepts_int() {
1238        // int should be Comparable
1239        let mut solver = ConstraintSolver::new();
1240        let mut tvgen = TypeVarGen::new();
1241        let bound_var = fresh_var(&mut tvgen);
1242        let mut constraints = vec![(
1243            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1244            Type::Constrained {
1245                var: bound_var,
1246                constraint: Box::new(TypeConstraint::Comparable),
1247            },
1248        )];
1249        assert!(solver.solve(&mut constraints).is_ok());
1250    }
1251
1252    // ===== Fix 2: Type::Function tests =====
1253
1254    #[test]
1255    fn test_function_type_preserves_variables() {
1256        // BuiltinTypes::function with Variable params should be Type::Function
1257        let mut tvgen = TypeVarGen::new();
1258        let param = fresh_type(&mut tvgen);
1259        let ret = fresh_type(&mut tvgen);
1260        let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1261        match func {
1262            Type::Function { params, returns } => {
1263                assert_eq!(params.len(), 1);
1264                assert_eq!(params[0], param);
1265                assert_eq!(*returns, ret);
1266            }
1267            _ => panic!("Expected Type::Function, got {:?}", func),
1268        }
1269    }
1270
1271    #[test]
1272    fn test_function_unification_binds_variables() {
1273        // (T1)->T2 ~ (number)->string should bind T1=number, T2=string
1274        let mut solver = ConstraintSolver::new();
1275        let mut tvgen = TypeVarGen::new();
1276        let t1 = fresh_var(&mut tvgen);
1277        let t2 = fresh_var(&mut tvgen);
1278
1279        let mut constraints = vec![(
1280            Type::Function {
1281                params: vec![Type::Variable(t1.clone())],
1282                returns: Box::new(Type::Variable(t2.clone())),
1283            },
1284            Type::Function {
1285                params: vec![BuiltinTypes::number()],
1286                returns: Box::new(BuiltinTypes::string()),
1287            },
1288        )];
1289
1290        solver.solve(&mut constraints).unwrap();
1291
1292        let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1293        let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1294        assert_eq!(resolved_t1, BuiltinTypes::number());
1295        assert_eq!(resolved_t2, BuiltinTypes::string());
1296    }
1297
1298    #[test]
1299    fn test_function_cross_unification_with_concrete() {
1300        // Type::Function ~ Concrete(TypeAnnotation::Function) should unify
1301        let mut solver = ConstraintSolver::new();
1302        let mut tvgen = TypeVarGen::new();
1303        let t1 = fresh_var(&mut tvgen);
1304
1305        let concrete_func = Type::Concrete(TypeAnnotation::Function {
1306            params: vec![shape_ast::ast::FunctionParam {
1307                name: None,
1308                optional: false,
1309                type_annotation: TypeAnnotation::Basic("number".to_string()),
1310            }],
1311            returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1312        });
1313
1314        let mut constraints = vec![(
1315            Type::Function {
1316                params: vec![Type::Variable(t1.clone())],
1317                returns: Box::new(BuiltinTypes::string()),
1318            },
1319            concrete_func,
1320        )];
1321
1322        solver.solve(&mut constraints).unwrap();
1323
1324        let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1325        assert_eq!(resolved, BuiltinTypes::number());
1326    }
1327
1328    #[test]
1329    fn test_object_annotations_unify_structurally() {
1330        let mut solver = ConstraintSolver::new();
1331        let mut constraints = vec![(
1332            Type::Concrete(TypeAnnotation::Object(vec![
1333                ObjectTypeField {
1334                    name: "x".to_string(),
1335                    optional: false,
1336                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1337                    annotations: vec![],
1338                },
1339                ObjectTypeField {
1340                    name: "y".to_string(),
1341                    optional: false,
1342                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1343                    annotations: vec![],
1344                },
1345            ])),
1346            Type::Concrete(TypeAnnotation::Object(vec![
1347                ObjectTypeField {
1348                    name: "x".to_string(),
1349                    optional: false,
1350                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1351                    annotations: vec![],
1352                },
1353                ObjectTypeField {
1354                    name: "y".to_string(),
1355                    optional: false,
1356                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1357                    annotations: vec![],
1358                },
1359            ])),
1360        )];
1361        assert!(solver.solve(&mut constraints).is_ok());
1362    }
1363
1364    #[test]
1365    fn test_intersection_annotations_unify_order_independent() {
1366        let mut solver = ConstraintSolver::new();
1367        let obj_xy = TypeAnnotation::Object(vec![
1368            ObjectTypeField {
1369                name: "x".to_string(),
1370                optional: false,
1371                type_annotation: TypeAnnotation::Basic("int".to_string()),
1372                annotations: vec![],
1373            },
1374            ObjectTypeField {
1375                name: "y".to_string(),
1376                optional: false,
1377                type_annotation: TypeAnnotation::Basic("int".to_string()),
1378                annotations: vec![],
1379            },
1380        ]);
1381        let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1382            name: "z".to_string(),
1383            optional: false,
1384            type_annotation: TypeAnnotation::Basic("int".to_string()),
1385            annotations: vec![],
1386        }]);
1387
1388        let mut constraints = vec![(
1389            Type::Concrete(TypeAnnotation::Intersection(vec![
1390                obj_xy.clone(),
1391                obj_z.clone(),
1392            ])),
1393            Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1394        )];
1395        assert!(solver.solve(&mut constraints).is_ok());
1396    }
1397
1398    // ===== Sprint 2: ImplementsTrait constraint tests =====
1399
1400    #[test]
1401    fn test_implements_trait_satisfied() {
1402        let mut solver = ConstraintSolver::new();
1403        let mut impls = std::collections::HashSet::new();
1404        impls.insert("Comparable::number".to_string());
1405        solver.set_trait_impls(impls);
1406
1407        let mut tvgen = TypeVarGen::new();
1408        let bound_var = fresh_var(&mut tvgen);
1409        let mut constraints = vec![(
1410            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1411            Type::Constrained {
1412                var: bound_var,
1413                constraint: Box::new(TypeConstraint::ImplementsTrait {
1414                    trait_name: "Comparable".to_string(),
1415                }),
1416            },
1417        )];
1418        assert!(solver.solve(&mut constraints).is_ok());
1419    }
1420
1421    #[test]
1422    fn test_implements_trait_violated() {
1423        let mut solver = ConstraintSolver::new();
1424        // No trait impls registered — string doesn't implement Comparable
1425        let mut tvgen = TypeVarGen::new();
1426        let bound_var = fresh_var(&mut tvgen);
1427        let mut constraints = vec![(
1428            Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1429            Type::Constrained {
1430                var: bound_var,
1431                constraint: Box::new(TypeConstraint::ImplementsTrait {
1432                    trait_name: "Comparable".to_string(),
1433                }),
1434            },
1435        )];
1436        let result = solver.solve(&mut constraints);
1437        assert!(result.is_err());
1438        match result.unwrap_err() {
1439            TypeError::TraitBoundViolation {
1440                type_name,
1441                trait_name,
1442            } => {
1443                assert_eq!(type_name, "string");
1444                assert_eq!(trait_name, "Comparable");
1445            }
1446            other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1447        }
1448    }
1449
1450    #[test]
1451    fn test_implements_trait_via_variable_resolution() {
1452        let mut solver = ConstraintSolver::new();
1453        let mut impls = std::collections::HashSet::new();
1454        impls.insert("Sortable::number".to_string());
1455        solver.set_trait_impls(impls);
1456
1457        let mut tvgen = TypeVarGen::new();
1458        let type_var = fresh_var(&mut tvgen);
1459        let bound_var = fresh_var(&mut tvgen);
1460
1461        let mut constraints = vec![
1462            // T: Sortable
1463            (
1464                Type::Variable(type_var.clone()),
1465                Type::Constrained {
1466                    var: bound_var,
1467                    constraint: Box::new(TypeConstraint::ImplementsTrait {
1468                        trait_name: "Sortable".to_string(),
1469                    }),
1470                },
1471            ),
1472            // T = number
1473            (
1474                Type::Variable(type_var),
1475                Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1476            ),
1477        ];
1478        assert!(
1479            solver.solve(&mut constraints).is_ok(),
1480            "T resolved to number which implements Sortable"
1481        );
1482    }
1483}