Skip to main content

tsz_solver/
infer_resolve.rs

1//! Inference resolution, variance analysis, and constraint strengthening.
2//!
3//! This module contains the resolution phase of type inference:
4//! - Constraint-based resolution (upper/lower bounds)
5//! - Candidate filtering and widening
6//! - Variance analysis for type parameters
7//! - Circular constraint unification (SCC/Tarjan)
8//! - Constraint strengthening and propagation
9//! - Variable fixing and substitution building
10
11use crate::infer::{
12    InferenceCandidate, InferenceContext, InferenceError, InferenceInfo, InferenceVar,
13    MAX_CONSTRAINT_ITERATIONS, MAX_TYPE_RECURSION_DEPTH,
14};
15use crate::instantiate::TypeSubstitution;
16use crate::types::{InferencePriority, TemplateSpan, TypeData, TypeId};
17use crate::widening;
18use rustc_hash::FxHashSet;
19use tsz_common::interner::Atom;
20
21struct VarianceState<'a> {
22    target_param: Atom,
23    covariant: &'a mut u32,
24    contravariant: &'a mut u32,
25}
26
27impl<'a> InferenceContext<'a> {
28    // =========================================================================
29    // Bounds Checking and Resolution
30    // =========================================================================
31
32    /// Resolve an inference variable using its collected constraints.
33    ///
34    /// Algorithm:
35    /// 1. If already unified to a concrete type, return that
36    /// 2. Otherwise, compute the best common type from lower bounds
37    /// 3. Validate against upper bounds
38    /// 4. If no lower bounds, use the constraint (upper bound) or default
39    pub fn resolve_with_constraints(
40        &mut self,
41        var: InferenceVar,
42    ) -> Result<TypeId, InferenceError> {
43        // Check if already resolved
44        if let Some(ty) = self.probe(var) {
45            return Ok(ty);
46        }
47
48        let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
49
50        // Validate against upper bounds
51        if !upper_bounds_only {
52            let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
53            if let Some(upper) =
54                self.first_failed_upper_bound(result, &filtered_upper_bounds, |a, b| {
55                    self.is_subtype(a, b)
56                })
57            {
58                return Err(InferenceError::BoundsViolation {
59                    var,
60                    lower: result,
61                    upper,
62                });
63            }
64        }
65
66        if self.occurs_in(root, result) {
67            return Err(InferenceError::OccursCheck {
68                var: root,
69                ty: result,
70            });
71        }
72
73        // Store the result
74        self.table.union_value(
75            root,
76            InferenceInfo {
77                resolved: Some(result),
78                ..InferenceInfo::default()
79            },
80        );
81
82        Ok(result)
83    }
84
85    /// Resolve an inference variable using its collected constraints and a custom
86    /// assignability check for upper-bound validation.
87    pub fn resolve_with_constraints_by<F>(
88        &mut self,
89        var: InferenceVar,
90        is_subtype: F,
91    ) -> Result<TypeId, InferenceError>
92    where
93        F: FnMut(TypeId, TypeId) -> bool,
94    {
95        // Check if already resolved
96        if let Some(ty) = self.probe(var) {
97            return Ok(ty);
98        }
99
100        let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
101
102        if !upper_bounds_only {
103            let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
104            if let Some(upper) =
105                self.first_failed_upper_bound(result, &filtered_upper_bounds, is_subtype)
106            {
107                return Err(InferenceError::BoundsViolation {
108                    var,
109                    lower: result,
110                    upper,
111                });
112            }
113        }
114
115        if self.occurs_in(root, result) {
116            return Err(InferenceError::OccursCheck {
117                var: root,
118                ty: result,
119            });
120        }
121
122        self.table.union_value(
123            root,
124            InferenceInfo {
125                resolved: Some(result),
126                ..InferenceInfo::default()
127            },
128        );
129
130        Ok(result)
131    }
132
133    fn filter_relevant_upper_bounds(upper_bounds: &[TypeId]) -> Vec<TypeId> {
134        upper_bounds
135            .iter()
136            .copied()
137            .filter(|&upper| !matches!(upper, TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR))
138            .collect()
139    }
140
141    fn first_failed_upper_bound<F>(
142        &self,
143        result: TypeId,
144        filtered_upper_bounds: &[TypeId],
145        mut is_subtype: F,
146    ) -> Option<TypeId>
147    where
148        F: FnMut(TypeId, TypeId) -> bool,
149    {
150        match filtered_upper_bounds {
151            [] => None,
152            [single] => (!is_subtype(result, *single)).then_some(*single),
153            many => {
154                // Building and checking a very large synthetic intersection can be
155                // more expensive than directly validating bounds one-by-one.
156                // Keep the intersection shortcut for small/medium bound sets only.
157                if many.len() <= Self::UPPER_BOUND_INTERSECTION_FAST_PATH_LIMIT {
158                    let intersection = self.interner.intersection(many.to_vec());
159                    if is_subtype(result, intersection) {
160                        return None;
161                    }
162                }
163                // For very large upper-bound sets, a single intersection check can
164                // still be profitable in the common success path (all bounds satisfy).
165                // Fall back to per-bound checks if that coarse check fails.
166                if many.len() >= Self::UPPER_BOUND_INTERSECTION_LARGE_SET_THRESHOLD
167                    && self.should_try_large_upper_bound_intersection(result, many)
168                {
169                    let intersection = self.interner.intersection(many.to_vec());
170                    if is_subtype(result, intersection) {
171                        return None;
172                    }
173                }
174                many.iter()
175                    .copied()
176                    .find(|&upper| !is_subtype(result, upper))
177            }
178        }
179    }
180
181    fn should_try_large_upper_bound_intersection(&self, result: TypeId, bounds: &[TypeId]) -> bool {
182        self.is_object_like_upper_bound(result)
183            && bounds
184                .iter()
185                .copied()
186                .all(|bound| self.is_object_like_upper_bound(bound))
187    }
188
189    fn is_object_like_upper_bound(&self, ty: TypeId) -> bool {
190        match self.interner.lookup(ty) {
191            Some(
192                TypeData::Object(_)
193                | TypeData::ObjectWithIndex(_)
194                | TypeData::Lazy(_)
195                | TypeData::Intersection(_),
196            ) => true,
197            Some(TypeData::TypeParameter(info)) => info
198                .constraint
199                .is_some_and(|constraint| self.is_object_like_upper_bound(constraint)),
200            _ => false,
201        }
202    }
203
204    fn compute_constraint_result(
205        &mut self,
206        var: InferenceVar,
207    ) -> (InferenceVar, TypeId, Vec<TypeId>, bool) {
208        let root = self.table.find(var);
209        let info = self.table.probe_value(root);
210        let target_names = self.type_param_names_for_root(root);
211        let mut upper_bounds = Vec::new();
212        let mut seen_upper_bounds = FxHashSet::default();
213        let mut candidates = info.candidates;
214        for bound in info.upper_bounds {
215            if self.occurs_in(root, bound) {
216                continue;
217            }
218            if !target_names.is_empty() && self.upper_bound_cycles_param(bound, &target_names) {
219                self.expand_cyclic_upper_bound(
220                    root,
221                    bound,
222                    &target_names,
223                    &mut candidates,
224                    &mut upper_bounds,
225                );
226                continue;
227            }
228            if seen_upper_bounds.insert(bound) {
229                upper_bounds.push(bound);
230            }
231        }
232
233        if !upper_bounds.is_empty() {
234            candidates.retain(|candidate| {
235                !matches!(
236                    candidate.type_id,
237                    TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR
238                )
239            });
240        }
241
242        // Check if this is a const type parameter to preserve literal types
243        let is_const = self.is_var_const(root);
244
245        let upper_bounds_only = candidates.is_empty() && !upper_bounds.is_empty();
246
247        let result = if !candidates.is_empty() {
248            self.resolve_from_candidates(&candidates, is_const, &upper_bounds)
249        } else if !upper_bounds.is_empty() {
250            // RESTORED: Fall back to upper bounds (constraints) when no candidates exist.
251            // This matches TypeScript: un-inferred generics default to their constraint.
252            // We use intersection in case there are multiple upper bounds (T extends A, T extends B).
253            if upper_bounds.len() == 1 {
254                upper_bounds[0]
255            } else {
256                self.interner.intersection(upper_bounds.clone())
257            }
258        } else {
259            // Only return UNKNOWN if there are NO candidates AND NO upper bounds
260            TypeId::UNKNOWN
261        };
262
263        (root, result, upper_bounds, upper_bounds_only)
264    }
265
266    /// Resolve all type parameters using constraints.
267    pub fn resolve_all_with_constraints(&mut self) -> Result<Vec<(Atom, TypeId)>, InferenceError> {
268        // CRITICAL: Strengthen inter-parameter constraints before resolution
269        // This ensures that constraints flow between dependent type parameters
270        // Example: If T extends U, and T is constrained to string, then U is also
271        // constrained to accept string (string must be assignable to U)
272        self.strengthen_constraints()?;
273
274        let type_params: Vec<_> = self.type_params.clone();
275        let mut results = Vec::new();
276
277        for (name, var, _) in type_params {
278            let ty = self.resolve_with_constraints(var)?;
279            results.push((name, ty));
280        }
281
282        Ok(results)
283    }
284
285    fn resolve_from_candidates(
286        &self,
287        candidates: &[InferenceCandidate],
288        is_const: bool,
289        upper_bounds: &[TypeId],
290    ) -> TypeId {
291        let filtered = self.filter_candidates_by_priority(candidates);
292        if filtered.is_empty() {
293            return TypeId::UNKNOWN;
294        }
295        let filtered_no_never: Vec<_> = filtered
296            .iter()
297            .filter(|c| c.type_id != TypeId::NEVER)
298            .cloned()
299            .collect();
300        if filtered_no_never.is_empty() {
301            return TypeId::NEVER;
302        }
303        // TypeScript preserves literal types when the constraint implies literals
304        // (e.g., T extends "a" | "b"). Widening "b" to string would violate the constraint.
305        let preserve_literals = is_const || self.constraint_implies_literals(upper_bounds);
306        let widened = if preserve_literals {
307            if is_const {
308                filtered_no_never
309                    .iter()
310                    .map(|c| widening::apply_const_assertion(self.interner, c.type_id))
311                    .collect()
312            } else {
313                filtered_no_never.iter().map(|c| c.type_id).collect()
314            }
315        } else {
316            self.widen_candidate_types(&filtered_no_never)
317        };
318        self.best_common_type(&widened)
319    }
320
321    /// Check if any upper bounds contain or imply literal types.
322    fn constraint_implies_literals(&self, upper_bounds: &[TypeId]) -> bool {
323        upper_bounds
324            .iter()
325            .any(|&bound| self.type_implies_literals(bound))
326    }
327
328    /// Check if a type contains literal types (directly or in unions/intersections).
329    fn type_implies_literals(&self, type_id: TypeId) -> bool {
330        match self.interner.lookup(type_id) {
331            Some(TypeData::Literal(_)) => true,
332            Some(TypeData::Union(list_id)) => {
333                let members = self.interner.type_list(list_id);
334                members.iter().any(|&m| self.type_implies_literals(m))
335            }
336            Some(TypeData::Intersection(list_id)) => {
337                let members = self.interner.type_list(list_id);
338                members.iter().any(|&m| self.type_implies_literals(m))
339            }
340            _ => false,
341        }
342    }
343
344    /// Filter candidates by priority using `InferencePriority`.
345    ///
346    /// CRITICAL FIX: In the new enum, LOWER values = HIGHER priority (processed earlier).
347    /// - `NakedTypeVariable` (1) is highest priority
348    /// - `ReturnType` (32) is lower priority
349    ///
350    /// Therefore we use `.min()` instead of `.max()` to find the highest priority candidate.
351    fn filter_candidates_by_priority(
352        &self,
353        candidates: &[InferenceCandidate],
354    ) -> Vec<InferenceCandidate> {
355        let Some(best_priority) = candidates.iter().map(|c| c.priority).min() else {
356            return Vec::new();
357        };
358        candidates
359            .iter()
360            .filter(|candidate| candidate.priority == best_priority)
361            .cloned()
362            .collect()
363    }
364
365    fn widen_candidate_types(&self, candidates: &[InferenceCandidate]) -> Vec<TypeId> {
366        candidates
367            .iter()
368            .map(|candidate| {
369                // Always widen fresh literal candidates to their base type.
370                // TypeScript widens fresh literals (0 → number, false → boolean)
371                // during inference resolution. Const type parameters are protected
372                // by the is_const check in resolve_from_candidates which uses
373                // apply_const_assertion instead of this method.
374                if candidate.is_fresh_literal {
375                    self.get_base_type(candidate.type_id)
376                        .unwrap_or(candidate.type_id)
377                } else {
378                    candidate.type_id
379                }
380            })
381            .collect()
382    }
383
384    // =========================================================================
385    // Conditional Type Inference
386    // =========================================================================
387
388    /// Infer type parameters from a conditional type.
389    /// When a type parameter appears in a conditional type, we can sometimes
390    /// infer its value from the check and extends clauses.
391    pub fn infer_from_conditional(
392        &mut self,
393        var: InferenceVar,
394        check_type: TypeId,
395        extends_type: TypeId,
396        true_type: TypeId,
397        false_type: TypeId,
398    ) {
399        // If check_type is an inference variable, try to infer from extends_type
400        if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(check_type)
401            && let Some(check_var) = self.find_type_param(info.name)
402            && check_var == self.table.find(var)
403        {
404            // check_type is this variable
405            // Try to infer from extends_type as an upper bound
406            self.add_upper_bound(var, extends_type);
407        }
408
409        // Recursively infer from true/false branches
410        self.infer_from_type(var, true_type);
411        self.infer_from_type(var, false_type);
412    }
413
414    /// Infer type parameters from a type by traversing its structure.
415    fn infer_from_type(&mut self, var: InferenceVar, ty: TypeId) {
416        let root = self.table.find(var);
417
418        // Check if this type contains the inference variable
419        if !self.contains_inference_var(ty, root) {
420            return;
421        }
422
423        match self.interner.lookup(ty) {
424            Some(TypeData::TypeParameter(info)) => {
425                if let Some(param_var) = self.find_type_param(info.name)
426                    && self.table.find(param_var) == root
427                {
428                    // This type is the inference variable itself
429                    // Extract bounds from constraint if present
430                    if let Some(constraint) = info.constraint {
431                        self.add_upper_bound(var, constraint);
432                    }
433                }
434            }
435            Some(TypeData::Array(elem)) => {
436                self.infer_from_type(var, elem);
437            }
438            Some(TypeData::Tuple(elements)) => {
439                let elements = self.interner.tuple_list(elements);
440                for elem in elements.iter() {
441                    self.infer_from_type(var, elem.type_id);
442                }
443            }
444            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
445                let members = self.interner.type_list(members);
446                for &member in members.iter() {
447                    self.infer_from_type(var, member);
448                }
449            }
450            Some(TypeData::Object(shape_id)) => {
451                let shape = self.interner.object_shape(shape_id);
452                for prop in &shape.properties {
453                    self.infer_from_type(var, prop.type_id);
454                }
455            }
456            Some(TypeData::ObjectWithIndex(shape_id)) => {
457                let shape = self.interner.object_shape(shape_id);
458                for prop in &shape.properties {
459                    self.infer_from_type(var, prop.type_id);
460                }
461                if let Some(index) = shape.string_index.as_ref() {
462                    self.infer_from_type(var, index.key_type);
463                    self.infer_from_type(var, index.value_type);
464                }
465                if let Some(index) = shape.number_index.as_ref() {
466                    self.infer_from_type(var, index.key_type);
467                    self.infer_from_type(var, index.value_type);
468                }
469            }
470            Some(TypeData::Application(app_id)) => {
471                let app = self.interner.type_application(app_id);
472                self.infer_from_type(var, app.base);
473                for &arg in &app.args {
474                    self.infer_from_type(var, arg);
475                }
476            }
477            Some(TypeData::Function(shape_id)) => {
478                let shape = self.interner.function_shape(shape_id);
479                for param in &shape.params {
480                    self.infer_from_type(var, param.type_id);
481                }
482                if let Some(this_type) = shape.this_type {
483                    self.infer_from_type(var, this_type);
484                }
485                self.infer_from_type(var, shape.return_type);
486            }
487            Some(TypeData::Conditional(cond_id)) => {
488                let cond = self.interner.conditional_type(cond_id);
489                self.infer_from_conditional(
490                    var,
491                    cond.check_type,
492                    cond.extends_type,
493                    cond.true_type,
494                    cond.false_type,
495                );
496            }
497            Some(TypeData::TemplateLiteral(spans)) => {
498                // Traverse template literal spans to find inference variables
499                let spans = self.interner.template_list(spans);
500                for span in spans.iter() {
501                    if let TemplateSpan::Type(inner) = span {
502                        self.infer_from_type(var, *inner);
503                    }
504                }
505            }
506            _ => {}
507        }
508    }
509
510    /// Check if a type contains an inference variable.
511    pub(crate) fn contains_inference_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
512        let mut visited = FxHashSet::default();
513        self.contains_inference_var_inner(ty, var, &mut visited, 0)
514    }
515
516    fn contains_inference_var_inner(
517        &mut self,
518        ty: TypeId,
519        var: InferenceVar,
520        visited: &mut FxHashSet<TypeId>,
521        depth: usize,
522    ) -> bool {
523        // Safety limit to prevent infinite recursion on deeply nested or cyclic types
524        if depth > MAX_TYPE_RECURSION_DEPTH {
525            return false;
526        }
527        // Prevent infinite loops on cyclic types
528        if !visited.insert(ty) {
529            return false;
530        }
531
532        let root = self.table.find(var);
533
534        match self.interner.lookup(ty) {
535            Some(TypeData::TypeParameter(info) | TypeData::Infer(info)) => {
536                if let Some(param_var) = self.find_type_param(info.name) {
537                    self.table.find(param_var) == root
538                } else {
539                    false
540                }
541            }
542            Some(TypeData::Array(elem)) => {
543                self.contains_inference_var_inner(elem, var, visited, depth + 1)
544            }
545            Some(TypeData::Tuple(elements)) => {
546                let elements = self.interner.tuple_list(elements);
547                elements
548                    .iter()
549                    .any(|e| self.contains_inference_var_inner(e.type_id, var, visited, depth + 1))
550            }
551            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
552                let members = self.interner.type_list(members);
553                members
554                    .iter()
555                    .any(|&m| self.contains_inference_var_inner(m, var, visited, depth + 1))
556            }
557            Some(TypeData::Object(shape_id)) => {
558                let shape = self.interner.object_shape(shape_id);
559                shape
560                    .properties
561                    .iter()
562                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
563            }
564            Some(TypeData::ObjectWithIndex(shape_id)) => {
565                let shape = self.interner.object_shape(shape_id);
566                shape
567                    .properties
568                    .iter()
569                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
570                    || shape.string_index.as_ref().is_some_and(|idx| {
571                        self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
572                            || self.contains_inference_var_inner(
573                                idx.value_type,
574                                var,
575                                visited,
576                                depth + 1,
577                            )
578                    })
579                    || shape.number_index.as_ref().is_some_and(|idx| {
580                        self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
581                            || self.contains_inference_var_inner(
582                                idx.value_type,
583                                var,
584                                visited,
585                                depth + 1,
586                            )
587                    })
588            }
589            Some(TypeData::Application(app_id)) => {
590                let app = self.interner.type_application(app_id);
591                self.contains_inference_var_inner(app.base, var, visited, depth + 1)
592                    || app
593                        .args
594                        .iter()
595                        .any(|&arg| self.contains_inference_var_inner(arg, var, visited, depth + 1))
596            }
597            Some(TypeData::Function(shape_id)) => {
598                let shape = self.interner.function_shape(shape_id);
599                shape
600                    .params
601                    .iter()
602                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
603                    || shape.this_type.is_some_and(|t| {
604                        self.contains_inference_var_inner(t, var, visited, depth + 1)
605                    })
606                    || self.contains_inference_var_inner(shape.return_type, var, visited, depth + 1)
607            }
608            Some(TypeData::Conditional(cond_id)) => {
609                let cond = self.interner.conditional_type(cond_id);
610                self.contains_inference_var_inner(cond.check_type, var, visited, depth + 1)
611                    || self.contains_inference_var_inner(cond.extends_type, var, visited, depth + 1)
612                    || self.contains_inference_var_inner(cond.true_type, var, visited, depth + 1)
613                    || self.contains_inference_var_inner(cond.false_type, var, visited, depth + 1)
614            }
615            Some(TypeData::TemplateLiteral(spans)) => {
616                let spans = self.interner.template_list(spans);
617                spans.iter().any(|span| match span {
618                    TemplateSpan::Text(_) => false,
619                    TemplateSpan::Type(inner) => {
620                        self.contains_inference_var_inner(*inner, var, visited, depth + 1)
621                    }
622                })
623            }
624            _ => false,
625        }
626    }
627
628    // =========================================================================
629    // Variance Inference
630    // =========================================================================
631
632    /// Compute the variance of a type parameter within a type.
633    /// Returns (`covariant_count`, `contravariant_count`, `invariant_count`, `bivariant_count`)
634    pub fn compute_variance(&self, ty: TypeId, target_param: Atom) -> (u32, u32, u32, u32) {
635        let mut covariant = 0u32;
636        let mut contravariant = 0u32;
637        let invariant = 0u32;
638        let bivariant = 0u32;
639        let mut state = VarianceState {
640            target_param,
641            covariant: &mut covariant,
642            contravariant: &mut contravariant,
643        };
644
645        self.compute_variance_helper(ty, true, &mut state);
646
647        (covariant, contravariant, invariant, bivariant)
648    }
649
650    fn compute_variance_helper(
651        &self,
652        ty: TypeId,
653        polarity: bool, // true = covariant, false = contravariant
654        state: &mut VarianceState<'_>,
655    ) {
656        match self.interner.lookup(ty) {
657            Some(TypeData::TypeParameter(info)) if info.name == state.target_param => {
658                if polarity {
659                    *state.covariant += 1;
660                } else {
661                    *state.contravariant += 1;
662                }
663            }
664            Some(TypeData::Array(elem)) => {
665                self.compute_variance_helper(elem, polarity, state);
666            }
667            Some(TypeData::Tuple(elements)) => {
668                let elements = self.interner.tuple_list(elements);
669                for elem in elements.iter() {
670                    self.compute_variance_helper(elem.type_id, polarity, state);
671                }
672            }
673            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
674                let members = self.interner.type_list(members);
675                for &member in members.iter() {
676                    self.compute_variance_helper(member, polarity, state);
677                }
678            }
679            Some(TypeData::Object(shape_id)) => {
680                let shape = self.interner.object_shape(shape_id);
681                for prop in &shape.properties {
682                    // Properties are covariant in their type (read position)
683                    self.compute_variance_helper(prop.type_id, polarity, state);
684                    // Properties are contravariant in their write type (write position)
685                    if prop.write_type != prop.type_id && !prop.readonly {
686                        self.compute_variance_helper(prop.write_type, !polarity, state);
687                    }
688                }
689            }
690            Some(TypeData::ObjectWithIndex(shape_id)) => {
691                let shape = self.interner.object_shape(shape_id);
692                for prop in &shape.properties {
693                    self.compute_variance_helper(prop.type_id, polarity, state);
694                    if prop.write_type != prop.type_id && !prop.readonly {
695                        self.compute_variance_helper(prop.write_type, !polarity, state);
696                    }
697                }
698                if let Some(index) = shape.string_index.as_ref() {
699                    self.compute_variance_helper(index.value_type, polarity, state);
700                }
701                if let Some(index) = shape.number_index.as_ref() {
702                    self.compute_variance_helper(index.value_type, polarity, state);
703                }
704            }
705            Some(TypeData::Application(app_id)) => {
706                let app = self.interner.type_application(app_id);
707                // Variance depends on the generic type definition
708                // For now, assume covariant for all type arguments
709                for &arg in &app.args {
710                    self.compute_variance_helper(arg, polarity, state);
711                }
712            }
713            Some(TypeData::Function(shape_id)) => {
714                let shape = self.interner.function_shape(shape_id);
715                // Parameters are contravariant
716                for param in &shape.params {
717                    self.compute_variance_helper(param.type_id, !polarity, state);
718                }
719                // Return type is covariant
720                self.compute_variance_helper(shape.return_type, polarity, state);
721            }
722            Some(TypeData::Conditional(cond_id)) => {
723                let cond = self.interner.conditional_type(cond_id);
724                // Conditional types are invariant in their type parameters
725                self.compute_variance_helper(cond.check_type, false, state);
726                self.compute_variance_helper(cond.extends_type, false, state);
727                // But can be either in the result
728                self.compute_variance_helper(cond.true_type, polarity, state);
729                self.compute_variance_helper(cond.false_type, polarity, state);
730            }
731            _ => {}
732        }
733    }
734
735    /// Check if a type parameter is invariant at a given position.
736    pub fn is_invariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
737        let (_, _, invariant, _) = self.compute_variance(ty, target_param);
738        invariant > 0
739    }
740
741    /// Check if a type parameter is bivariant at a given position.
742    pub fn is_bivariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
743        let (_, _, _, bivariant) = self.compute_variance(ty, target_param);
744        bivariant > 0
745    }
746
747    /// Get the variance of a type parameter as a string.
748    pub fn get_variance(&self, ty: TypeId, target_param: Atom) -> &'static str {
749        let (covariant, contravariant, invariant, bivariant) =
750            self.compute_variance(ty, target_param);
751
752        if invariant > 0 {
753            "invariant"
754        } else if bivariant > 0 {
755            "bivariant"
756        } else if covariant > 0 && contravariant > 0 {
757            "invariant" // Both covariant and contravariant means invariant
758        } else if covariant > 0 {
759            "covariant"
760        } else if contravariant > 0 {
761            "contravariant"
762        } else {
763            "unused"
764        }
765    }
766
767    // =========================================================================
768    // Enhanced Constraint Resolution
769    // =========================================================================
770
771    /// Try to infer a type parameter from its usage context.
772    /// This implements bidirectional type inference where the context
773    /// (e.g., return type, variable declaration) provides constraints.
774    pub fn infer_from_context(
775        &mut self,
776        var: InferenceVar,
777        context_type: TypeId,
778    ) -> Result<(), InferenceError> {
779        // Add context as an upper bound
780        self.add_upper_bound(var, context_type);
781
782        // If the context type contains this inference variable,
783        // we need to solve more carefully
784        let root = self.table.find(var);
785        if self.contains_inference_var(context_type, root) {
786            // Context contains the inference variable itself
787            // This is a recursive type - we need to handle it specially
788            return Err(InferenceError::OccursCheck {
789                var: root,
790                ty: context_type,
791            });
792        }
793
794        Ok(())
795    }
796
797    /// Detect and unify type parameters that form circular constraints.
798    /// For example, if T extends U and U extends T, they should be unified
799    /// into a single equivalence class for inference purposes.
800    fn unify_circular_constraints(&mut self) -> Result<(), InferenceError> {
801        use rustc_hash::{FxHashMap, FxHashSet};
802
803        let type_params: Vec<_> = self.type_params.clone();
804
805        // Build adjacency list: var -> set of vars it extends (upper bounds)
806        let mut graph: FxHashMap<InferenceVar, FxHashSet<InferenceVar>> = FxHashMap::default();
807        let mut var_for_param: FxHashMap<Atom, InferenceVar> = FxHashMap::default();
808
809        for (name, var, _) in &type_params {
810            let root = self.table.find(*var);
811            var_for_param.insert(*name, root);
812            graph.entry(root).or_default();
813        }
814
815        // Populate edges based on upper_bounds
816        for (_name, var, _) in &type_params {
817            let root = self.table.find(*var);
818            let info = self.table.probe_value(root);
819
820            for &upper in &info.upper_bounds {
821                // Only follow naked type parameter upper bounds (not List<T>, etc.)
822                if let Some(TypeData::TypeParameter(param_info)) = self.interner.lookup(upper)
823                    && let Some(&upper_var) = var_for_param.get(&param_info.name)
824                {
825                    let upper_root = self.table.find(upper_var);
826                    // Add edge: root extends upper_root
827                    graph.entry(root).or_default().insert(upper_root);
828                }
829            }
830        }
831
832        // Find SCCs using Tarjan's algorithm
833        let mut index_counter = 0;
834        let mut indices: FxHashMap<InferenceVar, usize> = FxHashMap::default();
835        let mut lowlink: FxHashMap<InferenceVar, usize> = FxHashMap::default();
836        let mut stack: Vec<InferenceVar> = Vec::new();
837        let mut on_stack: FxHashSet<InferenceVar> = FxHashSet::default();
838        let mut sccs: Vec<Vec<InferenceVar>> = Vec::new();
839
840        struct TarjanState<'a> {
841            graph: &'a FxHashMap<InferenceVar, FxHashSet<InferenceVar>>,
842            index_counter: &'a mut usize,
843            indices: &'a mut FxHashMap<InferenceVar, usize>,
844            lowlink: &'a mut FxHashMap<InferenceVar, usize>,
845            stack: &'a mut Vec<InferenceVar>,
846            on_stack: &'a mut FxHashSet<InferenceVar>,
847            sccs: &'a mut Vec<Vec<InferenceVar>>,
848        }
849
850        fn strongconnect(var: InferenceVar, state: &mut TarjanState) {
851            state.indices.insert(var, *state.index_counter);
852            state.lowlink.insert(var, *state.index_counter);
853            *state.index_counter += 1;
854            state.stack.push(var);
855            state.on_stack.insert(var);
856
857            if let Some(neighbors) = state.graph.get(&var) {
858                for &neighbor in neighbors {
859                    if !state.indices.contains_key(&neighbor) {
860                        strongconnect(neighbor, state);
861                        let neighbor_low = *state.lowlink.get(&neighbor).unwrap_or(&0);
862                        let var_low = state.lowlink.get_mut(&var).unwrap();
863                        *var_low = (*var_low).min(neighbor_low);
864                    } else if state.on_stack.contains(&neighbor) {
865                        let neighbor_idx = *state.indices.get(&neighbor).unwrap_or(&0);
866                        let var_low = state.lowlink.get_mut(&var).unwrap();
867                        *var_low = (*var_low).min(neighbor_idx);
868                    }
869                }
870            }
871
872            if *state.lowlink.get(&var).unwrap_or(&0) == *state.indices.get(&var).unwrap_or(&0) {
873                let mut scc = Vec::new();
874                loop {
875                    let w = state.stack.pop().unwrap();
876                    state.on_stack.remove(&w);
877                    scc.push(w);
878                    if w == var {
879                        break;
880                    }
881                }
882                state.sccs.push(scc);
883            }
884        }
885
886        // Run Tarjan's on all nodes
887        for &var in graph.keys() {
888            if !indices.contains_key(&var) {
889                let mut state = TarjanState {
890                    graph: &graph,
891                    index_counter: &mut index_counter,
892                    indices: &mut indices,
893                    lowlink: &mut lowlink,
894                    stack: &mut stack,
895                    on_stack: &mut on_stack,
896                    sccs: &mut sccs,
897                };
898                strongconnect(var, &mut state);
899            }
900        }
901
902        // Unify variables within each SCC (if SCC has >1 member)
903        for scc in sccs {
904            if scc.len() > 1 {
905                // Unify all variables in this SCC
906                let first = scc[0];
907                for &other in &scc[1..] {
908                    self.unify_vars(first, other)?;
909                }
910            }
911        }
912
913        Ok(())
914    }
915
916    /// Strengthen constraints by analyzing relationships between type parameters.
917    /// For example, if T <: U and we know T = string, then U must be at least string.
918    pub fn strengthen_constraints(&mut self) -> Result<(), InferenceError> {
919        // Detect and unify circular constraints (SCCs)
920        // This ensures that type parameters in cycles (T extends U, U extends T)
921        // are treated as a single equivalence class for inference.
922        self.unify_circular_constraints()?;
923
924        let type_params: Vec<_> = self.type_params.clone();
925        let mut changed = true;
926        let mut iterations = 0;
927
928        // Fixed-point propagation
929        // Iterate to fixed point - continue until no new candidates are added
930        while changed && iterations < MAX_CONSTRAINT_ITERATIONS {
931            changed = false;
932            iterations += 1;
933
934            for (name, var, _) in &type_params {
935                let root = self.table.find(*var);
936
937                // We need to clone info to avoid borrow checker issues while mutating
938                // This is expensive but necessary for correctness in this design
939                let info = self.table.probe_value(root).clone();
940
941                // Propagate candidates UP the extends chain
942                // If T extends U (T <: U), then candidates of T are also candidates of U
943                for &upper in &info.upper_bounds {
944                    if self.propagate_candidates_to_upper(root, upper, *name)? {
945                        changed = true;
946                    }
947                }
948            }
949        }
950        Ok(())
951    }
952
953    /// Propagates candidates from a subtype (var) to its supertype (upper).
954    /// If `var extends upper` (var <: upper), then candidates of `var` are also candidates of `upper`.
955    fn propagate_candidates_to_upper(
956        &mut self,
957        var_root: InferenceVar,
958        upper: TypeId,
959        exclude_param: Atom,
960    ) -> Result<bool, InferenceError> {
961        // Check if 'upper' is a type parameter we are inferring
962        if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(upper)
963            && info.name != exclude_param
964            && let Some(upper_var) = self.find_type_param(info.name)
965        {
966            let upper_root = self.table.find(upper_var);
967
968            // Don't propagate to self
969            if var_root == upper_root {
970                return Ok(false);
971            }
972
973            // Get candidates from the subtype (var)
974            let var_candidates = self.table.probe_value(var_root).candidates;
975
976            // Add them to the supertype (upper)
977            let mut changed = false;
978            for candidate in var_candidates {
979                // Use Circular priority to indicate this came from propagation
980                if self.add_candidate_if_new(
981                    upper_root,
982                    candidate.type_id,
983                    InferencePriority::Circular,
984                ) {
985                    changed = true;
986                }
987            }
988            return Ok(changed);
989        }
990        Ok(false)
991    }
992
993    /// Helper to track if we actually added something (for fixed-point loop)
994    fn add_candidate_if_new(
995        &mut self,
996        var: InferenceVar,
997        ty: TypeId,
998        priority: InferencePriority,
999    ) -> bool {
1000        let root = self.table.find(var);
1001        let info = self.table.probe_value(root);
1002
1003        // Check if type already exists in candidates
1004        if info.candidates.iter().any(|c| c.type_id == ty) {
1005            return false;
1006        }
1007
1008        self.add_candidate(var, ty, priority);
1009        true
1010    }
1011
1012    /// Validate that resolved types respect variance constraints.
1013    pub fn validate_variance(&mut self) -> Result<(), InferenceError> {
1014        let type_params: Vec<_> = self.type_params.clone();
1015        for (_name, var, _) in &type_params {
1016            let resolved = match self.probe(*var) {
1017                Some(ty) => ty,
1018                None => continue,
1019            };
1020
1021            // Check if this type parameter appears in its own resolved type
1022            // We use the occurs_in method which already exists and handles this
1023            if self.occurs_in(*var, resolved) {
1024                let root = self.table.find(*var);
1025                // This would be a circular reference
1026                return Err(InferenceError::OccursCheck {
1027                    var: root,
1028                    ty: resolved,
1029                });
1030            }
1031
1032            // For more advanced variance checking, we would need to know
1033            // the declared variance of each type parameter in its generic type
1034            // This is a placeholder for future enhancement
1035        }
1036
1037        Ok(())
1038    }
1039
1040    /// Fix (resolve) inference variables that have candidates from Round 1.
1041    ///
1042    /// This is called after processing non-contextual arguments to "fix" type
1043    /// variables that have enough information, before processing contextual
1044    /// arguments (like lambdas) in Round 2.
1045    ///
1046    /// The fixing process:
1047    /// 1. Finds variables with candidates but no resolved type yet
1048    /// 2. Computes their best current type from candidates
1049    /// 3. Sets the `resolved` field to prevent Round 2 from overriding
1050    ///
1051    /// Variables without candidates are NOT fixed (they might get info from Round 2).
1052    pub fn fix_current_variables(&mut self) -> Result<(), InferenceError> {
1053        let type_params: Vec<_> = self.type_params.clone();
1054
1055        for (_name, var, _is_const) in &type_params {
1056            let root = self.table.find(*var);
1057            let info = self.table.probe_value(root);
1058
1059            // Skip if already resolved
1060            if info.resolved.is_some() {
1061                continue;
1062            }
1063
1064            // Skip if no candidates yet (might get info from Round 2)
1065            if info.candidates.is_empty() {
1066                continue;
1067            }
1068
1069            // Compute the current best type from existing candidates
1070            // This uses the same logic as compute_constraint_result but doesn't
1071            // validate against upper bounds yet (that happens in final resolution)
1072            let is_const = self.is_var_const(root);
1073            let result =
1074                self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds);
1075
1076            // Check for occurs (recursive type)
1077            if self.occurs_in(root, result) {
1078                // Don't fix variables with occurs - let them be resolved later
1079                continue;
1080            }
1081
1082            // Fix this variable by setting resolved field
1083            // This prevents Round 2 from overriding with lower-priority constraints
1084            self.table.union_value(
1085                root,
1086                InferenceInfo {
1087                    resolved: Some(result),
1088                    // Keep candidates and upper_bounds for later validation
1089                    candidates: info.candidates,
1090                    upper_bounds: info.upper_bounds,
1091                },
1092            );
1093        }
1094
1095        Ok(())
1096    }
1097
1098    /// Get the current best substitution for all type parameters.
1099    ///
1100    /// This returns a `TypeSubstitution` mapping each type parameter to its
1101    /// current best type (either resolved or the best candidate so far).
1102    /// Used in Round 2 to provide contextual types to lambda arguments.
1103    pub fn get_current_substitution(&mut self) -> TypeSubstitution {
1104        let mut subst = TypeSubstitution::new();
1105        let type_params: Vec<_> = self.type_params.clone();
1106
1107        for (name, var, _) in &type_params {
1108            let ty = match self.probe(*var) {
1109                Some(resolved) => {
1110                    tracing::trace!(
1111                        ?name,
1112                        ?var,
1113                        ?resolved,
1114                        "get_current_substitution: already resolved"
1115                    );
1116                    resolved
1117                }
1118                None => {
1119                    // Not resolved yet, try to get best candidate
1120                    let root = self.table.find(*var);
1121                    let info = self.table.probe_value(root);
1122                    tracing::trace!(
1123                        ?name, ?var,
1124                        candidates_count = info.candidates.len(),
1125                        upper_bounds_count = info.upper_bounds.len(),
1126                        upper_bounds = ?info.upper_bounds,
1127                        "get_current_substitution: not resolved"
1128                    );
1129
1130                    if !info.candidates.is_empty() {
1131                        let is_const = self.is_var_const(root);
1132                        self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds)
1133                    } else if !info.upper_bounds.is_empty() {
1134                        // No candidates yet, but we have a constraint (upper bound).
1135                        // Use the constraint as contextual fallback so that mapped types
1136                        // like `{ [K in keyof P]: P[K] }` resolve using the constraint
1137                        // type. This matches tsc's behavior for contextual typing of
1138                        // generic call arguments when all arguments are context-sensitive.
1139                        if info.upper_bounds.len() == 1 {
1140                            info.upper_bounds[0]
1141                        } else {
1142                            self.interner.intersection(info.upper_bounds.to_vec())
1143                        }
1144                    } else {
1145                        // No info yet, use unknown as placeholder
1146                        TypeId::UNKNOWN
1147                    }
1148                }
1149            };
1150
1151            subst.insert(*name, ty);
1152        }
1153
1154        subst
1155    }
1156}