Skip to main content

shape_runtime/type_system/
constraints.rs

1//! Type Constraint Solver
2//!
3//! Solves type constraints generated during type inference
4//! to determine concrete types for type variables.
5
6use super::checking::MethodTable;
7use super::unification::Unifier;
8use super::*;
9use shape_ast::ast::{ObjectTypeField, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12/// Check if a Type::Generic base is "Array" or "Vec".
13fn is_array_or_vec_base(base: &Type) -> bool {
14    match base {
15        Type::Concrete(TypeAnnotation::Reference(name))
16        | Type::Concrete(TypeAnnotation::Basic(name)) => name == "Array" || name == "Vec",
17        _ => false,
18    }
19}
20
21pub struct ConstraintSolver {
22    /// Type unifier
23    unifier: Unifier,
24    /// Deferred constraints that couldn't be solved immediately.
25    /// These are handled in solve() via multiple passes.
26    _deferred: Vec<(Type, Type)>,
27    /// Type variable bounds
28    bounds: HashMap<TypeVar, TypeConstraint>,
29    /// Method table for HasMethod constraint enforcement
30    method_table: Option<MethodTable>,
31    /// Trait implementation registry: set of "TraitName::TypeName" keys
32    trait_impls: HashSet<String>,
33}
34
35impl Default for ConstraintSolver {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl ConstraintSolver {
42    pub fn new() -> Self {
43        ConstraintSolver {
44            unifier: Unifier::new(),
45            _deferred: Vec::new(),
46            bounds: HashMap::new(),
47            method_table: None,
48            trait_impls: HashSet::new(),
49        }
50    }
51
52    /// Attach a method table for HasMethod constraint enforcement.
53    /// When set, HasMethod constraints are validated against this table
54    /// instead of being accepted unconditionally.
55    pub fn set_method_table(&mut self, table: MethodTable) {
56        self.method_table = Some(table);
57    }
58
59    /// Register trait implementations for ImplementsTrait constraint enforcement.
60    /// Each entry is a "TraitName::TypeName" key indicating that TypeName implements TraitName.
61    pub fn set_trait_impls(&mut self, impls: HashSet<String>) {
62        self.trait_impls = impls;
63    }
64
65    /// Solve all type constraints
66    pub fn solve(&mut self, constraints: &mut Vec<(Type, Type)>) -> TypeResult<()> {
67        // First pass: solve simple unification constraints
68        let mut unsolved = Vec::new();
69
70        for (t1, t2) in constraints.drain(..) {
71            if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
72                // If we can't solve it now, defer it
73                unsolved.push((t1, t2));
74            }
75        }
76
77        // Second pass: try deferred constraints
78        let mut made_progress = true;
79        while made_progress && !unsolved.is_empty() {
80            made_progress = false;
81            let mut still_unsolved = Vec::new();
82
83            for (t1, t2) in unsolved.drain(..) {
84                if self.solve_constraint(t1.clone(), t2.clone()).is_err() {
85                    still_unsolved.push((t1, t2));
86                } else {
87                    made_progress = true;
88                }
89            }
90
91            unsolved = still_unsolved;
92        }
93
94        // Check if any constraints remain unsolved
95        if !unsolved.is_empty() {
96            return Err(TypeError::UnsolvedConstraints(unsolved));
97        }
98
99        // Apply bounds to type variables
100        self.apply_bounds()?;
101
102        Ok(())
103    }
104
105    /// Solve a single constraint
106    fn solve_constraint(&mut self, t1: Type, t2: Type) -> TypeResult<()> {
107        // Apply current substitutions before matching to avoid overwriting
108        // existing bindings (e.g., T17=string overwritten by T17=T19 during
109        // Function param/return pairwise unification).
110        let t1 = self.unifier.apply_substitutions(&t1);
111        let t2 = self.unifier.apply_substitutions(&t2);
112
113        match (&t1, &t2) {
114            // Variable constraints
115            (Type::Variable(v1), Type::Variable(v2)) if v1 == v2 => Ok(()),
116
117            // Constrained type variables — must be matched BEFORE the general
118            // Variable arm, otherwise (Variable, Constrained) pairs are caught
119            // by the Variable arm and the bound is never recorded.
120            (Type::Constrained { var, constraint }, ty)
121            | (ty, Type::Constrained { var, constraint }) => {
122                // Record the constraint
123                self.bounds.insert(var.clone(), *constraint.clone());
124
125                // Unify with the underlying type
126                self.solve_constraint(Type::Variable(var.clone()), ty.clone())
127            }
128
129            (Type::Variable(var), ty) | (ty, Type::Variable(var)) => {
130                // Check occurs check
131                if self.occurs_in(var, ty) {
132                    return Err(TypeError::InfiniteType(var.clone()));
133                }
134
135                self.unifier.bind(var.clone(), ty.clone());
136                Ok(())
137            }
138
139            // Concrete type constraints
140            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
141                if self.unify_annotations(ann1, ann2)? {
142                    Ok(())
143                } else if Self::can_numeric_widen(ann1, ann2) {
144                    // Implicit numeric promotion (int → number/float)
145                    Ok(())
146                } else {
147                    Err(TypeError::TypeMismatch(
148                        format!("{:?}", ann1),
149                        format!("{:?}", ann2),
150                    ))
151                }
152            }
153
154            // Generic type constraints
155            (Type::Generic { base: b1, args: a1 }, Type::Generic { base: b2, args: a2 }) => {
156                self.solve_constraint(*b1.clone(), *b2.clone())?;
157
158                let is_result_base = |base: &Type| {
159                    matches!(
160                        base,
161                        Type::Concrete(TypeAnnotation::Reference(name))
162                            | Type::Concrete(TypeAnnotation::Basic(name))
163                            if name == "Result"
164                    )
165                };
166
167                if a1.len() != a2.len() {
168                    if is_result_base(&b1) && is_result_base(&b2) {
169                        match (a1.len(), a2.len()) {
170                            // `Result<T>` is error-agnostic shorthand and should unify
171                            // with `Result<T, E>` by constraining only the success type.
172                            (1, 2) | (2, 1) => {
173                                self.solve_constraint(a1[0].clone(), a2[0].clone())?;
174                                return Ok(());
175                            }
176                            _ => return Err(TypeError::ArityMismatch(a1.len(), a2.len())),
177                        }
178                    } else {
179                        return Err(TypeError::ArityMismatch(a1.len(), a2.len()));
180                    }
181                }
182
183                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
184                    self.solve_constraint(arg1.clone(), arg2.clone())?;
185                }
186
187                Ok(())
188            }
189
190            // Function ~ Function: pairwise unify params + returns
191            (
192                Type::Function {
193                    params: p1,
194                    returns: r1,
195                },
196                Type::Function {
197                    params: p2,
198                    returns: r2,
199                },
200            ) => {
201                if p1.len() != p2.len() {
202                    return Err(TypeError::ArityMismatch(p1.len(), p2.len()));
203                }
204                for (param1, param2) in p1.iter().zip(p2.iter()) {
205                    // Parameter compatibility is checked from observed/actual to
206                    // declared/expected shape so directional numeric widening
207                    // (e.g. int -> number) remains valid in call constraints.
208                    self.solve_constraint(param2.clone(), param1.clone())?;
209                }
210                self.solve_constraint(*r1.clone(), *r2.clone())
211            }
212
213            // Cross-compatibility: Type::Function ~ Concrete(TypeAnnotation::Function)
214            (
215                Type::Function {
216                    params: fp,
217                    returns: fr,
218                },
219                Type::Concrete(TypeAnnotation::Function {
220                    params: cp,
221                    returns: cr,
222                }),
223            )
224            | (
225                Type::Concrete(TypeAnnotation::Function {
226                    params: cp,
227                    returns: cr,
228                }),
229                Type::Function {
230                    params: fp,
231                    returns: fr,
232                },
233            ) => {
234                if fp.len() != cp.len() {
235                    return Err(TypeError::ArityMismatch(fp.len(), cp.len()));
236                }
237                for (f_param, c_param) in fp.iter().zip(cp.iter()) {
238                    self.solve_constraint(
239                        f_param.clone(),
240                        Type::Concrete(c_param.type_annotation.clone()),
241                    )?;
242                }
243                self.solve_constraint(*fr.clone(), Type::Concrete(*cr.clone()))
244            }
245
246            // Array<T> (Type::Generic with base "Array" or "Vec") ~ Concrete(Array(T))
247            (Type::Generic { base, args }, Type::Concrete(TypeAnnotation::Array(elem)))
248            | (Type::Concrete(TypeAnnotation::Array(elem)), Type::Generic { base, args })
249                if args.len() == 1 && is_array_or_vec_base(base) =>
250            {
251                self.solve_constraint(args[0].clone(), Type::Concrete((**elem).clone()))
252            }
253
254            _ => Err(TypeError::TypeMismatch(
255                format!("{:?}", t1),
256                format!("{:?}", t2),
257            )),
258        }
259    }
260
261    /// Check if a type variable occurs in a type (occurs check)
262    fn occurs_in(&self, var: &TypeVar, ty: &Type) -> bool {
263        match ty {
264            Type::Variable(v) => v == var,
265            Type::Generic { base, args } => {
266                self.occurs_in(var, base) || args.iter().any(|arg| self.occurs_in(var, arg))
267            }
268            Type::Constrained { var: v, .. } => v == var,
269            Type::Function { params, returns } => {
270                params.iter().any(|p| self.occurs_in(var, p)) || self.occurs_in(var, returns)
271            }
272            Type::Concrete(_) => false,
273        }
274    }
275
276    /// Check if a numeric type can widen to another (directional).
277    ///
278    /// Integer-family types (`int`, `i16`, `u32`, `byte`, ...) can widen to
279    /// number-family types (`number`, `f32`, `f64`, ...).
280    /// `number → int` does NOT widen (lossy). `decimal → number` does NOT widen
281    /// (different precision semantics).
282    fn can_numeric_widen(from: &TypeAnnotation, to: &TypeAnnotation) -> bool {
283        let from_name = match from {
284            TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
285            _ => None,
286        };
287        let to_name = match to {
288            TypeAnnotation::Basic(name) | TypeAnnotation::Reference(name) => Some(name.as_str()),
289            _ => None,
290        };
291
292        match (from_name, to_name) {
293            (Some(f), Some(t)) => {
294                BuiltinTypes::is_integer_type_name(f) && BuiltinTypes::is_number_type_name(t)
295            }
296            _ => false,
297        }
298    }
299
300    /// Unify two type annotations
301    fn unify_annotations(&self, ann1: &TypeAnnotation, ann2: &TypeAnnotation) -> TypeResult<bool> {
302        match (ann1, ann2) {
303            // Basic types
304            (TypeAnnotation::Basic(_), TypeAnnotation::Basic(_)) => {
305                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
306            }
307            (TypeAnnotation::Reference(n1), TypeAnnotation::Reference(n2)) => Ok(n1 == n2),
308            (TypeAnnotation::Basic(_), TypeAnnotation::Reference(_))
309            | (TypeAnnotation::Reference(_), TypeAnnotation::Basic(_)) => {
310                Ok(ann1 == ann2 || Self::can_numeric_widen(ann1, ann2))
311            }
312
313            // Array types
314            (TypeAnnotation::Array(e1), TypeAnnotation::Array(e2)) => {
315                self.unify_annotations(e1, e2)
316            }
317
318            // Tuple types
319            (TypeAnnotation::Tuple(t1), TypeAnnotation::Tuple(t2)) => {
320                if t1.len() != t2.len() {
321                    return Ok(false);
322                }
323
324                for (elem1, elem2) in t1.iter().zip(t2.iter()) {
325                    if !self.unify_annotations(elem1, elem2)? {
326                        return Ok(false);
327                    }
328                }
329
330                Ok(true)
331            }
332
333            // Structural object types
334            (TypeAnnotation::Object(f1), TypeAnnotation::Object(f2)) => {
335                self.object_fields_compatible(f1, f2)
336            }
337
338            // Function types
339            (
340                TypeAnnotation::Function {
341                    params: p1,
342                    returns: r1,
343                },
344                TypeAnnotation::Function {
345                    params: p2,
346                    returns: r2,
347                },
348            ) => {
349                if p1.len() != p2.len() {
350                    return Ok(false);
351                }
352
353                for (param1, param2) in p1.iter().zip(p2.iter()) {
354                    if !self.unify_annotations(&param1.type_annotation, &param2.type_annotation)? {
355                        return Ok(false);
356                    }
357                }
358
359                self.unify_annotations(r1, r2)
360            }
361
362            // Union types
363            // A | B unifies with C | D if each type in one union can unify with at least one type in the other
364            (TypeAnnotation::Union(u1), TypeAnnotation::Union(u2)) => {
365                // Check that every type in u1 can unify with at least one type in u2
366                for t1 in u1 {
367                    let mut found_match = false;
368                    for t2 in u2 {
369                        if self.unify_annotations(t1, t2)? {
370                            found_match = true;
371                            break;
372                        }
373                    }
374                    if !found_match {
375                        return Ok(false);
376                    }
377                }
378                // Check that every type in u2 can unify with at least one type in u1
379                for t2 in u2 {
380                    let mut found_match = false;
381                    for t1 in u1 {
382                        if self.unify_annotations(t1, t2)? {
383                            found_match = true;
384                            break;
385                        }
386                    }
387                    if !found_match {
388                        return Ok(false);
389                    }
390                }
391                Ok(true)
392            }
393
394            // Union with non-union: A | B unifies with C if either A or B unifies with C
395            (TypeAnnotation::Union(union_types), other)
396            | (other, TypeAnnotation::Union(union_types)) => {
397                for union_type in union_types {
398                    if self.unify_annotations(union_type, other)? {
399                        return Ok(true);
400                    }
401                }
402                Ok(false)
403            }
404
405            // Intersection types (order-independent)
406            (TypeAnnotation::Intersection(i1), TypeAnnotation::Intersection(i2)) => {
407                self.unify_annotation_sets(i1, i2)
408            }
409
410            // Void, Null, Undefined
411            (TypeAnnotation::Void, TypeAnnotation::Void) => Ok(true),
412            (TypeAnnotation::Null, TypeAnnotation::Null) => Ok(true),
413            (TypeAnnotation::Undefined, TypeAnnotation::Undefined) => Ok(true),
414
415            // Trait object types: dyn Trait1 + Trait2
416            // Two trait objects unify if they have the same set of traits
417            (TypeAnnotation::Dyn(traits1), TypeAnnotation::Dyn(traits2)) => {
418                Ok(traits1.len() == traits2.len() && traits1.iter().all(|t| traits2.contains(t)))
419            }
420
421            // Array<T> (Generic) is equivalent to Vec<T> (Array)
422            (TypeAnnotation::Generic { name, args }, TypeAnnotation::Array(elem))
423            | (TypeAnnotation::Array(elem), TypeAnnotation::Generic { name, args })
424                if name == "Array" && args.len() == 1 =>
425            {
426                self.unify_annotations(&args[0], elem)
427            }
428
429            // Different types don't unify
430            _ => Ok(false),
431        }
432    }
433
434    fn object_fields_compatible(
435        &self,
436        left: &[ObjectTypeField],
437        right: &[ObjectTypeField],
438    ) -> TypeResult<bool> {
439        for left_field in left {
440            let Some(right_field) = right.iter().find(|f| f.name == left_field.name) else {
441                return Ok(false);
442            };
443            if left_field.optional != right_field.optional {
444                return Ok(false);
445            }
446            if !self.unify_annotations(&left_field.type_annotation, &right_field.type_annotation)? {
447                return Ok(false);
448            }
449        }
450        if left.len() != right.len() {
451            return Ok(false);
452        }
453        Ok(true)
454    }
455
456    fn unify_annotation_sets(
457        &self,
458        left: &[TypeAnnotation],
459        right: &[TypeAnnotation],
460    ) -> TypeResult<bool> {
461        if left.len() != right.len() {
462            return Ok(false);
463        }
464
465        let mut matched = vec![false; right.len()];
466        for left_ann in left {
467            let mut found = false;
468            for (idx, right_ann) in right.iter().enumerate() {
469                if matched[idx] {
470                    continue;
471                }
472                if self.unify_annotations(left_ann, right_ann)? {
473                    matched[idx] = true;
474                    found = true;
475                    break;
476                }
477            }
478            if !found {
479                return Ok(false);
480            }
481        }
482
483        Ok(true)
484    }
485
486    /// Apply type variable bounds, propagating resolved field types back to type variables.
487    ///
488    /// When a `HasField` constraint is satisfied and the expected field type was a
489    /// type variable, this binds that variable to the actual field type. This enables
490    /// backward propagation: `let f = |obj| obj.x; f({x: 42})` resolves `obj.x` to `int`.
491    fn apply_bounds(&mut self) -> TypeResult<()> {
492        let mut new_bindings: Vec<(TypeVar, Type)> = Vec::new();
493
494        for (var, constraint) in &self.bounds {
495            // Use apply_substitutions to follow the full variable chain
496            // (lookup only returns the direct binding, not the resolved type).
497            let resolved = self
498                .unifier
499                .apply_substitutions(&Type::Variable(var.clone()));
500
501            if let Type::Variable(_) = &resolved {
502                // Still unresolved — skip for now
503                continue;
504            }
505
506            self.check_constraint(&resolved, constraint)?;
507
508            // Backward propagation: when a HasField constraint is satisfied,
509            // bind the result type variable to the actual field type.
510            if let TypeConstraint::HasField(field, expected_field_type) = constraint {
511                if let Type::Variable(field_var) = expected_field_type.as_ref() {
512                    // Also check if the field var is already resolved
513                    let field_resolved = self
514                        .unifier
515                        .apply_substitutions(&Type::Variable(field_var.clone()));
516                    if let Type::Variable(_) = &field_resolved {
517                        // Field var still unresolved — try to bind it
518                        if let Type::Concrete(TypeAnnotation::Object(fields)) = &resolved {
519                            if let Some(found_field) = fields.iter().find(|f| f.name == *field) {
520                                new_bindings.push((
521                                    field_var.clone(),
522                                    Type::Concrete(found_field.type_annotation.clone()),
523                                ));
524                            }
525                        }
526                    }
527                }
528            }
529        }
530
531        // Apply collected bindings
532        for (var, ty) in new_bindings {
533            self.unifier.bind(var, ty);
534        }
535
536        Ok(())
537    }
538
539    /// Check if a type satisfies a constraint
540    fn check_constraint(&self, ty: &Type, constraint: &TypeConstraint) -> TypeResult<()> {
541        match constraint {
542            TypeConstraint::Numeric => match ty {
543                Type::Concrete(TypeAnnotation::Basic(name))
544                    if BuiltinTypes::is_numeric_type_name(name) =>
545                {
546                    Ok(())
547                }
548                _ => Err(TypeError::ConstraintViolation(format!(
549                    "{:?} is not numeric",
550                    ty
551                ))),
552            },
553
554            TypeConstraint::Comparable => match ty {
555                Type::Concrete(TypeAnnotation::Basic(name))
556                    if BuiltinTypes::is_numeric_type_name(name)
557                        || name == "string"
558                        || name == "bool" =>
559                {
560                    Ok(())
561                }
562                _ => Err(TypeError::ConstraintViolation(format!(
563                    "{:?} is not comparable",
564                    ty
565                ))),
566            },
567
568            TypeConstraint::Iterable => match ty {
569                Type::Concrete(TypeAnnotation::Array(_)) => Ok(()),
570                Type::Concrete(TypeAnnotation::Basic(name))
571                    if name == "string" || name == "rows" =>
572                {
573                    Ok(())
574                }
575                _ => Err(TypeError::ConstraintViolation(format!(
576                    "{:?} is not iterable",
577                    ty
578                ))),
579            },
580
581            TypeConstraint::HasField(field, expected_field_type) => {
582                match ty {
583                    Type::Concrete(TypeAnnotation::Object(fields)) => {
584                        match fields.iter().find(|f| f.name == *field) {
585                            Some(found_field) => {
586                                // Check that field type matches expected type
587                                if let Some(expected_ann) = expected_field_type.to_annotation() {
588                                    if self.unify_annotations(
589                                        &found_field.type_annotation,
590                                        &expected_ann,
591                                    )? {
592                                        Ok(())
593                                    } else {
594                                        Err(TypeError::ConstraintViolation(format!(
595                                            "field '{}' has type {:?}, expected {:?}",
596                                            field, found_field.type_annotation, expected_ann
597                                        )))
598                                    }
599                                } else {
600                                    // Expected type is a type variable, accept any field type
601                                    Ok(())
602                                }
603                            }
604                            None => Err(TypeError::ConstraintViolation(format!(
605                                "{:?} does not have field '{}'",
606                                ty, field
607                            ))),
608                        }
609                    }
610                    Type::Concrete(TypeAnnotation::Basic(_name)) => {
611                        // For named types, we assume property access was validated
612                        // during inference using the schema registry. If a HasField
613                        // constraint reaches here, it means the type wasn't a known
614                        // schema type during inference, so we accept it tentatively.
615                        // Runtime will do the final validation.
616                        //
617                        // Note: Previously this hardcoded "row" with OHLCV fields.
618                        // Now schema validation happens in TypeInferenceEngine::infer_property_access.
619                        Ok(())
620                    }
621                    _ => Err(TypeError::ConstraintViolation(format!(
622                        "{:?} cannot have fields",
623                        ty
624                    ))),
625                }
626            }
627
628            TypeConstraint::Callable {
629                params: expected_params,
630                returns: expected_returns,
631            } => {
632                match ty {
633                    Type::Concrete(TypeAnnotation::Function {
634                        params: actual_params,
635                        returns: actual_returns,
636                    }) => {
637                        // Check parameter count matches
638                        if expected_params.len() != actual_params.len() {
639                            return Err(TypeError::ConstraintViolation(format!(
640                                "function expects {} parameters, got {}",
641                                expected_params.len(),
642                                actual_params.len()
643                            )));
644                        }
645
646                        // Check each parameter type (contravariant: expected <: actual)
647                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
648                            if let Some(expected_ann) = expected.to_annotation() {
649                                if !self
650                                    .unify_annotations(&expected_ann, &actual.type_annotation)?
651                                {
652                                    return Err(TypeError::ConstraintViolation(format!(
653                                        "parameter type mismatch: expected {:?}, got {:?}",
654                                        expected_ann, actual.type_annotation
655                                    )));
656                                }
657                            }
658                        }
659
660                        // Check return type (covariant: actual <: expected)
661                        if let Some(expected_ret_ann) = expected_returns.to_annotation() {
662                            if !self.unify_annotations(actual_returns, &expected_ret_ann)? {
663                                return Err(TypeError::ConstraintViolation(format!(
664                                    "return type mismatch: expected {:?}, got {:?}",
665                                    expected_ret_ann, actual_returns
666                                )));
667                            }
668                        }
669
670                        Ok(())
671                    }
672                    Type::Function {
673                        params: actual_params,
674                        returns: actual_returns,
675                    } => {
676                        if expected_params.len() != actual_params.len() {
677                            return Err(TypeError::ConstraintViolation(format!(
678                                "function expects {} parameters, got {}",
679                                expected_params.len(),
680                                actual_params.len()
681                            )));
682                        }
683                        // Type::Function params are Type, not FunctionParam — compare directly
684                        for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
685                            if let (Some(e_ann), Some(a_ann)) =
686                                (expected.to_annotation(), actual.to_annotation())
687                            {
688                                if !self.unify_annotations(&e_ann, &a_ann)? {
689                                    return Err(TypeError::ConstraintViolation(format!(
690                                        "parameter type mismatch: expected {:?}, got {:?}",
691                                        e_ann, a_ann
692                                    )));
693                                }
694                            }
695                        }
696                        if let (Some(e_ret), Some(a_ret)) = (
697                            expected_returns.to_annotation(),
698                            actual_returns.to_annotation(),
699                        ) {
700                            if !self.unify_annotations(&a_ret, &e_ret)? {
701                                return Err(TypeError::ConstraintViolation(format!(
702                                    "return type mismatch: expected {:?}, got {:?}",
703                                    e_ret, a_ret
704                                )));
705                            }
706                        }
707                        Ok(())
708                    }
709                    _ => Err(TypeError::ConstraintViolation(format!(
710                        "{:?} is not callable",
711                        ty
712                    ))),
713                }
714            }
715
716            TypeConstraint::OneOf(options) => {
717                for option in options {
718                    // If type matches any option, constraint is satisfied
719                    if let Type::Concrete(ann) = option {
720                        if let Type::Concrete(ty_ann) = ty {
721                            if self.unify_annotations(ann, ty_ann).unwrap_or(false) {
722                                return Ok(());
723                            }
724                        }
725                    }
726                }
727
728                Err(TypeError::ConstraintViolation(format!(
729                    "{:?} does not match any of {:?}",
730                    ty, options
731                )))
732            }
733
734            TypeConstraint::Extends(base) => {
735                // Implement subtyping check
736                self.is_subtype(ty, base)
737            }
738
739            TypeConstraint::ImplementsTrait { trait_name } => {
740                match ty {
741                    Type::Variable(_) => {
742                        // Type variable not yet resolved — this is a compile error
743                        // (no deferring per Sprint 2 spec)
744                        Err(TypeError::TraitBoundViolation {
745                            type_name: format!("{:?}", ty),
746                            trait_name: trait_name.clone(),
747                        })
748                    }
749                    Type::Concrete(ann) => {
750                        let type_name = match ann {
751                            TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => n.clone(),
752                            _ => format!("{:?}", ann),
753                        };
754                        if self.has_trait_impl(trait_name, &type_name) {
755                            Ok(())
756                        } else {
757                            Err(TypeError::TraitBoundViolation {
758                                type_name,
759                                trait_name: trait_name.clone(),
760                            })
761                        }
762                    }
763                    Type::Generic { base, .. } => {
764                        let type_name = if let Type::Concrete(
765                            TypeAnnotation::Reference(n) | TypeAnnotation::Basic(n),
766                        ) = base.as_ref()
767                        {
768                            n.clone()
769                        } else {
770                            format!("{:?}", base)
771                        };
772                        if self.has_trait_impl(trait_name, &type_name) {
773                            Ok(())
774                        } else {
775                            Err(TypeError::TraitBoundViolation {
776                                type_name,
777                                trait_name: trait_name.clone(),
778                            })
779                        }
780                    }
781                    _ => Err(TypeError::TraitBoundViolation {
782                        type_name: format!("{:?}", ty),
783                        trait_name: trait_name.clone(),
784                    }),
785                }
786            }
787
788            TypeConstraint::HasMethod {
789                method_name,
790                arg_types: _,
791                return_type: _,
792            } => {
793                // If we have a method table, enforce the constraint
794                if let Some(method_table) = &self.method_table {
795                    match ty {
796                        Type::Variable(_) => Ok(()), // Unresolved type var, defer
797                        Type::Concrete(ann) => {
798                            let type_name = match ann {
799                                TypeAnnotation::Basic(n) | TypeAnnotation::Reference(n) => {
800                                    n.clone()
801                                }
802                                TypeAnnotation::Array(_) => "Vec".to_string(),
803                                _ => return Ok(()), // Complex types: accept
804                            };
805                            if method_table.lookup(ty, method_name).is_some() {
806                                Ok(())
807                            } else {
808                                Err(TypeError::MethodNotFound {
809                                    type_name,
810                                    method_name: method_name.clone(),
811                                })
812                            }
813                        }
814                        Type::Generic { base, .. } => {
815                            if method_table.lookup(ty, method_name).is_some() {
816                                Ok(())
817                            } else {
818                                let type_name =
819                                    if let Type::Concrete(TypeAnnotation::Reference(n)) =
820                                        base.as_ref()
821                                    {
822                                        n.clone()
823                                    } else {
824                                        format!("{:?}", base)
825                                    };
826                                Err(TypeError::MethodNotFound {
827                                    type_name,
828                                    method_name: method_name.clone(),
829                                })
830                            }
831                        }
832                        _ => Ok(()), // Function, Constrained: accept
833                    }
834                } else {
835                    // No method table attached — accept all (backward compatible)
836                    Ok(())
837                }
838            }
839        }
840    }
841
842    /// Check if a type implements a trait, considering numeric widening.
843    ///
844    /// For example, `int` satisfies a trait bound if the trait is implemented for `number`,
845    /// since `int` can widen to `number` in the type system.
846    fn has_trait_impl(&self, trait_name: &str, type_name: &str) -> bool {
847        let key = format!("{}::{}", trait_name, type_name);
848        if self.trait_impls.contains(&key) {
849            return true;
850        }
851        // Numeric widening: integer-family aliases can use number/float/f64 impls.
852        if BuiltinTypes::is_integer_type_name(type_name) {
853            for widen_to in &["number", "float", "f64"] {
854                let widen_key = format!("{}::{}", trait_name, widen_to);
855                if self.trait_impls.contains(&widen_key) {
856                    return true;
857                }
858            }
859        }
860        false
861    }
862
863    /// Check if ty is a subtype of base (ty <: base)
864    /// Subtyping rules:
865    /// - Same types are subtypes of each other
866    /// - Any is a supertype of everything
867    /// - Vec<A> <: Vec<B> if A <: B (covariant)
868    /// - Function<P1, R1> <: Function<P2, R2> if P2 <: P1 (contravariant params) and R1 <: R2 (covariant return)
869    fn is_subtype(&self, ty: &Type, base: &Type) -> TypeResult<()> {
870        match (ty, base) {
871            // Same types are subtypes
872            (t1, t2) if t1 == t2 => Ok(()),
873
874            // Type variables - if we can unify, it's compatible
875            (Type::Variable(_), _) | (_, Type::Variable(_)) => Ok(()),
876
877            // Array subtyping (covariant)
878            (
879                Type::Concrete(TypeAnnotation::Array(elem1)),
880                Type::Concrete(TypeAnnotation::Array(elem2)),
881            ) => {
882                let t1 = Type::Concrete(*elem1.clone());
883                let t2 = Type::Concrete(*elem2.clone());
884                self.is_subtype(&t1, &t2)
885            }
886
887            // Function subtyping (contravariant params, covariant return)
888            (
889                Type::Concrete(TypeAnnotation::Function {
890                    params: p1,
891                    returns: r1,
892                }),
893                Type::Concrete(TypeAnnotation::Function {
894                    params: p2,
895                    returns: r2,
896                }),
897            ) => {
898                // Check parameter count
899                if p1.len() != p2.len() {
900                    return Err(TypeError::ConstraintViolation(format!(
901                        "function parameter count mismatch: {} vs {}",
902                        p1.len(),
903                        p2.len()
904                    )));
905                }
906
907                // Contravariant: base params must be subtypes of ty params
908                for (param1, param2) in p1.iter().zip(p2.iter()) {
909                    let t1 = Type::Concrete(param2.type_annotation.clone());
910                    let t2 = Type::Concrete(param1.type_annotation.clone());
911                    self.is_subtype(&t1, &t2)?;
912                }
913
914                // Covariant: ty return must be subtype of base return
915                let ret1 = Type::Concrete(*r1.clone());
916                let ret2 = Type::Concrete(*r2.clone());
917                self.is_subtype(&ret1, &ret2)
918            }
919
920            // Optional subtyping: T <: Option<T>
921            (t, Type::Concrete(TypeAnnotation::Generic { name, args }))
922                if name == "Option" && args.len() == 1 =>
923            {
924                let inner = Type::Concrete(args[0].clone());
925                self.is_subtype(t, &inner)
926            }
927
928            // Type::Function subtyping (contravariant params, covariant return)
929            (
930                Type::Function {
931                    params: p1,
932                    returns: r1,
933                },
934                Type::Function {
935                    params: p2,
936                    returns: r2,
937                },
938            ) => {
939                if p1.len() != p2.len() {
940                    return Err(TypeError::ConstraintViolation(format!(
941                        "function parameter count mismatch: {} vs {}",
942                        p1.len(),
943                        p2.len()
944                    )));
945                }
946                // Contravariant params
947                for (param1, param2) in p1.iter().zip(p2.iter()) {
948                    self.is_subtype(param2, param1)?;
949                }
950                // Covariant return
951                self.is_subtype(r1, r2)
952            }
953
954            // Basic types - check if they unify
955            (Type::Concrete(ann1), Type::Concrete(ann2)) => {
956                if self.unify_annotations(ann1, ann2)? {
957                    Ok(())
958                } else {
959                    Err(TypeError::ConstraintViolation(format!(
960                        "{:?} is not a subtype of {:?}",
961                        ty, base
962                    )))
963                }
964            }
965
966            // Default: not a subtype
967            _ => Err(TypeError::ConstraintViolation(format!(
968                "{:?} is not a subtype of {:?}",
969                ty, base
970            ))),
971        }
972    }
973
974    /// Get the unifier for applying substitutions
975    pub fn unifier(&self) -> &Unifier {
976        &self.unifier
977    }
978}
979
980#[cfg(test)]
981mod tests {
982    use super::*;
983    use shape_ast::ast::ObjectTypeField;
984
985    #[test]
986    fn test_hasfield_backward_propagation_binds_field_type() {
987        // When a TypeVar has a HasField constraint and is resolved to a concrete
988        // object type, the field's result type variable should be bound to the
989        // actual field type. This enables backward type propagation.
990        let mut solver = ConstraintSolver::new();
991
992        let obj_var = TypeVar::fresh();
993        let field_result_var = TypeVar::fresh();
994        let bound_var = TypeVar::fresh();
995
996        let mut constraints = vec![
997            // obj_var ~ Constrained { var: bound_var, HasField("x", field_result_var) }
998            // This records bound: bound_var → HasField("x", field_result_var)
999            // and solves: bound_var ~ obj_var
1000            (
1001                Type::Variable(obj_var.clone()),
1002                Type::Constrained {
1003                    var: bound_var,
1004                    constraint: Box::new(TypeConstraint::HasField(
1005                        "x".to_string(),
1006                        Box::new(Type::Variable(field_result_var.clone())),
1007                    )),
1008                },
1009            ),
1010            // obj_var = {x: int}
1011            (
1012                Type::Variable(obj_var),
1013                Type::Concrete(TypeAnnotation::Object(vec![ObjectTypeField {
1014                    name: "x".to_string(),
1015                    optional: false,
1016                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1017                    annotations: vec![],
1018                }])),
1019            ),
1020        ];
1021
1022        solver.solve(&mut constraints).unwrap();
1023
1024        // field_result_var should now be resolved to int via apply_bounds
1025        let resolved = solver
1026            .unifier()
1027            .apply_substitutions(&Type::Variable(field_result_var));
1028        match &resolved {
1029            Type::Concrete(TypeAnnotation::Basic(name)) => {
1030                assert_eq!(name, "int", "field type should be int");
1031            }
1032            _ => panic!(
1033                "Expected field_result_var to be resolved to int, got {:?}",
1034                resolved
1035            ),
1036        }
1037    }
1038
1039    #[test]
1040    fn test_hasfield_backward_propagation_multiple_fields() {
1041        // Test that multiple HasField constraints on the same object all propagate
1042        let mut solver = ConstraintSolver::new();
1043
1044        let obj_var = TypeVar::fresh();
1045        let field_x_var = TypeVar::fresh();
1046        let field_y_var = TypeVar::fresh();
1047        let bound_var_x = TypeVar::fresh();
1048        let bound_var_y = TypeVar::fresh();
1049
1050        let mut constraints = vec![
1051            // HasField("x", field_x_var)
1052            (
1053                Type::Variable(obj_var.clone()),
1054                Type::Constrained {
1055                    var: bound_var_x,
1056                    constraint: Box::new(TypeConstraint::HasField(
1057                        "x".to_string(),
1058                        Box::new(Type::Variable(field_x_var.clone())),
1059                    )),
1060                },
1061            ),
1062            // HasField("y", field_y_var)
1063            (
1064                Type::Variable(obj_var.clone()),
1065                Type::Constrained {
1066                    var: bound_var_y,
1067                    constraint: Box::new(TypeConstraint::HasField(
1068                        "y".to_string(),
1069                        Box::new(Type::Variable(field_y_var.clone())),
1070                    )),
1071                },
1072            ),
1073            // obj_var = {x: int, y: string}
1074            (
1075                Type::Variable(obj_var),
1076                Type::Concrete(TypeAnnotation::Object(vec![
1077                    ObjectTypeField {
1078                        name: "x".to_string(),
1079                        optional: false,
1080                        type_annotation: TypeAnnotation::Basic("int".to_string()),
1081                        annotations: vec![],
1082                    },
1083                    ObjectTypeField {
1084                        name: "y".to_string(),
1085                        optional: false,
1086                        type_annotation: TypeAnnotation::Basic("string".to_string()),
1087                        annotations: vec![],
1088                    },
1089                ])),
1090            ),
1091        ];
1092
1093        solver.solve(&mut constraints).unwrap();
1094
1095        let resolved_x = solver
1096            .unifier()
1097            .apply_substitutions(&Type::Variable(field_x_var));
1098        let resolved_y = solver
1099            .unifier()
1100            .apply_substitutions(&Type::Variable(field_y_var));
1101
1102        match &resolved_x {
1103            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "int"),
1104            _ => panic!("Expected x to be int, got {:?}", resolved_x),
1105        }
1106        match &resolved_y {
1107            Type::Concrete(TypeAnnotation::Basic(name)) => assert_eq!(name, "string"),
1108            _ => panic!("Expected y to be string, got {:?}", resolved_y),
1109        }
1110    }
1111
1112    // ===== Fix 1: Numeric type preservation tests =====
1113
1114    #[test]
1115    fn test_int_constrained_numeric_succeeds() {
1116        // Concrete(int) ~ Constrained(Numeric) should succeed
1117        let mut solver = ConstraintSolver::new();
1118        let bound_var = TypeVar::fresh();
1119        let mut constraints = vec![(
1120            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1121            Type::Constrained {
1122                var: bound_var,
1123                constraint: Box::new(TypeConstraint::Numeric),
1124            },
1125        )];
1126        assert!(solver.solve(&mut constraints).is_ok());
1127    }
1128
1129    #[test]
1130    fn test_numeric_widening_int_to_number() {
1131        // (Concrete(int), Concrete(number)) should succeed via widening
1132        let mut solver = ConstraintSolver::new();
1133        let mut constraints = vec![(
1134            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1135            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1136        )];
1137        assert!(solver.solve(&mut constraints).is_ok());
1138    }
1139
1140    #[test]
1141    fn test_numeric_widening_width_aware_integer_to_float_family() {
1142        let mut solver = ConstraintSolver::new();
1143        let mut constraints = vec![(
1144            Type::Concrete(TypeAnnotation::Basic("i16".to_string())),
1145            Type::Concrete(TypeAnnotation::Basic("f32".to_string())),
1146        )];
1147        assert!(solver.solve(&mut constraints).is_ok());
1148    }
1149
1150    #[test]
1151    fn test_no_widening_number_to_int() {
1152        // (Concrete(number), Concrete(int)) should fail — lossy
1153        let mut solver = ConstraintSolver::new();
1154        let mut constraints = vec![(
1155            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1156            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1157        )];
1158        assert!(solver.solve(&mut constraints).is_err());
1159    }
1160
1161    #[test]
1162    fn test_decimal_constrained_numeric_succeeds() {
1163        let mut solver = ConstraintSolver::new();
1164        let bound_var = TypeVar::fresh();
1165        let mut constraints = vec![(
1166            Type::Concrete(TypeAnnotation::Basic("decimal".to_string())),
1167            Type::Constrained {
1168                var: bound_var,
1169                constraint: Box::new(TypeConstraint::Numeric),
1170            },
1171        )];
1172        assert!(solver.solve(&mut constraints).is_ok());
1173    }
1174
1175    #[test]
1176    fn test_comparable_accepts_int() {
1177        // int should be Comparable
1178        let mut solver = ConstraintSolver::new();
1179        let bound_var = TypeVar::fresh();
1180        let mut constraints = vec![(
1181            Type::Concrete(TypeAnnotation::Basic("int".to_string())),
1182            Type::Constrained {
1183                var: bound_var,
1184                constraint: Box::new(TypeConstraint::Comparable),
1185            },
1186        )];
1187        assert!(solver.solve(&mut constraints).is_ok());
1188    }
1189
1190    // ===== Fix 2: Type::Function tests =====
1191
1192    #[test]
1193    fn test_function_type_preserves_variables() {
1194        // BuiltinTypes::function with Variable params should be Type::Function
1195        let param = Type::Variable(TypeVar::fresh());
1196        let ret = Type::Variable(TypeVar::fresh());
1197        let func = BuiltinTypes::function(vec![param.clone()], ret.clone());
1198        match func {
1199            Type::Function { params, returns } => {
1200                assert_eq!(params.len(), 1);
1201                assert_eq!(params[0], param);
1202                assert_eq!(*returns, ret);
1203            }
1204            _ => panic!("Expected Type::Function, got {:?}", func),
1205        }
1206    }
1207
1208    #[test]
1209    fn test_function_unification_binds_variables() {
1210        // (T1)->T2 ~ (number)->string should bind T1=number, T2=string
1211        let mut solver = ConstraintSolver::new();
1212        let t1 = TypeVar::fresh();
1213        let t2 = TypeVar::fresh();
1214
1215        let mut constraints = vec![(
1216            Type::Function {
1217                params: vec![Type::Variable(t1.clone())],
1218                returns: Box::new(Type::Variable(t2.clone())),
1219            },
1220            Type::Function {
1221                params: vec![BuiltinTypes::number()],
1222                returns: Box::new(BuiltinTypes::string()),
1223            },
1224        )];
1225
1226        solver.solve(&mut constraints).unwrap();
1227
1228        let resolved_t1 = solver.unifier().apply_substitutions(&Type::Variable(t1));
1229        let resolved_t2 = solver.unifier().apply_substitutions(&Type::Variable(t2));
1230        assert_eq!(resolved_t1, BuiltinTypes::number());
1231        assert_eq!(resolved_t2, BuiltinTypes::string());
1232    }
1233
1234    #[test]
1235    fn test_function_cross_unification_with_concrete() {
1236        // Type::Function ~ Concrete(TypeAnnotation::Function) should unify
1237        let mut solver = ConstraintSolver::new();
1238        let t1 = TypeVar::fresh();
1239
1240        let concrete_func = Type::Concrete(TypeAnnotation::Function {
1241            params: vec![shape_ast::ast::FunctionParam {
1242                name: None,
1243                optional: false,
1244                type_annotation: TypeAnnotation::Basic("number".to_string()),
1245            }],
1246            returns: Box::new(TypeAnnotation::Basic("string".to_string())),
1247        });
1248
1249        let mut constraints = vec![(
1250            Type::Function {
1251                params: vec![Type::Variable(t1.clone())],
1252                returns: Box::new(BuiltinTypes::string()),
1253            },
1254            concrete_func,
1255        )];
1256
1257        solver.solve(&mut constraints).unwrap();
1258
1259        let resolved = solver.unifier().apply_substitutions(&Type::Variable(t1));
1260        assert_eq!(resolved, BuiltinTypes::number());
1261    }
1262
1263    #[test]
1264    fn test_object_annotations_unify_structurally() {
1265        let mut solver = ConstraintSolver::new();
1266        let mut constraints = vec![(
1267            Type::Concrete(TypeAnnotation::Object(vec![
1268                ObjectTypeField {
1269                    name: "x".to_string(),
1270                    optional: false,
1271                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1272                    annotations: vec![],
1273                },
1274                ObjectTypeField {
1275                    name: "y".to_string(),
1276                    optional: false,
1277                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1278                    annotations: vec![],
1279                },
1280            ])),
1281            Type::Concrete(TypeAnnotation::Object(vec![
1282                ObjectTypeField {
1283                    name: "x".to_string(),
1284                    optional: false,
1285                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1286                    annotations: vec![],
1287                },
1288                ObjectTypeField {
1289                    name: "y".to_string(),
1290                    optional: false,
1291                    type_annotation: TypeAnnotation::Basic("int".to_string()),
1292                    annotations: vec![],
1293                },
1294            ])),
1295        )];
1296        assert!(solver.solve(&mut constraints).is_ok());
1297    }
1298
1299    #[test]
1300    fn test_intersection_annotations_unify_order_independent() {
1301        let mut solver = ConstraintSolver::new();
1302        let obj_xy = TypeAnnotation::Object(vec![
1303            ObjectTypeField {
1304                name: "x".to_string(),
1305                optional: false,
1306                type_annotation: TypeAnnotation::Basic("int".to_string()),
1307                annotations: vec![],
1308            },
1309            ObjectTypeField {
1310                name: "y".to_string(),
1311                optional: false,
1312                type_annotation: TypeAnnotation::Basic("int".to_string()),
1313                annotations: vec![],
1314            },
1315        ]);
1316        let obj_z = TypeAnnotation::Object(vec![ObjectTypeField {
1317            name: "z".to_string(),
1318            optional: false,
1319            type_annotation: TypeAnnotation::Basic("int".to_string()),
1320            annotations: vec![],
1321        }]);
1322
1323        let mut constraints = vec![(
1324            Type::Concrete(TypeAnnotation::Intersection(vec![
1325                obj_xy.clone(),
1326                obj_z.clone(),
1327            ])),
1328            Type::Concrete(TypeAnnotation::Intersection(vec![obj_z, obj_xy])),
1329        )];
1330        assert!(solver.solve(&mut constraints).is_ok());
1331    }
1332
1333    // ===== Sprint 2: ImplementsTrait constraint tests =====
1334
1335    #[test]
1336    fn test_implements_trait_satisfied() {
1337        let mut solver = ConstraintSolver::new();
1338        let mut impls = std::collections::HashSet::new();
1339        impls.insert("Comparable::number".to_string());
1340        solver.set_trait_impls(impls);
1341
1342        let bound_var = TypeVar::fresh();
1343        let mut constraints = vec![(
1344            Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1345            Type::Constrained {
1346                var: bound_var,
1347                constraint: Box::new(TypeConstraint::ImplementsTrait {
1348                    trait_name: "Comparable".to_string(),
1349                }),
1350            },
1351        )];
1352        assert!(solver.solve(&mut constraints).is_ok());
1353    }
1354
1355    #[test]
1356    fn test_implements_trait_violated() {
1357        let mut solver = ConstraintSolver::new();
1358        // No trait impls registered — string doesn't implement Comparable
1359        let bound_var = TypeVar::fresh();
1360        let mut constraints = vec![(
1361            Type::Concrete(TypeAnnotation::Basic("string".to_string())),
1362            Type::Constrained {
1363                var: bound_var,
1364                constraint: Box::new(TypeConstraint::ImplementsTrait {
1365                    trait_name: "Comparable".to_string(),
1366                }),
1367            },
1368        )];
1369        let result = solver.solve(&mut constraints);
1370        assert!(result.is_err());
1371        match result.unwrap_err() {
1372            TypeError::TraitBoundViolation {
1373                type_name,
1374                trait_name,
1375            } => {
1376                assert_eq!(type_name, "string");
1377                assert_eq!(trait_name, "Comparable");
1378            }
1379            other => panic!("Expected TraitBoundViolation, got: {:?}", other),
1380        }
1381    }
1382
1383    #[test]
1384    fn test_implements_trait_via_variable_resolution() {
1385        let mut solver = ConstraintSolver::new();
1386        let mut impls = std::collections::HashSet::new();
1387        impls.insert("Sortable::number".to_string());
1388        solver.set_trait_impls(impls);
1389
1390        let type_var = TypeVar::fresh();
1391        let bound_var = TypeVar::fresh();
1392
1393        let mut constraints = vec![
1394            // T: Sortable
1395            (
1396                Type::Variable(type_var.clone()),
1397                Type::Constrained {
1398                    var: bound_var,
1399                    constraint: Box::new(TypeConstraint::ImplementsTrait {
1400                        trait_name: "Sortable".to_string(),
1401                    }),
1402                },
1403            ),
1404            // T = number
1405            (
1406                Type::Variable(type_var),
1407                Type::Concrete(TypeAnnotation::Basic("number".to_string())),
1408            ),
1409        ];
1410        assert!(
1411            solver.solve(&mut constraints).is_ok(),
1412            "T resolved to number which implements Sortable"
1413        );
1414    }
1415}