Skip to main content

shape_runtime/type_system/
constraints.rs

1//! Type Constraint Solver
2//!
3//! Solves type constraints generated during type inference to determine
4//! concrete types for type variables. The solver operates in three phases:
5//!
6//! ## Phase 1: Eager unification
7//!
8//! Each constraint `(T1, T2)` is attempted immediately via `solve_constraint`.
9//! Simple bindings (variable-to-concrete, variable-to-variable) succeed here.
10//! Constraints that fail (e.g. because a variable is not yet resolved) are
11//! deferred to the next phase.
12//!
13//! ## Phase 2: Fixed-point iteration on deferred constraints
14//!
15//! Deferred constraints are retried in a loop. Each successful resolution may
16//! unlock further deferred constraints by refining substitutions. The loop
17//! terminates when a full pass makes no progress. Any constraints still
18//! unsolved after the fixed-point are reported as `UnsolvedConstraints`.
19//!
20//! ## Phase 3: Bound application
21//!
22//! After all equality constraints are resolved, `apply_bounds` validates
23//! type variable bounds (`Numeric`, `Comparable`, `Iterable`, `HasField`,
24//! `HasMethod`, `ImplementsTrait`). `HasField` constraints additionally
25//! perform backward propagation: when a structural object field is found,
26//! the field's result type variable is bound to the actual field type.
27//!
28//! The solver delegates low-level variable binding and substitution to the
29//! `Unifier` (Robinson's algorithm with path compression).
30
31use super::checking::MethodTable;
32use super::unification::Unifier;
33use super::*;
34use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
35use std::collections::{HashMap, HashSet};
36
37/// Check if a Type::Generic base is "Array" or "Vec".
38fn is_array_or_vec_base(base: &Type) -> bool {
39    match base {
40        Type::Concrete(TypeAnnotation::Reference(name)) => name == "Array" || name == "Vec",
41        Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
42        _ => false,
43    }
44}
45
46pub struct ConstraintSolver {
47    /// Type unifier
48    unifier: Unifier,
49    /// Deferred constraints that couldn't be solved immediately.
50    /// These are handled in solve() via multiple passes.
51    _deferred: Vec<(Type, Type)>,
52    /// Type variable bounds
53    bounds: HashMap<TypeVar, TypeConstraint>,
54    /// Method table for HasMethod constraint enforcement
55    method_table: Option<MethodTable>,
56    /// Trait implementation registry: set of "TraitName::TypeName" keys
57    trait_impls: HashSet<String>,
58}
59
60impl Default for ConstraintSolver {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl ConstraintSolver {
67    pub fn new() -> Self {
68        ConstraintSolver {
69            unifier: Unifier::new(),
70            _deferred: Vec::new(),
71            bounds: HashMap::new(),
72            method_table: None,
73            trait_impls: HashSet::new(),
74        }
75    }
76
77    /// Attach a method table for HasMethod constraint enforcement.
78    /// When set, HasMethod constraints are validated against this table
79    /// instead of being accepted unconditionally.
80    pub fn set_method_table(&mut self, table: MethodTable) {
81        self.method_table = Some(table);
82    }
83
84    /// Register trait implementations for ImplementsTrait constraint enforcement.
85    /// Each entry is a "TraitName::TypeName" key indicating that TypeName implements TraitName.
86    pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
87        self.trait_impls = impls;
88    }
89
90    /// Solve all type constraints
91    pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
92        // First pass: solve simple unification constraints
93        let mut unsolved = Vec::new();
94
95        for (t1, t2) in constraints.drain(..) {
96            if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
97                // If we can't solve it now, defer it
98                unsolved.push((t1, t2));
99            }
100        }
101
102        // Second pass: try deferred constraints
103        let mut made_progress = true;
104        while made_progress && !unsolved.is_empty() {
105            made_progress = false;
106            let mut still_unsolved = Vec::new();
107
108            for (t1, t2) in unsolved.drain(..) {
109                if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
110                    still_unsolved.push((t1, t2));
111                } else {
112                    made_progress = true;
113                }
114            }
115
116            unsolved = still_unsolved;
117        }
118
119        // Check if any constraints remain unsolved
120        if !unsolved.is_empty() {
121            return Err(TypeError::UnsolvedConstraints(unsolved));
122        }
123
124        // Apply bounds to type variables
125        self.apply_bounds()?;
126
127        Ok(())
128    }
129
130    /// Solve a single constraint
131    fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
132        // Apply current substitutions before matching to avoid overwriting
133        // existing bindings (e.g., T17=string overwritten by T17=T19 during
134        // Function param/return pairwise unification).
135        let t1 = self.unifier.apply_substitutions(&t1);
136        let t2 = self.unifier.apply_substitutions(&t2);
137
138        match (&t1, &t2) {
139            // Variable constraints
140            (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
141
142            // Constrained type variables — must be matched BEFORE the general
143            // Variable arm, otherwise (Variable, Constrained) pairs are caught
144            // by the Variable arm and the bound is never recorded.
145            (Type::Constrained { var, constraint }, ty)
146            | (ty, Type::Constrained { var, constraint }) => {
147                // Record the constraint
148                self.bounds.insert(var.clone(), *constraint.clone());
149
150                // Unify with the underlying type
151                self.solve_constraint(Type::Variable(var.clone()), ty.clone())
152            }
153
154            (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
155                // Check occurs check
156                if self.occurs_in(var, ty) {
157                    return Err(TypeError::InfiniteType(var.clone()));
158                }
159
160                self.unifier.bind(var.clone(), ty.clone());
161                Ok(())
162            }
163
164            // Concrete type constraints
165            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
166                if self.unify_annotations(ann1, ann2)? {
167                    Ok(())
168                } else if Self::can_numeric_widen(ann1, ann2) {
169                    // Implicit numeric promotion (int → number/float)
170                    Ok(())
171                } else {
172                    Err(TypeError::TypeMismatch(
173                        format!("{:?}", ann1),
174                        format!("{:?}", ann2),
175                    ))
176                }
177            }
178
179            // Generic type constraints
180            (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
181                self.solve_constraint(*b1.clone(), *b2.clone())?;
182
183                let is_result_base = |base: &Type| match base {
184                    Type::Concrete(TypeAnnotation::Reference(name)) => name == "Result",
185                    Type::Concrete(TypeAnnotation::Basic(name)) => name == "Result",
186                    _ => false,
187                };
188
189                if a1.len() != a2.len() {
190                    if is_result_base(&b1) && is_result_base(&b2) {
191                        match (a1.len(), a2.len()) {
192                            // `Result<T>` is error-agnostic shorthand and should unify
193                            // with `Result<T, E>` by constraining only the success type.
194                            (1, 2) | (2, 1) => {
195                                self.solve_constraint(a1[0].clone(), a2[0].clone())?;
196                                return Ok(());
197                            }
198                            _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
199                        }
200                    } else {
201                        return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
202                    }
203                }
204
205                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
206                    self.solve_constraint(arg1.clone(), arg2.clone())?;
207                }
208
209                Ok(())
210            }
211
212            // Function ~ Function: pairwise unify params + returns
213            (
214                Type::Function {
215                    params: p1,
216                    returns: r1,
217                },
218                Type::Function {
219                    params: p2,
220                    returns: r2,
221                },
222            ) => {
223                if p1.len() != p2.len() {
224                    return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
225                }
226                for (param1, param2) in p1.iter().zip(p2.iter()) {
227                    // Parameter compatibility is checked from observed/actual to
228                    // declared/expected shape so directional numeric widening
229                    // (e.g. int -> number) remains valid in call constraints.
230                    self.solve_constraint(param2.clone(), param1.clone())?;
231                }
232                self.solve_constraint(*r1.clone(), *r2.clone())
233            }
234
235            // Cross-compatibility: Type::Function ~ Concrete(TypeAnnotation::Function)
236            (
237                Type::Function {
238                    params: fp,
239                    returns: fr,
240                },
241                Type::Concrete(TypeAnnotation::Function {
242                    params: cp,
243                    returns: cr,
244                }),
245            )
246            | (
247                Type::Concrete(TypeAnnotation::Function {
248                    params: cp,
249                    returns: cr,
250                }),
251                Type::Function {
252                    params: fp,
253                    returns: fr,
254                },
255            ) => {
256                if fp.len() != cp.len() {
257                    return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
258                }
259                for (f_param, c_param) in fp.iter().zip(cp.iter()) {
260                    self.solve_constraint(
261                        f_param.clone(),
262                        Type::Concrete(c_param.type_annotation.clone()),
263                    )?;
264                }
265                self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
266            }
267
268            // Array<T> (Type::Generic with base "Array" or "Vec") ~ Concrete(Array(T))
269            (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
270            | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
271                if args.len() == 1 && is_array_or_vec_base(base) =>
272            {
273                self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
274            }
275
276            _ => Err(TypeError::TypeMismatch(
277                format!("{:?}", t1),
278                format!("{:?}", t2),
279            )),
280        }
281    }
282
283    /// Check if a type variable occurs in a type (occurs check)
284    fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
285        match ty {
286            Type::Variable(v) => v == var,
287            Type::Generic { base, args } => {
288                self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
289            }
290            Type::Constrained { var: v, .. } => v == var,
291            Type::Function { params, returns } => {
292                params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
293            }
294            Type::Concrete(_) => false,
295        }
296    }
297
298    /// Check if a numeric type can widen to another (directional).
299    ///
300    /// Integer-family types (`int`, `i16`, `u32`, `byte`, ...) can widen to
301    /// number-family types (`number`, `f32`, `f64`, ...).
302    /// `number → int` does NOT widen (lossy). `decimal → number` does NOT widen
303    /// (different precision semantics).
304    fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
305        let from_name = match from {
306            TypeAnnotation::Basic(name) => Some(name.as_str()),
307            TypeAnnotation::Reference(name) => Some(name.as_str()),
308            _ => None,
309        };
310        let to_name = match to {
311            TypeAnnotation::Basic(name) => Some(name.as_str()),
312            TypeAnnotation::Reference(name) => Some(name.as_str()),
313            _ => None,
314        };
315
316        match (from_name, to_name) {
317            (Some(f), Some(t)) => {
318                BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
319            }
320            _ => false,
321        }
322    }
323
324    /// Unify two type annotations
325    fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
326        match (ann1, ann2) {
327            // Basic types
328            (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
329                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
330            }
331            (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
332            (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
333            | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
334                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
335            }
336
337            // Array types
338            (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
339                self.unify_annotations(e1, e2)
340            }
341
342            // Tuple types
343            (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
344                if t1.len() != t2.len() {
345                    return Ok(false);
346                }
347
348                for (elem1, elem2) in t1.iter().zip(t2.iter()) {
349                    if !self.unify_annotations(elem1, elem2)? {
350                        return Ok(false);
351                    }
352                }
353
354                Ok(true)
355            }
356
357            // Structural object types
358            (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
359                self.object_fields_compatible(f1, f2)
360            }
361
362            // Function types
363            (
364                TypeAnnotation::Function {
365                    params: p1,
366                    returns: r1,
367                },
368                TypeAnnotation::Function {
369                    params: p2,
370                    returns: r2,
371                },
372            ) => {
373                if p1.len() != p2.len() {
374                    return Ok(false);
375                }
376
377                for (param1, param2) in p1.iter().zip(p2.iter()) {
378                    if !self.unify_annotations(&param1.type_annotation, &param2.type_annotation)? {
379                        return Ok(false);
380                    }
381                }
382
383                self.unify_annotations(r1, r2)
384            }
385
386            // Union types
387            // A | B unifies with C | D if each type in one union can unify with at least one type in the other
388            (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
389                // Check that every type in u1 can unify with at least one type in u2
390                for t1 in u1 {
391                    let mut found_match = false;
392                    for t2 in u2 {
393                        if self.unify_annotations(t1, t2)? {
394                            found_match = true;
395                            break;
396                        }
397                    }
398                    if !found_match {
399                        return Ok(false);
400                    }
401                }
402                // Check that every type in u2 can unify with at least one type in u1
403                for t2 in u2 {
404                    let mut found_match = false;
405                    for t1 in u1 {
406                        if self.unify_annotations(t1, t2)? {
407                            found_match = true;
408                            break;
409                        }
410                    }
411                    if !found_match {
412                        return Ok(false);
413                    }
414                }
415                Ok(true)
416            }
417
418            // Union with non-union: A | B unifies with C if either A or B unifies with C
419            (TypeAnnotation::Union(union_types), other)
420            | (other, TypeAnnotation::Union(union_types)) => {
421                for union_type in union_types {
422                    if self.unify_annotations(union_type, other)? {
423                        return Ok(true);
424                    }
425                }
426                Ok(false)
427            }
428
429            // Intersection types (order-independent)
430            (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
431                self.unify_annotation_sets(i1, i2)
432            }
433
434            // Void, Null, Undefined
435            (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
436            (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
437            (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
438
439            // Trait object types: dyn Trait1 + Trait2
440            // Two trait objects unify if they have the same set of traits
441            (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
442                Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
443            }
444
445            // Array<T> (Generic) is equivalent to Vec<T> (Array)
446            (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
447            | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
448                if name == "Array" && args.len() == 1 =>
449            {
450                self.unify_annotations(&args[0], elem)
451            }
452
453            // Different types don't unify
454            _ => Ok(false),
455        }
456    }
457
458    fn object_fields_compatible(
459        &self,
460        left: &[ObjectTypeField],
461        right: &[ObjectTypeField],
462    ) -> TypeResult<bool> {
463        for left_field in left {
464            let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
465                return Ok(false);
466            };
467            if left_field.optional != right_field.optional {
468                return Ok(false);
469            }
470            if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
471                return Ok(false);
472            }
473        }
474        if left.len() != right.len() {
475            return Ok(false);
476        }
477        Ok(true)
478    }
479
480    fn unify_annotation_sets(
481        &self,
482        left: &[TypeAnnotation],
483        right: &[TypeAnnotation],
484    ) -> TypeResult<bool> {
485        if left.len() != right.len() {
486            return Ok(false);
487        }
488
489        let mut matched = vec![false; right.len()];
490        for left_ann in left {
491            let mut found = false;
492            for (idx, right_ann) in right.iter().enumerate() {
493                if matched[idx] {
494                    continue;
495                }
496                if self.unify_annotations(left_ann, right_ann)? {
497                    matched[idx] = true;
498                    found = true;
499                    break;
500                }
501            }
502            if !found {
503                return Ok(false);
504            }
505        }
506
507        Ok(true)
508    }
509
510    /// Apply type variable bounds, propagating resolved field types back to type variables.
511    ///
512    /// When a `HasField` constraint is satisfied and the expected field type was a
513    /// type variable, this binds that variable to the actual field type. This enables
514    /// backward propagation: `let f = |obj| obj.x; f({x: 42})` resolves `obj.x` to `int`.
515    fn apply_bounds(&mut self) -> TypeResult<()> {
516        let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
517
518        for (var, constraint) in &self.bounds {
519            // Use apply_substitutions to follow the full variable chain
520            // (lookup only returns the direct binding, not the resolved type).
521            let resolved = self
522                .unifier
523                .apply_substitutions(&Type::Variable(var.clone()));
524
525            if let Type::Variable(_) = &resolved {
526                // Still unresolved — skip for now
527                continue;
528            }
529
530            self.check_constraint(&resolved, constraint)?;
531
532            // Backward propagation: when a HasField constraint is satisfied,
533            // bind the result type variable to the actual field type.
534            if let TypeConstraint::HasField(field, expected_field_type) = constraint {
535                if let Type::Variable(field_var) = expected_field_type.as_ref() {
536                    // Also check if the field var is already resolved
537                    let field_resolved = self
538                        .unifier
539                        .apply_substitutions(&Type::Variable(field_var.clone()));
540                    if let Type::Variable(_) = &field_resolved {
541                        // Field var still unresolved — try to bind it
542                        if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
543                            if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
544                                new_bindings.push((
545                                    field_var.clone(),
546                                    Type::Concrete(found_field.type_annotation.clone()),
547                                ));
548                            }
549                        }
550                    }
551                }
552            }
553
554            // Backward propagation: when an Indexable constraint is satisfied
555            // and the constrained variable resolved to a concrete array type,
556            // bind the carried element variable to the actual element type.
557            // This is the connective tissue that lets `a[0]` on an
558            // unannotated parameter recover its element type once `a`
559            // resolves to `Array<int>` (via callsite unification). Without
560            // this, the index access returns a disconnected fresh variable
561            // and a downstream `a[0] + b[0]` sees `unknown + unknown`.
562            if let TypeConstraint::Indexable(expected_elem_type) = constraint {
563                if let Type::Variable(elem_var) = expected_elem_type.as_ref() {
564                    let elem_resolved = self
565                        .unifier
566                        .apply_substitutions(&Type::Variable(elem_var.clone()));
567                    if let Type::Variable(_) = &elem_resolved {
568                        let actual_elem: Option<Type> = match &resolved {
569                            Type::Concrete(TypeAnnotation::Array(elem)) => {
570                                Some(Type::Concrete((**elem).clone()))
571                            }
572                            Type::Generic { base, args }
573                                if args.len() == 1 && is_array_or_vec_base(base) =>
574                            {
575                                Some(args[0].clone())
576                            }
577                            // String indexing yields a single-character string.
578                            Type::Concrete(TypeAnnotation::Basic(name))
579                                if name == "string" =>
580                            {
581                                Some(BuiltinTypes::string())
582                            }
583                            _ => None,
584                        };
585                        if let Some(elem_ty) = actual_elem {
586                            if !matches!(elem_ty, Type::Variable(_)) {
587                                new_bindings.push((elem_var.clone(), elem_ty));
588                            }
589                        }
590                    }
591                }
592            }
593        }
594
595        // Apply collected bindings
596        for (var, ty) in new_bindings {
597            self.unifier.bind(var, ty);
598        }
599
600        Ok(())
601    }
602
603    /// Check if a type satisfies a constraint
604    fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
605        match constraint {
606            TypeConstraint::Comparable => match ty {
607                Type::Concrete(TypeAnnotation::Basic(name))
608                    if BuiltinTypes::is_numeric_type_name(name)
609                        || name == "string"
610                        || name == "bool" =>
611                {
612                    Ok(())
613                }
614                _ => Err(TypeError::ConstraintViolation(format!(
615                    "{:?} is not comparable",
616                    ty
617                ))),
618            },
619
620            TypeConstraint::Iterable => match ty {
621                Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
622                Type::Concrete(TypeAnnotation::Basic(name))
623                    if name == "string" || name == "rows" =>
624                {
625                    Ok(())
626                }
627                _ => Err(TypeError::ConstraintViolation(format!(
628                    "{:?} is not iterable",
629                    ty
630                ))),
631            },
632
633            // `obj[i]` index access. The carried element type is bound by
634            // `apply_bounds` backward propagation (mirrors `HasField`); here
635            // we only validate that the resolved type supports indexing.
636            TypeConstraint::Indexable(_) => match ty {
637                Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
638                Type::Generic { base, args }
639                    if args.len() == 1 && is_array_or_vec_base(base) =>
640                {
641                    Ok(())
642                }
643                Type::Concrete(TypeAnnotation::Basic(name))
644                    if name == "string" || name == "rows" =>
645                {
646                    Ok(())
647                }
648                _ => Err(TypeError::ConstraintViolation(format!(
649                    "{:?} does not support index access",
650                    ty
651                ))),
652            },
653
654            TypeConstraint::HasField(field, expected_field_type) => {
655                match ty {
656                    Type::Concrete(TypeAnnotation::Object(fields)) => {
657                        match fields.iter().find(|f| f.name == *field) {
658                            Some(found_field) => {
659                                // Check that field type matches expected type
660                                if let Some(expected_ann) = expected_field_type.to_annotation() {
661                                    if self.unify_annotations(
662                                        &found_field.type_annotation,
663                                        &expected_ann,
664                                    )? {
665                                        Ok(())
666                                    } else {
667                                        Err(TypeError::ConstraintViolation(format!(
668                                            "field '{}' has type {:?}, expected {:?}",
669                                            field, found_field.type_annotation, expected_ann
670                                        )))
671                                    }
672                                } else {
673                                    // Expected type is a type variable, accept any field type
674                                    Ok(())
675                                }
676                            }
677                            None => Err(TypeError::ConstraintViolation(format!(
678                                "{:?} does not have field '{}'",
679                                ty, field
680                            ))),
681                        }
682                    }
683                    Type::Concrete(TypeAnnotation::Basic(_name)) => {
684                        // For named types, we assume property access was validated
685                        // during inference using the schema registry. If a HasField
686                        // constraint reaches here, it means the type wasn't a known
687                        // schema type during inference, so we accept it tentatively.
688                        // Runtime will do the final validation.
689                        //
690                        // Note: Previously this hardcoded "row" with OHLCV fields.
691                        // Now schema validation happens in TypeInferenceEngine::infer_property_access.
692                        Ok(())
693                    }
694                    _ => Err(TypeError::ConstraintViolation(format!(
695                        "{:?} cannot have fields",
696                        ty
697                    ))),
698                }
699            }
700
701            TypeConstraint::Callable {
702                params: expected_params,
703                returns: expected_returns,
704            } => {
705                match ty {
706                    Type::Concrete(TypeAnnotation::Function {
707                        params: actual_params,
708                        returns: actual_returns,
709                    }) => {
710                        // Check parameter count matches
711                        if expected_params.len() != actual_params.len() {
712                            return Err(TypeError::ConstraintViolation(format!(
713                                "function expects {} parameters, got {}",
714                                expected_params.len(),
715                                actual_params.len()
716                            )));
717                        }
718
719                        // Check each parameter type (contravariant: expected <: actual)
720                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
721                            if let Some(expected_ann) = expected.to_annotation() {
722                                if !self
723                                    .unify_annotations(&expected_ann, &actual.type_annotation)?
724                                {
725                                    return Err(TypeError::ConstraintViolation(format!(
726                                        "parameter type mismatch: expected {:?}, got {:?}",
727                                        expected_ann, actual.type_annotation
728                                    )));
729                                }
730                            }
731                        }
732
733                        // Check return type (covariant: actual <: expected)
734                        if let Some(expected_ret_ann) = expected_returns.to_annotation() {
735                            if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
736                                return Err(TypeError::ConstraintViolation(format!(
737                                    "return type mismatch: expected {:?}, got {:?}",
738                                    expected_ret_ann, actual_returns
739                                )));
740                            }
741                        }
742
743                        Ok(())
744                    }
745                    Type::Function {
746                        params: actual_params,
747                        returns: actual_returns,
748                    } => {
749                        if expected_params.len() != actual_params.len() {
750                            return Err(TypeError::ConstraintViolation(format!(
751                                "function expects {} parameters, got {}",
752                                expected_params.len(),
753                                actual_params.len()
754                            )));
755                        }
756                        // Type::Function params are Type, not FunctionParam — compare directly
757                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
758                            if let (Some(e_ann), Some(a_ann)) =
759                                (expected.to_annotation(), actual.to_annotation())
760                            {
761                                if !self.unify_annotations(&e_ann, &a_ann)? {
762                                    return Err(TypeError::ConstraintViolation(format!(
763                                        "parameter type mismatch: expected {:?}, got {:?}",
764                                        e_ann, a_ann
765                                    )));
766                                }
767                            }
768                        }
769                        if let (Some(e_ret), Some(a_ret)) = (
770                            expected_returns.to_annotation(),
771                            actual_returns.to_annotation(),
772                        ) {
773                            if !self.unify_annotations(&a_ret, &e_ret)? {
774                                return Err(TypeError::ConstraintViolation(format!(
775                                    "return type mismatch: expected {:?}, got {:?}",
776                                    e_ret, a_ret
777                                )));
778                            }
779                        }
780                        Ok(())
781                    }
782                    _ => Err(TypeError::ConstraintViolation(format!(
783                        "{:?} is not callable",
784                        ty
785                    ))),
786                }
787            }
788
789            TypeConstraint::OneOf(options) => {
790                for option in options {
791                    // If type matches any option, constraint is satisfied
792                    if let Type::Concrete(ann) = option {
793                        if let Type::Concrete(ty_ann) = ty {
794                            if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
795                                return Ok(());
796                            }
797                        }
798                    }
799                }
800
801                Err(TypeError::ConstraintViolation(format!(
802                    "{:?} does not match any of {:?}",
803                    ty, options
804                )))
805            }
806
807            TypeConstraint::Extends(base) => {
808                // Implement subtyping check
809                self.is_subtype(ty, base)
810            }
811
812            TypeConstraint::ImplementsTrait { trait_name } => {
813                match ty {
814                    Type::Variable(_) => {
815                        // Type variable not yet resolved — this is a compile error
816                        // (no deferring per Sprint 2 spec)
817                        Err(TypeError::TraitBoundViolation {
818                            type_name: format!("{:?}", ty),
819                            trait_name: trait_name.clone(),
820                        })
821                    }
822                    Type::Concrete(ann) => {
823                        let type_name = match ann {
824                            TypeAnnotation::Basic(n) => n.clone(),
825                            TypeAnnotation::Reference(n) => n.to_string(),
826                            _ => format!("{:?}", ann),
827                        };
828                        if self.has_trait_impl(trait_name, &type_name) {
829                            Ok(())
830                        } else {
831                            Err(TypeError::TraitBoundViolation {
832                                type_name,
833                                trait_name: trait_name.clone(),
834                            })
835                        }
836                    }
837                    Type::Generic { base, .. } => {
838                        let type_name = match base.as_ref() {
839                            Type::Concrete(TypeAnnotation::Reference(n)) => n.to_string(),
840                            Type::Concrete(TypeAnnotation::Basic(n)) => n.clone(),
841                            _ => format!("{:?}", base),
842                        };
843                        if self.has_trait_impl(trait_name, &type_name) {
844                            Ok(())
845                        } else {
846                            Err(TypeError::TraitBoundViolation {
847                                type_name,
848                                trait_name: trait_name.clone(),
849                            })
850                        }
851                    }
852                    _ => Err(TypeError::TraitBoundViolation {
853                        type_name: format!("{:?}", ty),
854                        trait_name: trait_name.clone(),
855                    }),
856                }
857            }
858
859            TypeConstraint::HasMethod {
860                method_name,
861                arg_types: _,
862                return_type: _,
863            } => {
864                // If we have a method table, enforce the constraint
865                if let Some(method_table) = &self.method_table {
866                    match ty {
867                        Type::Variable(_) => Ok(()), // Unresolved type var, defer
868                        Type::Concrete(ann) => {
869                            let type_name = match ann {
870                                TypeAnnotation::Basic(n) => n.clone(),
871                                TypeAnnotation::Reference(n) => n.to_string(),
872                                TypeAnnotation::Array(_) => "Vec".to_string(),
873                                _ => return Ok(()), // Complex types: accept
874                            };
875                            if method_table.lookup(ty, method_name).is_some() {
876                                Ok(())
877                            } else {
878                                Err(TypeError::MethodNotFound {
879                                    type_name,
880                                    method_name: method_name.clone(),
881                                })
882                            }
883                        }
884                        Type::Generic { base, .. } => {
885                            if method_table.lookup(ty, method_name).is_some() {
886                                Ok(())
887                            } else {
888                                let type_name =
889                                    if let Type::Concrete(TypeAnnotation::Reference(n)) =
890                                        base.as_ref()
891                                    {
892                                        n.to_string()
893                                    } else {
894                                        format!("{:?}", base)
895                                    };
896                                Err(TypeError::MethodNotFound {
897                                    type_name,
898                                    method_name: method_name.clone(),
899                                })
900                            }
901                        }
902                        _ => Ok(()), // Function, Constrained: accept
903                    }
904                } else {
905                    // No method table attached — accept all (backward compatible)
906                    Ok(())
907                }
908            }
909        }
910    }
911
912    /// Check if a type implements a trait, considering aliases and numeric widening.
913    ///
914    /// Handles three resolution strategies:
915    /// 1. Direct lookup: `"Numeric::int"` in the trait_impls set
916    /// 2. Canonical alias: `"Float"` → `"f64"`, `"byte"` → `"u8"` via runtime name table
917    /// 3. Script alias: `"i64"` → `"int"`, `"f64"` → `"number"` via script alias table
918    /// 4. Numeric widening: integer-family names can satisfy number/float/f64 impls
919    fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
920        let key = format!("{}::{}", trait_name, type_name);
921        if self.trait_impls.contains(&key) {
922            return true;
923        }
924        // Try canonical runtime alias (e.g. "Float" -> "f64", "byte" -> "u8")
925        if let Some(canonical) = BuiltinTypes::canonical_numeric_runtime_name(type_name) {
926            let canon_key = format!("{}::{}", trait_name, canonical);
927            if self.trait_impls.contains(&canon_key) {
928                return true;
929            }
930        }
931        // Try script-facing alias (e.g. "i64" -> "int", "f64" -> "number")
932        if let Some(script_alias) = BuiltinTypes::canonical_script_alias(type_name) {
933            let alias_key = format!("{}::{}", trait_name, script_alias);
934            if self.trait_impls.contains(&alias_key) {
935                return true;
936            }
937        }
938        // Numeric widening: integer-family aliases can use number/float/f64 impls.
939        if BuiltinTypes::is_integer_type_name(type_name) {
940            for widen_to in &["number", "float", "f64"] {
941                let widen_key = format!("{}::{}", trait_name, widen_to);
942                if self.trait_impls.contains(&widen_key) {
943                    return true;
944                }
945            }
946        }
947        false
948    }
949
950    /// Check if ty is a subtype of base (ty <: base)
951    /// Subtyping rules:
952    /// - Same types are subtypes of each other
953    /// - Any is a supertype of everything
954    /// - Vec<A> <: Vec<B> if A <: B (covariant)
955    /// - Function<P1, R1> <: Function<P2, R2> if P2 <: P1 (contravariant params) and R1 <: R2 (covariant return)
956    fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
957        match (ty, base) {
958            // Same types are subtypes
959            (t1, t2) if t1 == t2 => Ok(()),
960
961            // Type variables - if we can unify, it's compatible
962            (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
963
964            // Array subtyping (covariant)
965            (
966                Type::Concrete(TypeAnnotation::Array(elem1)),
967                Type::Concrete(TypeAnnotation::Array(elem2)),
968            ) => {
969                let t1 = Type::Concrete(*elem1.clone());
970                let t2 = Type::Concrete(*elem2.clone());
971                self.is_subtype(&t1, &t2)
972            }
973
974            // Function subtyping (contravariant params, covariant return)
975            (
976                Type::Concrete(TypeAnnotation::Function {
977                    params: p1,
978                    returns: r1,
979                }),
980                Type::Concrete(TypeAnnotation::Function {
981                    params: p2,
982                    returns: r2,
983                }),
984            ) => {
985                // Check parameter count
986                if p1.len() != p2.len() {
987                    return Err(TypeError::ConstraintViolation(format!(
988                        "function parameter count mismatch: {} vs {}",
989                        p1.len(),
990                        p2.len()
991                    )));
992                }
993
994                // Contravariant: base params must be subtypes of ty params
995                for (param1, param2) in p1.iter().zip(p2.iter()) {
996                    let t1 = Type::Concrete(param2.type_annotation.clone());
997                    let t2 = Type::Concrete(param1.type_annotation.clone());
998                    self.is_subtype(&t1, &t2)?;
999                }
1000
1001                // Covariant: ty return must be subtype of base return
1002                let ret1 = Type::Concrete(*r1.clone());
1003                let ret2 = Type::Concrete(*r2.clone());
1004                self.is_subtype(&ret1, &ret2)
1005            }
1006
1007            // Optional subtyping: T <: Option<T>
1008            (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
1009                if name == "Option" && args.len() == 1 =>
1010            {
1011                let inner = Type::Concrete(args[0].clone());
1012                self.is_subtype(t, &inner)
1013            }
1014
1015            // Type::Function subtyping (contravariant params, covariant return)
1016            (
1017                Type::Function {
1018                    params: p1,
1019                    returns: r1,
1020                },
1021                Type::Function {
1022                    params: p2,
1023                    returns: r2,
1024                },
1025            ) => {
1026                if p1.len() != p2.len() {
1027                    return Err(TypeError::ConstraintViolation(format!(
1028                        "function parameter count mismatch: {} vs {}",
1029                        p1.len(),
1030                        p2.len()
1031                    )));
1032                }
1033                // Contravariant params
1034                for (param1, param2) in p1.iter().zip(p2.iter()) {
1035                    self.is_subtype(param2, param1)?;
1036                }
1037                // Covariant return
1038                self.is_subtype(r1, r2)
1039            }
1040
1041            // Basic types - check if they unify
1042            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
1043                if self.unify_annotations(ann1, ann2)? {
1044                    Ok(())
1045                } else {
1046                    Err(TypeError::ConstraintViolation(format!(
1047                        "{:?} is not a subtype of {:?}",
1048                        ty, base
1049                    )))
1050                }
1051            }
1052
1053            // Default: not a subtype
1054            _ => Err(TypeError::ConstraintViolation(format!(
1055                "{:?} is not a subtype of {:?}",
1056                ty, base
1057            ))),
1058        }
1059    }
1060
1061    /// Get the unifier for applying substitutions
1062    pub fn unifier(&self) -> &Unifier {
1063        &self.unifier
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use crate::type_system::TypeVarGen;
1071    use shape_ast::ast::ObjectTypeField;
1072
1073    /// Test-local helper: allocate a fresh type variable from a
1074    /// per-test counter. Each test owns its own `TypeVarGen`, so IDs
1075    /// (`T0`, `T1`, ...) are deterministic and independent across tests.
1076    fn fresh_var(tvgen: &mut TypeVarGen) -> TypeVar {
1077        tvgen.fresh_var()
1078    }
1079
1080    fn fresh_type(tvgen: &mut TypeVarGen) -> Type {
1081        tvgen.fresh_type()
1082    }
1083
1084    #[test]
1085    fn test_hasfield_backward_propagation_binds_field_type() {
1086        // When a TypeVar has a HasField constraint and is resolved to a concrete
1087        // object type, the field's result type variable should be bound to the
1088        // actual field type. This enables backward type propagation.
1089        let mut solver = ConstraintSolver::new();
1090        let mut tvgen = TypeVarGen::new();
1091
1092        let obj_var = fresh_var(&mut tvgen);
1093        let field_result_var = fresh_var(&mut tvgen);
1094        let bound_var = fresh_var(&mut tvgen);
1095
1096        let mut constraints = vec![
1097            // obj_var ~ Constrained { var: bound_var, HasField("x", field_result_var) }
1098            // This records bound: bound_var → HasField("x", field_result_var)
1099            // and solves: bound_var ~ obj_var
1100            (
1101                Type::Variable(obj_var.clone()),
1102                Type::Constrained {
1103                    var: bound_var,
1104                    constraint: Box::new(TypeConstraint::HasField(
1105                        "x".to_string(),
1106                        Box::new(Type::Variable(field_result_var.clone())),
1107                    )),
1108                },
1109            ),
1110            // obj_var = {x: int}
1111            (
1112                Type::Variable(obj_var),
1113                Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1114                    name: "x".to_string(),
1115                    optional: false,
1116                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1117                    annotations: vec![],
1118                }])),
1119            ),
1120        ];
1121
1122        solver.solve(&mut constraints).unwrap();
1123
1124        // field_result_var should now be resolved to int via apply_bounds
1125        let resolved = solver
1126            .unifier()
1127            .apply_substitutions(&Type::Variable(field_result_var));
1128        match &resolved {
1129            Type::Concrete(TypeAnnotation::Basic(name)) => {
1130                assert_eq!(name, "int", "field type should be int");
1131            }
1132            _ => panic!(
1133                "Expected field_result_var to be resolved to int, got {:?}",
1134                resolved
1135            ),
1136        }
1137    }
1138
1139    #[test]
1140    fn test_hasfield_backward_propagation_multiple_fields() {
1141        // Test that multiple HasField constraints on the same object all propagate
1142        let mut solver = ConstraintSolver::new();
1143        let mut tvgen = TypeVarGen::new();
1144
1145        let obj_var = fresh_var(&mut tvgen);
1146        let field_x_var = fresh_var(&mut tvgen);
1147        let field_y_var = fresh_var(&mut tvgen);
1148        let bound_var_x = fresh_var(&mut tvgen);
1149        let bound_var_y = fresh_var(&mut tvgen);
1150
1151        let mut constraints = vec![
1152            // HasField("x", field_x_var)
1153            (
1154                Type::Variable(obj_var.clone()),
1155                Type::Constrained {
1156                    var: bound_var_x,
1157                    constraint: Box::new(TypeConstraint::HasField(
1158                        "x".to_string(),
1159                        Box::new(Type::Variable(field_x_var.clone())),
1160                    )),
1161                },
1162            ),
1163            // HasField("y", field_y_var)
1164            (
1165                Type::Variable(obj_var.clone()),
1166                Type::Constrained {
1167                    var: bound_var_y,
1168                    constraint: Box::new(TypeConstraint::HasField(
1169                        "y".to_string(),
1170                        Box::new(Type::Variable(field_y_var.clone())),
1171                    )),
1172                },
1173            ),
1174            // obj_var = {x: int, y: string}
1175            (
1176                Type::Variable(obj_var),
1177                Type::Concrete(TypeAnnotation::Object(vec![
1178                    ObjectTypeField {
1179                        name: "x".to_string(),
1180                        optional: false,
1181                        type_annotation: TypeAnnotation::Basic("int".to_string()),
1182                        annotations: vec![],
1183                    },
1184                    ObjectTypeField {
1185                        name: "y".to_string(),
1186                        optional: false,
1187                        type_annotation: TypeAnnotation::Basic("string".to_string()),
1188                        annotations: vec![],
1189                    },
1190                ])),
1191            ),
1192        ];
1193
1194        solver.solve(&mut constraints).unwrap();
1195
1196        let resolved_x = solver
1197            .unifier()
1198            .apply_substitutions(&Type::Variable(field_x_var));
1199        let resolved_y = solver
1200            .unifier()
1201            .apply_substitutions(&Type::Variable(field_y_var));
1202
1203        match &resolved_x {
1204            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1205            _ => panic!("Expected x to be int, got {:?}", resolved_x),
1206        }
1207        match &resolved_y {
1208            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1209            _ => panic!("Expected y to be string, got {:?}", resolved_y),
1210        }
1211    }
1212
1213    // ===== Fix 1: Numeric type preservation tests =====
1214
1215    #[test]
1216    fn test_int_constrained_numeric_succeeds() {
1217        // Concrete(int) ~ Constrained(ImplementsTrait("Numeric")) should succeed
1218        let mut solver = ConstraintSolver::new();
1219        // Inject Numeric trait impls (same as TypeEnvironment registers)
1220        let trait_impls: std::collections::HashSet<String> = [
1221            "Numeric::int", "Numeric::number", "Numeric::decimal",
1222            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1223            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1224            "Numeric::f32", "Numeric::f64",
1225        ].iter().map(|s| s.to_string()).collect();
1226        solver.set_trait_impls(trait_impls);
1227        let mut tvgen = TypeVarGen::new();
1228        let bound_var = fresh_var(&mut tvgen);
1229        let mut constraints = vec![(
1230            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1231            Type::Constrained {
1232                var: bound_var,
1233                constraint: Box::new(TypeConstraint::ImplementsTrait {
1234                    trait_name: "Numeric".to_string(),
1235                }),
1236            },
1237        )];
1238        assert!(solver.solve(&mut constraints).is_ok());
1239    }
1240
1241    #[test]
1242    fn test_numeric_widening_int_to_number() {
1243        // (Concrete(int), Concrete(number)) should succeed via widening
1244        let mut solver = ConstraintSolver::new();
1245        let mut constraints = vec![(
1246            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1247            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1248        )];
1249        assert!(solver.solve(&mut constraints).is_ok());
1250    }
1251
1252    #[test]
1253    fn test_numeric_widening_width_aware_integer_to_float_family() {
1254        let mut solver = ConstraintSolver::new();
1255        let mut constraints = vec![(
1256            Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1257            Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1258        )];
1259        assert!(solver.solve(&mut constraints).is_ok());
1260    }
1261
1262    #[test]
1263    fn test_no_widening_number_to_int() {
1264        // (Concrete(number), Concrete(int)) should fail — lossy
1265        let mut solver = ConstraintSolver::new();
1266        let mut constraints = vec![(
1267            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1268            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1269        )];
1270        assert!(solver.solve(&mut constraints).is_err());
1271    }
1272
1273    #[test]
1274    fn test_decimal_constrained_numeric_succeeds() {
1275        let mut solver = ConstraintSolver::new();
1276        let trait_impls: std::collections::HashSet<String> = [
1277            "Numeric::int", "Numeric::number", "Numeric::decimal",
1278            "Numeric::i8", "Numeric::i16", "Numeric::i32", "Numeric::i64",
1279            "Numeric::u8", "Numeric::u16", "Numeric::u32", "Numeric::u64",
1280            "Numeric::f32", "Numeric::f64",
1281        ].iter().map(|s| s.to_string()).collect();
1282        solver.set_trait_impls(trait_impls);
1283        let mut tvgen = TypeVarGen::new();
1284        let bound_var = fresh_var(&mut tvgen);
1285        let mut constraints = vec![(
1286            Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1287            Type::Constrained {
1288                var: bound_var,
1289                constraint: Box::new(TypeConstraint::ImplementsTrait {
1290                    trait_name: "Numeric".to_string(),
1291                }),
1292            },
1293        )];
1294        assert!(solver.solve(&mut constraints).is_ok());
1295    }
1296
1297    #[test]
1298    fn test_comparable_accepts_int() {
1299        // int should be Comparable
1300        let mut solver = ConstraintSolver::new();
1301        let mut tvgen = TypeVarGen::new();
1302        let bound_var = fresh_var(&mut tvgen);
1303        let mut constraints = vec![(
1304            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1305            Type::Constrained {
1306                var: bound_var,
1307                constraint: Box::new(TypeConstraint::Comparable),
1308            },
1309        )];
1310        assert!(solver.solve(&mut constraints).is_ok());
1311    }
1312
1313    // ===== Fix 2: Type::Function tests =====
1314
1315    #[test]
1316    fn test_function_type_preserves_variables() {
1317        // BuiltinTypes::function with Variable params should be Type::Function
1318        let mut tvgen = TypeVarGen::new();
1319        let param = fresh_type(&mut tvgen);
1320        let ret = fresh_type(&mut tvgen);
1321        let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1322        match func {
1323            Type::Function { params, returns } => {
1324                assert_eq!(params.len(), 1);
1325                assert_eq!(params[0], param);
1326                assert_eq!(*returns, ret);
1327            }
1328            _ => panic!("Expected Type::Function, got {:?}", func),
1329        }
1330    }
1331
1332    #[test]
1333    fn test_function_unification_binds_variables() {
1334        // (T1)->T2 ~ (number)->string should bind T1=number, T2=string
1335        let mut solver = ConstraintSolver::new();
1336        let mut tvgen = TypeVarGen::new();
1337        let t1 = fresh_var(&mut tvgen);
1338        let t2 = fresh_var(&mut tvgen);
1339
1340        let mut constraints = vec![(
1341            Type::Function {
1342                params: vec![Type::Variable(t1.clone())],
1343                returns: Box::new(Type::Variable(t2.clone())),
1344            },
1345            Type::Function {
1346                params: vec![BuiltinTypes::number()],
1347                returns: Box::new(BuiltinTypes::string()),
1348            },
1349        )];
1350
1351        solver.solve(&mut constraints).unwrap();
1352
1353        let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1354        let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1355        assert_eq!(resolved_t1, BuiltinTypes::number());
1356        assert_eq!(resolved_t2, BuiltinTypes::string());
1357    }
1358
1359    #[test]
1360    fn test_function_cross_unification_with_concrete() {
1361        // Type::Function ~ Concrete(TypeAnnotation::Function) should unify
1362        let mut solver = ConstraintSolver::new();
1363        let mut tvgen = TypeVarGen::new();
1364        let t1 = fresh_var(&mut tvgen);
1365
1366        let concrete_func = Type::Concrete(TypeAnnotation::Function {
1367            params: vec![shape_ast::ast::FunctionParam {
1368                name: None,
1369                optional: false,
1370                type_annotation: TypeAnnotation::Basic("number".to_string()),
1371            }],
1372            returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1373        });
1374
1375        let mut constraints = vec![(
1376            Type::Function {
1377                params: vec![Type::Variable(t1.clone())],
1378                returns: Box::new(BuiltinTypes::string()),
1379            },
1380            concrete_func,
1381        )];
1382
1383        solver.solve(&mut constraints).unwrap();
1384
1385        let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1386        assert_eq!(resolved, BuiltinTypes::number());
1387    }
1388
1389    #[test]
1390    fn test_object_annotations_unify_structurally() {
1391        let mut solver = ConstraintSolver::new();
1392        let mut constraints = vec![(
1393            Type::Concrete(TypeAnnotation::Object(vec![
1394                ObjectTypeField {
1395                    name: "x".to_string(),
1396                    optional: false,
1397                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1398                    annotations: vec![],
1399                },
1400                ObjectTypeField {
1401                    name: "y".to_string(),
1402                    optional: false,
1403                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1404                    annotations: vec![],
1405                },
1406            ])),
1407            Type::Concrete(TypeAnnotation::Object(vec![
1408                ObjectTypeField {
1409                    name: "x".to_string(),
1410                    optional: false,
1411                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1412                    annotations: vec![],
1413                },
1414                ObjectTypeField {
1415                    name: "y".to_string(),
1416                    optional: false,
1417                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1418                    annotations: vec![],
1419                },
1420            ])),
1421        )];
1422        assert!(solver.solve(&mut constraints).is_ok());
1423    }
1424
1425    #[test]
1426    fn test_intersection_annotations_unify_order_independent() {
1427        let mut solver = ConstraintSolver::new();
1428        let obj_xy = TypeAnnotation::Object(vec![
1429            ObjectTypeField {
1430                name: "x".to_string(),
1431                optional: false,
1432                type_annotation: TypeAnnotation::Basic("int".to_string()),
1433                annotations: vec![],
1434            },
1435            ObjectTypeField {
1436                name: "y".to_string(),
1437                optional: false,
1438                type_annotation: TypeAnnotation::Basic("int".to_string()),
1439                annotations: vec![],
1440            },
1441        ]);
1442        let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1443            name: "z".to_string(),
1444            optional: false,
1445            type_annotation: TypeAnnotation::Basic("int".to_string()),
1446            annotations: vec![],
1447        }]);
1448
1449        let mut constraints = vec![(
1450            Type::Concrete(TypeAnnotation::Intersection(vec![
1451                obj_xy.clone(),
1452                obj_z.clone(),
1453            ])),
1454            Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1455        )];
1456        assert!(solver.solve(&mut constraints).is_ok());
1457    }
1458
1459    // ===== Sprint 2: ImplementsTrait constraint tests =====
1460
1461    #[test]
1462    fn test_implements_trait_satisfied() {
1463        let mut solver = ConstraintSolver::new();
1464        let mut impls = std::collections::HashSet::new();
1465        impls.insert("Comparable::number".to_string());
1466        solver.set_trait_impls(impls);
1467
1468        let mut tvgen = TypeVarGen::new();
1469        let bound_var = fresh_var(&mut tvgen);
1470        let mut constraints = vec![(
1471            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1472            Type::Constrained {
1473                var: bound_var,
1474                constraint: Box::new(TypeConstraint::ImplementsTrait {
1475                    trait_name: "Comparable".to_string(),
1476                }),
1477            },
1478        )];
1479        assert!(solver.solve(&mut constraints).is_ok());
1480    }
1481
1482    #[test]
1483    fn test_implements_trait_violated() {
1484        let mut solver = ConstraintSolver::new();
1485        // No trait impls registered — string doesn't implement Comparable
1486        let mut tvgen = TypeVarGen::new();
1487        let bound_var = fresh_var(&mut tvgen);
1488        let mut constraints = vec![(
1489            Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1490            Type::Constrained {
1491                var: bound_var,
1492                constraint: Box::new(TypeConstraint::ImplementsTrait {
1493                    trait_name: "Comparable".to_string(),
1494                }),
1495            },
1496        )];
1497        let result = solver.solve(&mut constraints);
1498        assert!(result.is_err());
1499        match result.unwrap_err() {
1500            TypeError::TraitBoundViolation {
1501                type_name,
1502                trait_name,
1503            } => {
1504                assert_eq!(type_name, "string");
1505                assert_eq!(trait_name, "Comparable");
1506            }
1507            other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1508        }
1509    }
1510
1511    #[test]
1512    fn test_implements_trait_via_variable_resolution() {
1513        let mut solver = ConstraintSolver::new();
1514        let mut impls = std::collections::HashSet::new();
1515        impls.insert("Sortable::number".to_string());
1516        solver.set_trait_impls(impls);
1517
1518        let mut tvgen = TypeVarGen::new();
1519        let type_var = fresh_var(&mut tvgen);
1520        let bound_var = fresh_var(&mut tvgen);
1521
1522        let mut constraints = vec![
1523            // T: Sortable
1524            (
1525                Type::Variable(type_var.clone()),
1526                Type::Constrained {
1527                    var: bound_var,
1528                    constraint: Box::new(TypeConstraint::ImplementsTrait {
1529                        trait_name: "Sortable".to_string(),
1530                    }),
1531                },
1532            ),
1533            // T = number
1534            (
1535                Type::Variable(type_var),
1536                Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1537            ),
1538        ];
1539        assert!(
1540            solver.solve(&mut constraints).is_ok(),
1541            "T resolved to number which implements Sortable"
1542        );
1543    }
1544}