Skip to main content

shape_runtime/type_system/
constraints.rs

1//! Type Constraint Solver
2//!
3//! Solves type constraints generated during type inference to determine
4//! concrete types for type variables. The solver operates in three phases:
5//!
6//! ## Phase 1: Eager unification
7//!
8//! Each constraint `(T1, T2)` is attempted immediately via `solve_constraint`.
9//! Simple bindings (variable-to-concrete, variable-to-variable) succeed here.
10//! Constraints that fail (e.g. because a variable is not yet resolved) are
11//! deferred to the next phase.
12//!
13//! ## Phase 2: Fixed-point iteration on deferred constraints
14//!
15//! Deferred constraints are retried in a loop. Each successful resolution may
16//! unlock further deferred constraints by refining substitutions. The loop
17//! terminates when a full pass makes no progress. Any constraints still
18//! unsolved after the fixed-point are reported as `UnsolvedConstraints`.
19//!
20//! ## Phase 3: Bound application
21//!
22//! After all equality constraints are resolved, `apply_bounds` validates
23//! type variable bounds (`Numeric`, `Comparable`, `Iterable`, `HasField`,
24//! `HasMethod`, `ImplementsTrait`). `HasField` constraints additionally
25//! perform backward propagation: when a structural object field is found,
26//! the field's result type variable is bound to the actual field type.
27//!
28//! The solver delegates low-level variable binding and substitution to the
29//! `Unifier` (Robinson's algorithm with path compression).
30
31use super::checking::MethodTable;
32use super::unification::Unifier;
33use super::*;
34use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
35use std::collections::{HashMap, HashSet};
36
37/// Check if a Type::Generic base is "Array" or "Vec".
38fn is_array_or_vec_base(base: &Type) -> bool {
39    match base {
40        Type::Concrete(TypeAnnotation::Reference(name)) => name == "Array" || name == "Vec",
41        Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
42        _ => false,
43    }
44}
45
46pub struct ConstraintSolver {
47    /// Type unifier
48    unifier: Unifier,
49    /// Deferred constraints that couldn't be solved immediately.
50    /// These are handled in solve() via multiple passes.
51    _deferred: Vec<(Type, Type)>,
52    /// Type variable bounds
53    bounds: HashMap<TypeVar, TypeConstraint>,
54    /// Method table for HasMethod constraint enforcement
55    method_table: Option<MethodTable>,
56    /// Trait implementation registry: set of "TraitName::TypeName" keys
57    trait_impls: HashSet<String>,
58}
59
60impl Default for ConstraintSolver {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl ConstraintSolver {
67    pub fn new() -> Self {
68        ConstraintSolver {
69            unifier: Unifier::new(),
70            _deferred: Vec::new(),
71            bounds: HashMap::new(),
72            method_table: None,
73            trait_impls: HashSet::new(),
74        }
75    }
76
77    /// Attach a method table for HasMethod constraint enforcement.
78    /// When set, HasMethod constraints are validated against this table
79    /// instead of being accepted unconditionally.
80    pub fn set_method_table(&mut self, table: MethodTable) {
81        self.method_table = Some(table);
82    }
83
84    /// Register trait implementations for ImplementsTrait constraint enforcement.
85    /// Each entry is a "TraitName::TypeName" key indicating that TypeName implements TraitName.
86    pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
87        self.trait_impls = impls;
88    }
89
90    /// Solve all type constraints
91    pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
92        // First pass: solve simple unification constraints
93        let mut unsolved = Vec::new();
94
95        for (t1, t2) in constraints.drain(..) {
96            if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
97                // If we can't solve it now, defer it
98                unsolved.push((t1, t2));
99            }
100        }
101
102        // Second pass: try deferred constraints
103        let mut made_progress = true;
104        while made_progress && !unsolved.is_empty() {
105            made_progress = false;
106            let mut still_unsolved = Vec::new();
107
108            for (t1, t2) in unsolved.drain(..) {
109                if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
110                    still_unsolved.push((t1, t2));
111                } else {
112                    made_progress = true;
113                }
114            }
115
116            unsolved = still_unsolved;
117        }
118
119        // Check if any constraints remain unsolved
120        if !unsolved.is_empty() {
121            return Err(TypeError::UnsolvedConstraints(unsolved));
122        }
123
124        // Apply bounds to type variables
125        self.apply_bounds()?;
126
127        Ok(())
128    }
129
130    /// Solve a single constraint
131    fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
132        // Apply current substitutions before matching to avoid overwriting
133        // existing bindings (e.g., T17=string overwritten by T17=T19 during
134        // Function param/return pairwise unification).
135        let t1 = self.unifier.apply_substitutions(&t1);
136        let t2 = self.unifier.apply_substitutions(&t2);
137
138        match (&t1, &t2) {
139            // Variable constraints
140            (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
141
142            // Constrained type variables — must be matched BEFORE the general
143            // Variable arm, otherwise (Variable, Constrained) pairs are caught
144            // by the Variable arm and the bound is never recorded.
145            (Type::Constrained { var, constraint }, ty)
146            | (ty, Type::Constrained { var, constraint }) => {
147                // Record the constraint
148                self.bounds.insert(var.clone(), *constraint.clone());
149
150                // Unify with the underlying type
151                self.solve_constraint(Type::Variable(var.clone()), ty.clone())
152            }
153
154            (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
155                // Check occurs check
156                if self.occurs_in(var, ty) {
157                    return Err(TypeError::InfiniteType(var.clone()));
158                }
159
160                self.unifier.bind(var.clone(), ty.clone());
161                Ok(())
162            }
163
164            // Concrete type constraints
165            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
166                if self.unify_annotations(ann1, ann2)? {
167                    Ok(())
168                } else if Self::can_numeric_widen(ann1, ann2) {
169                    // Implicit numeric promotion (int → number/float)
170                    Ok(())
171                } else {
172                    Err(TypeError::TypeMismatch(
173                        format!("{:?}", ann1),
174                        format!("{:?}", ann2),
175                    ))
176                }
177            }
178
179            // Generic type constraints
180            (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
181                self.solve_constraint(*b1.clone(), *b2.clone())?;
182
183                let is_result_base = |base: &Type| match base {
184                    Type::Concrete(TypeAnnotation::Reference(name)) => name == "Result",
185                    Type::Concrete(TypeAnnotation::Basic(name)) => name == "Result",
186                    _ => false,
187                };
188
189                if a1.len() != a2.len() {
190                    if is_result_base(&b1) && is_result_base(&b2) {
191                        match (a1.len(), a2.len()) {
192                            // `Result<T>` is error-agnostic shorthand and should unify
193                            // with `Result<T, E>` by constraining only the success type.
194                            (1, 2) | (2, 1) => {
195                                self.solve_constraint(a1[0].clone(), a2[0].clone())?;
196                                return Ok(());
197                            }
198                            _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
199                        }
200                    } else {
201                        return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
202                    }
203                }
204
205                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
206                    self.solve_constraint(arg1.clone(), arg2.clone())?;
207                }
208
209                Ok(())
210            }
211
212            // Function ~ Function: pairwise unify params + returns
213            (
214                Type::Function {
215                    params: p1,
216                    returns: r1,
217                },
218                Type::Function {
219                    params: p2,
220                    returns: r2,
221                },
222            ) => {
223                if p1.len() != p2.len() {
224                    return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
225                }
226                for (param1, param2) in p1.iter().zip(p2.iter()) {
227                    // Parameter compatibility is checked from observed/actual to
228                    // declared/expected shape so directional numeric widening
229                    // (e.g. int -> number) remains valid in call constraints.
230                    self.solve_constraint(param2.clone(), param1.clone())?;
231                }
232                self.solve_constraint(*r1.clone(), *r2.clone())
233            }
234
235            // Cross-compatibility: Type::Function ~ Concrete(TypeAnnotation::Function)
236            (
237                Type::Function {
238                    params: fp,
239                    returns: fr,
240                },
241                Type::Concrete(TypeAnnotation::Function {
242                    params: cp,
243                    returns: cr,
244                }),
245            )
246            | (
247                Type::Concrete(TypeAnnotation::Function {
248                    params: cp,
249                    returns: cr,
250                }),
251                Type::Function {
252                    params: fp,
253                    returns: fr,
254                },
255            ) => {
256                if fp.len() != cp.len() {
257                    return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
258                }
259                for (f_param, c_param) in fp.iter().zip(cp.iter()) {
260                    self.solve_constraint(
261                        f_param.clone(),
262                        Type::Concrete(c_param.type_annotation.clone()),
263                    )?;
264                }
265                self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
266            }
267
268            // Array<T> (Type::Generic with base "Array" or "Vec") ~ Concrete(Array(T))
269            (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
270            | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
271                if args.len() == 1 && is_array_or_vec_base(base) =>
272            {
273                self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
274            }
275
276            _ => Err(TypeError::TypeMismatch(
277                format!("{:?}", t1),
278                format!("{:?}", t2),
279            )),
280        }
281    }
282
283    /// Check if a type variable occurs in a type (occurs check)
284    fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
285        match ty {
286            Type::Variable(v) => v == var,
287            Type::Generic { base, args } => {
288                self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
289            }
290            Type::Constrained { var: v, .. } => v == var,
291            Type::Function { params, returns } => {
292                params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
293            }
294            Type::Concrete(_) => false,
295        }
296    }
297
298    /// Check if a numeric type can widen to another (directional).
299    ///
300    /// Integer-family types (`int`, `i16`, `u32`, `byte`, ...) can widen to
301    /// number-family types (`number`, `f32`, `f64`, ...).
302    /// `number → int` does NOT widen (lossy). `decimal → number` does NOT widen
303    /// (different precision semantics).
304    fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
305        let from_name = match from {
306            TypeAnnotation::Basic(name) => Some(name.as_str()),
307            TypeAnnotation::Reference(name) => Some(name.as_str()),
308            _ => None,
309        };
310        let to_name = match to {
311            TypeAnnotation::Basic(name) => Some(name.as_str()),
312            TypeAnnotation::Reference(name) => Some(name.as_str()),
313            _ => None,
314        };
315
316        match (from_name, to_name) {
317            (Some(f), Some(t)) => {
318                BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
319            }
320            _ => false,
321        }
322    }
323
324    /// Unify two type annotations
325    fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
326        match (ann1, ann2) {
327            // Basic types
328            (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
329                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
330            }
331            (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
332            (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
333            | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
334                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
335            }
336
337            // Array types
338            (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
339                self.unify_annotations(e1, e2)
340            }
341
342            // Tuple types
343            (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
344                if t1.len() != t2.len() {
345                    return Ok(false);
346                }
347
348                for (elem1, elem2) in t1.iter().zip(t2.iter()) {
349                    if !self.unify_annotations(elem1, elem2)? {
350                        return Ok(false);
351                    }
352                }
353
354                Ok(true)
355            }
356
357            // Structural object types
358            (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
359                self.object_fields_compatible(f1, f2)
360            }
361
362            // Function types
363            (
364                TypeAnnotation::Function {
365                    params: p1,
366                    returns: r1,
367                },
368                TypeAnnotation::Function {
369                    params: p2,
370                    returns: r2,
371                },
372            ) => {
373                if p1.len() != p2.len() {
374                    return Ok(false);
375                }
376
377                for (param1, param2) in p1.iter().zip(p2.iter()) {
378                    if !self.unify_annotations(&param1.type_annotation, &param2.type_annotation)? {
379                        return Ok(false);
380                    }
381                }
382
383                self.unify_annotations(r1, r2)
384            }
385
386            // Union types
387            // A | B unifies with C | D if each type in one union can unify with at least one type in the other
388            (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
389                // Check that every type in u1 can unify with at least one type in u2
390                for t1 in u1 {
391                    let mut found_match = false;
392                    for t2 in u2 {
393                        if self.unify_annotations(t1, t2)? {
394                            found_match = true;
395                            break;
396                        }
397                    }
398                    if !found_match {
399                        return Ok(false);
400                    }
401                }
402                // Check that every type in u2 can unify with at least one type in u1
403                for t2 in u2 {
404                    let mut found_match = false;
405                    for t1 in u1 {
406                        if self.unify_annotations(t1, t2)? {
407                            found_match = true;
408                            break;
409                        }
410                    }
411                    if !found_match {
412                        return Ok(false);
413                    }
414                }
415                Ok(true)
416            }
417
418            // Union with non-union: A | B unifies with C if either A or B unifies with C
419            (TypeAnnotation::Union(union_types), other)
420            | (other, TypeAnnotation::Union(union_types)) => {
421                for union_type in union_types {
422                    if self.unify_annotations(union_type, other)? {
423                        return Ok(true);
424                    }
425                }
426                Ok(false)
427            }
428
429            // Intersection types (order-independent)
430            (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
431                self.unify_annotation_sets(i1, i2)
432            }
433
434            // Void, Null, Undefined
435            (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
436            (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
437            (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
438
439            // Trait object types: dyn Trait1 + Trait2
440            // Two trait objects unify if they have the same set of traits
441            (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
442                Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
443            }
444
445            // Array<T> (Generic) is equivalent to Vec<T> (Array)
446            (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
447            | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
448                if name == "Array" && args.len() == 1 =>
449            {
450                self.unify_annotations(&args[0], elem)
451            }
452
453            // Different types don't unify
454            _ => Ok(false),
455        }
456    }
457
458    fn object_fields_compatible(
459        &self,
460        left: &[ObjectTypeField],
461        right: &[ObjectTypeField],
462    ) -> TypeResult<bool> {
463        for left_field in left {
464            let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
465                return Ok(false);
466            };
467            if left_field.optional != right_field.optional {
468                return Ok(false);
469            }
470            if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
471                return Ok(false);
472            }
473        }
474        if left.len() != right.len() {
475            return Ok(false);
476        }
477        Ok(true)
478    }
479
480    fn unify_annotation_sets(
481        &self,
482        left: &[TypeAnnotation],
483        right: &[TypeAnnotation],
484    ) -> TypeResult<bool> {
485        if left.len() != right.len() {
486            return Ok(false);
487        }
488
489        let mut matched = vec![false; right.len()];
490        for left_ann in left {
491            let mut found = false;
492            for (idx, right_ann) in right.iter().enumerate() {
493                if matched[idx] {
494                    continue;
495                }
496                if self.unify_annotations(left_ann, right_ann)? {
497                    matched[idx] = true;
498                    found = true;
499                    break;
500                }
501            }
502            if !found {
503                return Ok(false);
504            }
505        }
506
507        Ok(true)
508    }
509
510    /// Apply type variable bounds, propagating resolved field types back to type variables.
511    ///
512    /// When a `HasField` constraint is satisfied and the expected field type was a
513    /// type variable, this binds that variable to the actual field type. This enables
514    /// backward propagation: `let f = |obj| obj.x; f({x: 42})` resolves `obj.x` to `int`.
515    fn apply_bounds(&mut self) -> TypeResult<()> {
516        let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
517
518        for (var, constraint) in &self.bounds {
519            // Use apply_substitutions to follow the full variable chain
520            // (lookup only returns the direct binding, not the resolved type).
521            let resolved = self
522                .unifier
523                .apply_substitutions(&Type::Variable(var.clone()));
524
525            if let Type::Variable(_) = &resolved {
526                // Still unresolved — skip for now
527                continue;
528            }
529
530            self.check_constraint(&resolved, constraint)?;
531
532            // Backward propagation: when a HasField constraint is satisfied,
533            // bind the result type variable to the actual field type.
534            if let TypeConstraint::HasField(field, expected_field_type) = constraint {
535                if let Type::Variable(field_var) = expected_field_type.as_ref() {
536                    // Also check if the field var is already resolved
537                    let field_resolved = self
538                        .unifier
539                        .apply_substitutions(&Type::Variable(field_var.clone()));
540                    if let Type::Variable(_) = &field_resolved {
541                        // Field var still unresolved — try to bind it
542                        if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
543                            if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
544                                new_bindings.push((
545                                    field_var.clone(),
546                                    Type::Concrete(found_field.type_annotation.clone()),
547                                ));
548                            }
549                        }
550                    }
551                }
552            }
553        }
554
555        // Apply collected bindings
556        for (var, ty) in new_bindings {
557            self.unifier.bind(var, ty);
558        }
559
560        Ok(())
561    }
562
563    /// Check if a type satisfies a constraint
564    fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
565        match constraint {
566            TypeConstraint::Comparable => match ty {
567                Type::Concrete(TypeAnnotation::Basic(name))
568                    if BuiltinTypes::is_numeric_type_name(name)
569                        || name == "string"
570                        || name == "bool" =>
571                {
572                    Ok(())
573                }
574                _ => Err(TypeError::ConstraintViolation(format!(
575                    "{:?} is not comparable",
576                    ty
577                ))),
578            },
579
580            TypeConstraint::Iterable => match ty {
581                Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
582                Type::Concrete(TypeAnnotation::Basic(name))
583                    if name == "string" || name == "rows" =>
584                {
585                    Ok(())
586                }
587                _ => Err(TypeError::ConstraintViolation(format!(
588                    "{:?} is not iterable",
589                    ty
590                ))),
591            },
592
593            TypeConstraint::HasField(field, expected_field_type) => {
594                match ty {
595                    Type::Concrete(TypeAnnotation::Object(fields)) => {
596                        match fields.iter().find(|f| f.name == *field) {
597                            Some(found_field) => {
598                                // Check that field type matches expected type
599                                if let Some(expected_ann) = expected_field_type.to_annotation() {
600                                    if self.unify_annotations(
601                                        &found_field.type_annotation,
602                                        &expected_ann,
603                                    )? {
604                                        Ok(())
605                                    } else {
606                                        Err(TypeError::ConstraintViolation(format!(
607                                            "field '{}' has type {:?}, expected {:?}",
608                                            field, found_field.type_annotation, expected_ann
609                                        )))
610                                    }
611                                } else {
612                                    // Expected type is a type variable, accept any field type
613                                    Ok(())
614                                }
615                            }
616                            None => Err(TypeError::ConstraintViolation(format!(
617                                "{:?} does not have field '{}'",
618                                ty, field
619                            ))),
620                        }
621                    }
622                    Type::Concrete(TypeAnnotation::Basic(_name)) => {
623                        // For named types, we assume property access was validated
624                        // during inference using the schema registry. If a HasField
625                        // constraint reaches here, it means the type wasn't a known
626                        // schema type during inference, so we accept it tentatively.
627                        // Runtime will do the final validation.
628                        //
629                        // Note: Previously this hardcoded "row" with OHLCV fields.
630                        // Now schema validation happens in TypeInferenceEngine::infer_property_access.
631                        Ok(())
632                    }
633                    _ => Err(TypeError::ConstraintViolation(format!(
634                        "{:?} cannot have fields",
635                        ty
636                    ))),
637                }
638            }
639
640            TypeConstraint::Callable {
641                params: expected_params,
642                returns: expected_returns,
643            } => {
644                match ty {
645                    Type::Concrete(TypeAnnotation::Function {
646                        params: actual_params,
647                        returns: actual_returns,
648                    }) => {
649                        // Check parameter count matches
650                        if expected_params.len() != actual_params.len() {
651                            return Err(TypeError::ConstraintViolation(format!(
652                                "function expects {} parameters, got {}",
653                                expected_params.len(),
654                                actual_params.len()
655                            )));
656                        }
657
658                        // Check each parameter type (contravariant: expected <: actual)
659                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
660                            if let Some(expected_ann) = expected.to_annotation() {
661                                if !self
662                                    .unify_annotations(&expected_ann, &actual.type_annotation)?
663                                {
664                                    return Err(TypeError::ConstraintViolation(format!(
665                                        "parameter type mismatch: expected {:?}, got {:?}",
666                                        expected_ann, actual.type_annotation
667                                    )));
668                                }
669                            }
670                        }
671
672                        // Check return type (covariant: actual <: expected)
673                        if let Some(expected_ret_ann) = expected_returns.to_annotation() {
674                            if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
675                                return Err(TypeError::ConstraintViolation(format!(
676                                    "return type mismatch: expected {:?}, got {:?}",
677                                    expected_ret_ann, actual_returns
678                                )));
679                            }
680                        }
681
682                        Ok(())
683                    }
684                    Type::Function {
685                        params: actual_params,
686                        returns: actual_returns,
687                    } => {
688                        if expected_params.len() != actual_params.len() {
689                            return Err(TypeError::ConstraintViolation(format!(
690                                "function expects {} parameters, got {}",
691                                expected_params.len(),
692                                actual_params.len()
693                            )));
694                        }
695                        // Type::Function params are Type, not FunctionParam — compare directly
696                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
697                            if let (Some(e_ann), Some(a_ann)) =
698                                (expected.to_annotation(), actual.to_annotation())
699                            {
700                                if !self.unify_annotations(&e_ann, &a_ann)? {
701                                    return Err(TypeError::ConstraintViolation(format!(
702                                        "parameter type mismatch: expected {:?}, got {:?}",
703                                        e_ann, a_ann
704                                    )));
705                                }
706                            }
707                        }
708                        if let (Some(e_ret), Some(a_ret)) = (
709                            expected_returns.to_annotation(),
710                            actual_returns.to_annotation(),
711                        ) {
712                            if !self.unify_annotations(&a_ret, &e_ret)? {
713                                return Err(TypeError::ConstraintViolation(format!(
714                                    "return type mismatch: expected {:?}, got {:?}",
715                                    e_ret, a_ret
716                                )));
717                            }
718                        }
719                        Ok(())
720                    }
721                    _ => Err(TypeError::ConstraintViolation(format!(
722                        "{:?} is not callable",
723                        ty
724                    ))),
725                }
726            }
727
728            TypeConstraint::OneOf(options) => {
729                for option in options {
730                    // If type matches any option, constraint is satisfied
731                    if let Type::Concrete(ann) = option {
732                        if let Type::Concrete(ty_ann) = ty {
733                            if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
734                                return Ok(());
735                            }
736                        }
737                    }
738                }
739
740                Err(TypeError::ConstraintViolation(format!(
741                    "{:?} does not match any of {:?}",
742                    ty, options
743                )))
744            }
745
746            TypeConstraint::Extends(base) => {
747                // Implement subtyping check
748                self.is_subtype(ty, base)
749            }
750
751            TypeConstraint::ImplementsTrait { trait_name } => {
752                match ty {
753                    Type::Variable(_) => {
754                        // Type variable not yet resolved — this is a compile error
755                        // (no deferring per Sprint 2 spec)
756                        Err(TypeError::TraitBoundViolation {
757                            type_name: format!("{:?}", ty),
758                            trait_name: trait_name.clone(),
759                        })
760                    }
761                    Type::Concrete(ann) => {
762                        let type_name = match ann {
763                            TypeAnnotation::Basic(n) => n.clone(),
764                            TypeAnnotation::Reference(n) => n.to_string(),
765                            _ => format!("{:?}", ann),
766                        };
767                        if self.has_trait_impl(trait_name, &type_name) {
768                            Ok(())
769                        } else {
770                            Err(TypeError::TraitBoundViolation {
771                                type_name,
772                                trait_name: trait_name.clone(),
773                            })
774                        }
775                    }
776                    Type::Generic { base, .. } => {
777                        let type_name = match base.as_ref() {
778                            Type::Concrete(TypeAnnotation::Reference(n)) => n.to_string(),
779                            Type::Concrete(TypeAnnotation::Basic(n)) => n.clone(),
780                            _ => format!("{:?}", base),
781                        };
782                        if self.has_trait_impl(trait_name, &type_name) {
783                            Ok(())
784                        } else {
785                            Err(TypeError::TraitBoundViolation {
786                                type_name,
787                                trait_name: trait_name.clone(),
788                            })
789                        }
790                    }
791                    _ => Err(TypeError::TraitBoundViolation {
792                        type_name: format!("{:?}", ty),
793                        trait_name: trait_name.clone(),
794                    }),
795                }
796            }
797
798            TypeConstraint::HasMethod {
799                method_name,
800                arg_types: _,
801                return_type: _,
802            } => {
803                // If we have a method table, enforce the constraint
804                if let Some(method_table) = &self.method_table {
805                    match ty {
806                        Type::Variable(_) => Ok(()), // Unresolved type var, defer
807                        Type::Concrete(ann) => {
808                            let type_name = match ann {
809                                TypeAnnotation::Basic(n) => n.clone(),
810                                TypeAnnotation::Reference(n) => n.to_string(),
811                                TypeAnnotation::Array(_) => "Vec".to_string(),
812                                _ => return Ok(()), // Complex types: accept
813                            };
814                            if method_table.lookup(ty, method_name).is_some() {
815                                Ok(())
816                            } else {
817                                Err(TypeError::MethodNotFound {
818                                    type_name,
819                                    method_name: method_name.clone(),
820                                })
821                            }
822                        }
823                        Type::Generic { base, .. } => {
824                            if method_table.lookup(ty, method_name).is_some() {
825                                Ok(())
826                            } else {
827                                let type_name =
828                                    if let Type::Concrete(TypeAnnotation::Reference(n)) =
829                                        base.as_ref()
830                                    {
831                                        n.to_string()
832                                    } else {
833                                        format!("{:?}", base)
834                                    };
835                                Err(TypeError::MethodNotFound {
836                                    type_name,
837                                    method_name: method_name.clone(),
838                                })
839                            }
840                        }
841                        _ => Ok(()), // Function, Constrained: accept
842                    }
843                } else {
844                    // No method table attached — accept all (backward compatible)
845                    Ok(())
846                }
847            }
848        }
849    }
850
851    /// Check if a type implements a trait, considering aliases and numeric widening.
852    ///
853    /// Handles three resolution strategies:
854    /// 1. Direct lookup: `"Numeric::int"` in the trait_impls set
855    /// 2. Canonical alias: `"Float"` → `"f64"`, `"byte"` → `"u8"` via runtime name table
856    /// 3. Script alias: `"i64"` → `"int"`, `"f64"` → `"number"` via script alias table
857    /// 4. Numeric widening: integer-family names can satisfy number/float/f64 impls
858    fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
859        let key = format!("{}::{}", trait_name, type_name);
860        if self.trait_impls.contains(&key) {
861            return true;
862        }
863        // Try canonical runtime alias (e.g. "Float" -> "f64", "byte" -> "u8")
864        if let Some(canonical) = BuiltinTypes::canonical_numeric_runtime_name(type_name) {
865            let canon_key = format!("{}::{}", trait_name, canonical);
866            if self.trait_impls.contains(&canon_key) {
867                return true;
868            }
869        }
870        // Try script-facing alias (e.g. "i64" -> "int", "f64" -> "number")
871        if let Some(script_alias) = BuiltinTypes::canonical_script_alias(type_name) {
872            let alias_key = format!("{}::{}", trait_name, script_alias);
873            if self.trait_impls.contains(&alias_key) {
874                return true;
875            }
876        }
877        // Numeric widening: integer-family aliases can use number/float/f64 impls.
878        if BuiltinTypes::is_integer_type_name(type_name) {
879            for widen_to in &["number", "float", "f64"] {
880                let widen_key = format!("{}::{}", trait_name, widen_to);
881                if self.trait_impls.contains(&widen_key) {
882                    return true;
883                }
884            }
885        }
886        false
887    }
888
889    /// Check if ty is a subtype of base (ty <: base)
890    /// Subtyping rules:
891    /// - Same types are subtypes of each other
892    /// - Any is a supertype of everything
893    /// - Vec<A> <: Vec<B> if A <: B (covariant)
894    /// - Function<P1, R1> <: Function<P2, R2> if P2 <: P1 (contravariant params) and R1 <: R2 (covariant return)
895    fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
896        match (ty, base) {
897            // Same types are subtypes
898            (t1, t2) if t1 == t2 => Ok(()),
899
900            // Type variables - if we can unify, it's compatible
901            (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
902
903            // Array subtyping (covariant)
904            (
905                Type::Concrete(TypeAnnotation::Array(elem1)),
906                Type::Concrete(TypeAnnotation::Array(elem2)),
907            ) => {
908                let t1 = Type::Concrete(*elem1.clone());
909                let t2 = Type::Concrete(*elem2.clone());
910                self.is_subtype(&t1, &t2)
911            }
912
913            // Function subtyping (contravariant params, covariant return)
914            (
915                Type::Concrete(TypeAnnotation::Function {
916                    params: p1,
917                    returns: r1,
918                }),
919                Type::Concrete(TypeAnnotation::Function {
920                    params: p2,
921                    returns: r2,
922                }),
923            ) => {
924                // Check parameter count
925                if p1.len() != p2.len() {
926                    return Err(TypeError::ConstraintViolation(format!(
927                        "function parameter count mismatch: {} vs {}",
928                        p1.len(),
929                        p2.len()
930                    )));
931                }
932
933                // Contravariant: base params must be subtypes of ty params
934                for (param1, param2) in p1.iter().zip(p2.iter()) {
935                    let t1 = Type::Concrete(param2.type_annotation.clone());
936                    let t2 = Type::Concrete(param1.type_annotation.clone());
937                    self.is_subtype(&t1, &t2)?;
938                }
939
940                // Covariant: ty return must be subtype of base return
941                let ret1 = Type::Concrete(*r1.clone());
942                let ret2 = Type::Concrete(*r2.clone());
943                self.is_subtype(&ret1, &ret2)
944            }
945
946            // Optional subtyping: T <: Option<T>
947            (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
948                if name == "Option" && args.len() == 1 =>
949            {
950                let inner = Type::Concrete(args[0].clone());
951                self.is_subtype(t, &inner)
952            }
953
954            // Type::Function subtyping (contravariant params, covariant return)
955            (
956                Type::Function {
957                    params: p1,
958                    returns: r1,
959                },
960                Type::Function {
961                    params: p2,
962                    returns: r2,
963                },
964            ) => {
965                if p1.len() != p2.len() {
966                    return Err(TypeError::ConstraintViolation(format!(
967                        "function parameter count mismatch: {} vs {}",
968                        p1.len(),
969                        p2.len()
970                    )));
971                }
972                // Contravariant params
973                for (param1, param2) in p1.iter().zip(p2.iter()) {
974                    self.is_subtype(param2, param1)?;
975                }
976                // Covariant return
977                self.is_subtype(r1, r2)
978            }
979
980            // Basic types - check if they unify
981            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
982                if self.unify_annotations(ann1, ann2)? {
983                    Ok(())
984                } else {
985                    Err(TypeError::ConstraintViolation(format!(
986                        "{:?} is not a subtype of {:?}",
987                        ty, base
988                    )))
989                }
990            }
991
992            // Default: not a subtype
993            _ => Err(TypeError::ConstraintViolation(format!(
994                "{:?} is not a subtype of {:?}",
995                ty, base
996            ))),
997        }
998    }
999
1000    /// Get the unifier for applying substitutions
1001    pub fn unifier(&self) -> &Unifier {
1002        &self.unifier
1003    }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008    use super::*;
1009    use shape_ast::ast::ObjectTypeField;
1010
1011    #[test]
1012    fn test_hasfield_backward_propagation_binds_field_type() {
1013        // When a TypeVar has a HasField constraint and is resolved to a concrete
1014        // object type, the field's result type variable should be bound to the
1015        // actual field type. This enables backward type propagation.
1016        let mut solver = ConstraintSolver::new();
1017
1018        let obj_var = TypeVar::fresh();
1019        let field_result_var = TypeVar::fresh();
1020        let bound_var = TypeVar::fresh();
1021
1022        let mut constraints = vec![
1023            // obj_var ~ Constrained { var: bound_var, HasField("x", field_result_var) }
1024            // This records bound: bound_var → HasField("x", field_result_var)
1025            // and solves: bound_var ~ obj_var
1026            (
1027                Type::Variable(obj_var.clone()),
1028                Type::Constrained {
1029                    var: bound_var,
1030                    constraint: Box::new(TypeConstraint::HasField(
1031                        "x".to_string(),
1032                        Box::new(Type::Variable(field_result_var.clone())),
1033                    )),
1034                },
1035            ),
1036            // obj_var = {x: int}
1037            (
1038                Type::Variable(obj_var),
1039                Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1040                    name: "x".to_string(),
1041                    optional: false,
1042                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1043                    annotations: vec![],
1044                }])),
1045            ),
1046        ];
1047
1048        solver.solve(&mut constraints).unwrap();
1049
1050        // field_result_var should now be resolved to int via apply_bounds
1051        let resolved = solver
1052            .unifier()
1053            .apply_substitutions(&Type::Variable(field_result_var));
1054        match &resolved {
1055            Type::Concrete(TypeAnnotation::Basic(name)) => {
1056                assert_eq!(name, "int", "field type should be int");
1057            }
1058            _ => panic!(
1059                "Expected field_result_var to be resolved to int, got {:?}",
1060                resolved
1061            ),
1062        }
1063    }
1064
1065    #[test]
1066    fn test_hasfield_backward_propagation_multiple_fields() {
1067        // Test that multiple HasField constraints on the same object all propagate
1068        let mut solver = ConstraintSolver::new();
1069
1070        let obj_var = TypeVar::fresh();
1071        let field_x_var = TypeVar::fresh();
1072        let field_y_var = TypeVar::fresh();
1073        let bound_var_x = TypeVar::fresh();
1074        let bound_var_y = TypeVar::fresh();
1075
1076        let mut constraints = vec![
1077            // HasField("x", field_x_var)
1078            (
1079                Type::Variable(obj_var.clone()),
1080                Type::Constrained {
1081                    var: bound_var_x,
1082                    constraint: Box::new(TypeConstraint::HasField(
1083                        "x".to_string(),
1084                        Box::new(Type::Variable(field_x_var.clone())),
1085                    )),
1086                },
1087            ),
1088            // HasField("y", field_y_var)
1089            (
1090                Type::Variable(obj_var.clone()),
1091                Type::Constrained {
1092                    var: bound_var_y,
1093                    constraint: Box::new(TypeConstraint::HasField(
1094                        "y".to_string(),
1095                        Box::new(Type::Variable(field_y_var.clone())),
1096                    )),
1097                },
1098            ),
1099            // obj_var = {x: int, y: string}
1100            (
1101                Type::Variable(obj_var),
1102                Type::Concrete(TypeAnnotation::Object(vec![
1103                    ObjectTypeField {
1104                        name: "x".to_string(),
1105                        optional: false,
1106                        type_annotation: TypeAnnotation::Basic("int".to_string()),
1107                        annotations: vec![],
1108                    },
1109                    ObjectTypeField {
1110                        name: "y".to_string(),
1111                        optional: false,
1112                        type_annotation: TypeAnnotation::Basic("string".to_string()),
1113                        annotations: vec![],
1114                    },
1115                ])),
1116            ),
1117        ];
1118
1119        solver.solve(&mut constraints).unwrap();
1120
1121        let resolved_x = solver
1122            .unifier()
1123            .apply_substitutions(&Type::Variable(field_x_var));
1124        let resolved_y = solver
1125            .unifier()
1126            .apply_substitutions(&Type::Variable(field_y_var));
1127
1128        match &resolved_x {
1129            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1130            _ => panic!("Expected x to be int, got {:?}", resolved_x),
1131        }
1132        match &resolved_y {
1133            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1134            _ => panic!("Expected y to be string, got {:?}", resolved_y),
1135        }
1136    }
1137
1138    // ===== Fix 1: Numeric type preservation tests =====
1139
1140    #[test]
1141    fn test_int_constrained_numeric_succeeds() {
1142        // Concrete(int) ~ Constrained(ImplementsTrait("Numeric")) should succeed
1143        let mut solver = ConstraintSolver::new();
1144        // Inject Numeric trait impls (same as TypeEnvironment registers)
1145        let trait_impls: std::collections::HashSet<String> = [
1146            "Numeric::int", "Numeric::number", "Numeric::decimal",
1147            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1148            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1149            "Numeric::f32", "Numeric::f64",
1150        ].iter().map(|s| s.to_string()).collect();
1151        solver.set_trait_impls(trait_impls);
1152        let bound_var = TypeVar::fresh();
1153        let mut constraints = vec![(
1154            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1155            Type::Constrained {
1156                var: bound_var,
1157                constraint: Box::new(TypeConstraint::ImplementsTrait {
1158                    trait_name: "Numeric".to_string(),
1159                }),
1160            },
1161        )];
1162        assert!(solver.solve(&mut constraints).is_ok());
1163    }
1164
1165    #[test]
1166    fn test_numeric_widening_int_to_number() {
1167        // (Concrete(int), Concrete(number)) should succeed via widening
1168        let mut solver = ConstraintSolver::new();
1169        let mut constraints = vec![(
1170            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1171            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1172        )];
1173        assert!(solver.solve(&mut constraints).is_ok());
1174    }
1175
1176    #[test]
1177    fn test_numeric_widening_width_aware_integer_to_float_family() {
1178        let mut solver = ConstraintSolver::new();
1179        let mut constraints = vec![(
1180            Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1181            Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1182        )];
1183        assert!(solver.solve(&mut constraints).is_ok());
1184    }
1185
1186    #[test]
1187    fn test_no_widening_number_to_int() {
1188        // (Concrete(number), Concrete(int)) should fail — lossy
1189        let mut solver = ConstraintSolver::new();
1190        let mut constraints = vec![(
1191            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1192            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1193        )];
1194        assert!(solver.solve(&mut constraints).is_err());
1195    }
1196
1197    #[test]
1198    fn test_decimal_constrained_numeric_succeeds() {
1199        let mut solver = ConstraintSolver::new();
1200        let trait_impls: std::collections::HashSet<String> = [
1201            "Numeric::int", "Numeric::number", "Numeric::decimal",
1202            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1203            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1204            "Numeric::f32", "Numeric::f64",
1205        ].iter().map(|s| s.to_string()).collect();
1206        solver.set_trait_impls(trait_impls);
1207        let bound_var = TypeVar::fresh();
1208        let mut constraints = vec![(
1209            Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1210            Type::Constrained {
1211                var: bound_var,
1212                constraint: Box::new(TypeConstraint::ImplementsTrait {
1213                    trait_name: "Numeric".to_string(),
1214                }),
1215            },
1216        )];
1217        assert!(solver.solve(&mut constraints).is_ok());
1218    }
1219
1220    #[test]
1221    fn test_comparable_accepts_int() {
1222        // int should be Comparable
1223        let mut solver = ConstraintSolver::new();
1224        let bound_var = TypeVar::fresh();
1225        let mut constraints = vec![(
1226            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1227            Type::Constrained {
1228                var: bound_var,
1229                constraint: Box::new(TypeConstraint::Comparable),
1230            },
1231        )];
1232        assert!(solver.solve(&mut constraints).is_ok());
1233    }
1234
1235    // ===== Fix 2: Type::Function tests =====
1236
1237    #[test]
1238    fn test_function_type_preserves_variables() {
1239        // BuiltinTypes::function with Variable params should be Type::Function
1240        let param = Type::fresh_var();
1241        let ret = Type::fresh_var();
1242        let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1243        match func {
1244            Type::Function { params, returns } => {
1245                assert_eq!(params.len(), 1);
1246                assert_eq!(params[0], param);
1247                assert_eq!(*returns, ret);
1248            }
1249            _ => panic!("Expected Type::Function, got {:?}", func),
1250        }
1251    }
1252
1253    #[test]
1254    fn test_function_unification_binds_variables() {
1255        // (T1)->T2 ~ (number)->string should bind T1=number, T2=string
1256        let mut solver = ConstraintSolver::new();
1257        let t1 = TypeVar::fresh();
1258        let t2 = TypeVar::fresh();
1259
1260        let mut constraints = vec![(
1261            Type::Function {
1262                params: vec![Type::Variable(t1.clone())],
1263                returns: Box::new(Type::Variable(t2.clone())),
1264            },
1265            Type::Function {
1266                params: vec![BuiltinTypes::number()],
1267                returns: Box::new(BuiltinTypes::string()),
1268            },
1269        )];
1270
1271        solver.solve(&mut constraints).unwrap();
1272
1273        let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1274        let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1275        assert_eq!(resolved_t1, BuiltinTypes::number());
1276        assert_eq!(resolved_t2, BuiltinTypes::string());
1277    }
1278
1279    #[test]
1280    fn test_function_cross_unification_with_concrete() {
1281        // Type::Function ~ Concrete(TypeAnnotation::Function) should unify
1282        let mut solver = ConstraintSolver::new();
1283        let t1 = TypeVar::fresh();
1284
1285        let concrete_func = Type::Concrete(TypeAnnotation::Function {
1286            params: vec![shape_ast::ast::FunctionParam {
1287                name: None,
1288                optional: false,
1289                type_annotation: TypeAnnotation::Basic("number".to_string()),
1290            }],
1291            returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1292        });
1293
1294        let mut constraints = vec![(
1295            Type::Function {
1296                params: vec![Type::Variable(t1.clone())],
1297                returns: Box::new(BuiltinTypes::string()),
1298            },
1299            concrete_func,
1300        )];
1301
1302        solver.solve(&mut constraints).unwrap();
1303
1304        let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1305        assert_eq!(resolved, BuiltinTypes::number());
1306    }
1307
1308    #[test]
1309    fn test_object_annotations_unify_structurally() {
1310        let mut solver = ConstraintSolver::new();
1311        let mut constraints = vec![(
1312            Type::Concrete(TypeAnnotation::Object(vec![
1313                ObjectTypeField {
1314                    name: "x".to_string(),
1315                    optional: false,
1316                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1317                    annotations: vec![],
1318                },
1319                ObjectTypeField {
1320                    name: "y".to_string(),
1321                    optional: false,
1322                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1323                    annotations: vec![],
1324                },
1325            ])),
1326            Type::Concrete(TypeAnnotation::Object(vec![
1327                ObjectTypeField {
1328                    name: "x".to_string(),
1329                    optional: false,
1330                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1331                    annotations: vec![],
1332                },
1333                ObjectTypeField {
1334                    name: "y".to_string(),
1335                    optional: false,
1336                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1337                    annotations: vec![],
1338                },
1339            ])),
1340        )];
1341        assert!(solver.solve(&mut constraints).is_ok());
1342    }
1343
1344    #[test]
1345    fn test_intersection_annotations_unify_order_independent() {
1346        let mut solver = ConstraintSolver::new();
1347        let obj_xy = TypeAnnotation::Object(vec![
1348            ObjectTypeField {
1349                name: "x".to_string(),
1350                optional: false,
1351                type_annotation: TypeAnnotation::Basic("int".to_string()),
1352                annotations: vec![],
1353            },
1354            ObjectTypeField {
1355                name: "y".to_string(),
1356                optional: false,
1357                type_annotation: TypeAnnotation::Basic("int".to_string()),
1358                annotations: vec![],
1359            },
1360        ]);
1361        let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1362            name: "z".to_string(),
1363            optional: false,
1364            type_annotation: TypeAnnotation::Basic("int".to_string()),
1365            annotations: vec![],
1366        }]);
1367
1368        let mut constraints = vec![(
1369            Type::Concrete(TypeAnnotation::Intersection(vec![
1370                obj_xy.clone(),
1371                obj_z.clone(),
1372            ])),
1373            Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1374        )];
1375        assert!(solver.solve(&mut constraints).is_ok());
1376    }
1377
1378    // ===== Sprint 2: ImplementsTrait constraint tests =====
1379
1380    #[test]
1381    fn test_implements_trait_satisfied() {
1382        let mut solver = ConstraintSolver::new();
1383        let mut impls = std::collections::HashSet::new();
1384        impls.insert("Comparable::number".to_string());
1385        solver.set_trait_impls(impls);
1386
1387        let bound_var = TypeVar::fresh();
1388        let mut constraints = vec![(
1389            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1390            Type::Constrained {
1391                var: bound_var,
1392                constraint: Box::new(TypeConstraint::ImplementsTrait {
1393                    trait_name: "Comparable".to_string(),
1394                }),
1395            },
1396        )];
1397        assert!(solver.solve(&mut constraints).is_ok());
1398    }
1399
1400    #[test]
1401    fn test_implements_trait_violated() {
1402        let mut solver = ConstraintSolver::new();
1403        // No trait impls registered — string doesn't implement Comparable
1404        let bound_var = TypeVar::fresh();
1405        let mut constraints = vec![(
1406            Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1407            Type::Constrained {
1408                var: bound_var,
1409                constraint: Box::new(TypeConstraint::ImplementsTrait {
1410                    trait_name: "Comparable".to_string(),
1411                }),
1412            },
1413        )];
1414        let result = solver.solve(&mut constraints);
1415        assert!(result.is_err());
1416        match result.unwrap_err() {
1417            TypeError::TraitBoundViolation {
1418                type_name,
1419                trait_name,
1420            } => {
1421                assert_eq!(type_name, "string");
1422                assert_eq!(trait_name, "Comparable");
1423            }
1424            other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1425        }
1426    }
1427
1428    #[test]
1429    fn test_implements_trait_via_variable_resolution() {
1430        let mut solver = ConstraintSolver::new();
1431        let mut impls = std::collections::HashSet::new();
1432        impls.insert("Sortable::number".to_string());
1433        solver.set_trait_impls(impls);
1434
1435        let type_var = TypeVar::fresh();
1436        let bound_var = TypeVar::fresh();
1437
1438        let mut constraints = vec![
1439            // T: Sortable
1440            (
1441                Type::Variable(type_var.clone()),
1442                Type::Constrained {
1443                    var: bound_var,
1444                    constraint: Box::new(TypeConstraint::ImplementsTrait {
1445                        trait_name: "Sortable".to_string(),
1446                    }),
1447                },
1448            ),
1449            // T = number
1450            (
1451                Type::Variable(type_var),
1452                Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1453            ),
1454        ];
1455        assert!(
1456            solver.solve(&mut constraints).is_ok(),
1457            "T resolved to number which implements Sortable"
1458        );
1459    }
1460}