Skip to main content

tsz_solver/inference/
infer_resolve.rs

1//! Inference resolution, variance analysis, and constraint strengthening.
2//!
3//! This module contains the resolution phase of type inference:
4//! - Constraint-based resolution (upper/lower bounds)
5//! - Candidate filtering and widening
6//! - Variance analysis for type parameters
7//! - Circular constraint unification (SCC/Tarjan)
8//! - Constraint strengthening and propagation
9//! - Variable fixing and substitution building
10
11use crate::inference::infer::{
12    InferenceCandidate, InferenceContext, InferenceError, InferenceInfo, InferenceVar,
13    MAX_CONSTRAINT_ITERATIONS, MAX_TYPE_RECURSION_DEPTH,
14};
15use crate::instantiation::instantiate::TypeSubstitution;
16use crate::operations::widening;
17use crate::types::{InferencePriority, TemplateSpan, TypeData, TypeId};
18use rustc_hash::FxHashSet;
19use tsz_common::interner::Atom;
20
21struct VarianceState<'a> {
22    target_param: Atom,
23    covariant: &'a mut u32,
24    contravariant: &'a mut u32,
25}
26
27impl<'a> InferenceContext<'a> {
28    // =========================================================================
29    // Bounds Checking and Resolution
30    // =========================================================================
31
32    /// Resolve an inference variable using its collected constraints.
33    ///
34    /// Algorithm:
35    /// 1. If already unified to a concrete type, return that
36    /// 2. Otherwise, compute the best common type from lower bounds
37    /// 3. Validate against upper bounds
38    /// 4. If no lower bounds, use the constraint (upper bound) or default
39    pub fn resolve_with_constraints(
40        &mut self,
41        var: InferenceVar,
42    ) -> Result<TypeId, InferenceError> {
43        // Check if already resolved
44        if let Some(ty) = self.probe(var) {
45            return Ok(ty);
46        }
47
48        let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
49
50        // Validate against upper bounds
51        if !upper_bounds_only {
52            let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
53            if let Some(upper) =
54                self.first_failed_upper_bound(result, &filtered_upper_bounds, |a, b| {
55                    self.is_subtype(a, b)
56                })
57            {
58                return Err(InferenceError::BoundsViolation {
59                    var,
60                    lower: result,
61                    upper,
62                });
63            }
64        }
65
66        if self.occurs_in(root, result) {
67            return Err(InferenceError::OccursCheck {
68                var: root,
69                ty: result,
70            });
71        }
72
73        // Store the result
74        self.table.union_value(
75            root,
76            InferenceInfo {
77                resolved: Some(result),
78                ..InferenceInfo::default()
79            },
80        );
81
82        Ok(result)
83    }
84
85    /// Resolve an inference variable using its collected constraints and a custom
86    /// assignability check for upper-bound validation.
87    pub fn resolve_with_constraints_by<F>(
88        &mut self,
89        var: InferenceVar,
90        is_subtype: F,
91    ) -> Result<TypeId, InferenceError>
92    where
93        F: FnMut(TypeId, TypeId) -> bool,
94    {
95        // Check if already resolved
96        if let Some(ty) = self.probe(var) {
97            return Ok(ty);
98        }
99
100        let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
101
102        if !upper_bounds_only {
103            let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
104            if let Some(upper) =
105                self.first_failed_upper_bound(result, &filtered_upper_bounds, is_subtype)
106            {
107                return Err(InferenceError::BoundsViolation {
108                    var,
109                    lower: result,
110                    upper,
111                });
112            }
113        }
114
115        if self.occurs_in(root, result) {
116            return Err(InferenceError::OccursCheck {
117                var: root,
118                ty: result,
119            });
120        }
121
122        self.table.union_value(
123            root,
124            InferenceInfo {
125                resolved: Some(result),
126                ..InferenceInfo::default()
127            },
128        );
129
130        Ok(result)
131    }
132
133    fn filter_relevant_upper_bounds(upper_bounds: &[TypeId]) -> Vec<TypeId> {
134        upper_bounds
135            .iter()
136            .copied()
137            .filter(|&upper| !matches!(upper, TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR))
138            .collect()
139    }
140
141    fn first_failed_upper_bound<F>(
142        &self,
143        result: TypeId,
144        filtered_upper_bounds: &[TypeId],
145        mut is_subtype: F,
146    ) -> Option<TypeId>
147    where
148        F: FnMut(TypeId, TypeId) -> bool,
149    {
150        match filtered_upper_bounds {
151            [] => None,
152            [single] => (!is_subtype(result, *single)).then_some(*single),
153            many => {
154                // Building and checking a very large synthetic intersection can be
155                // more expensive than directly validating bounds one-by-one.
156                // Keep the intersection shortcut for small/medium bound sets only.
157                if many.len() <= Self::UPPER_BOUND_INTERSECTION_FAST_PATH_LIMIT {
158                    let intersection = self.interner.intersection(many.to_vec());
159                    if is_subtype(result, intersection) {
160                        return None;
161                    }
162                }
163                // For very large upper-bound sets, a single intersection check can
164                // still be profitable in the common success path (all bounds satisfy).
165                // Fall back to per-bound checks if that coarse check fails.
166                if many.len() >= Self::UPPER_BOUND_INTERSECTION_LARGE_SET_THRESHOLD
167                    && self.should_try_large_upper_bound_intersection(result, many)
168                {
169                    let intersection = self.interner.intersection(many.to_vec());
170                    if is_subtype(result, intersection) {
171                        return None;
172                    }
173                }
174                many.iter()
175                    .copied()
176                    .find(|&upper| !is_subtype(result, upper))
177            }
178        }
179    }
180
181    fn should_try_large_upper_bound_intersection(&self, result: TypeId, bounds: &[TypeId]) -> bool {
182        self.is_object_like_upper_bound(result)
183            && bounds
184                .iter()
185                .copied()
186                .all(|bound| self.is_object_like_upper_bound(bound))
187    }
188
189    fn is_object_like_upper_bound(&self, ty: TypeId) -> bool {
190        match self.interner.lookup(ty) {
191            Some(
192                TypeData::Object(_)
193                | TypeData::ObjectWithIndex(_)
194                | TypeData::Lazy(_)
195                | TypeData::Intersection(_),
196            ) => true,
197            Some(TypeData::TypeParameter(info)) => info
198                .constraint
199                .is_some_and(|constraint| self.is_object_like_upper_bound(constraint)),
200            _ => false,
201        }
202    }
203
204    fn compute_constraint_result(
205        &mut self,
206        var: InferenceVar,
207    ) -> (InferenceVar, TypeId, Vec<TypeId>, bool) {
208        let root = self.table.find(var);
209        let info = self.table.probe_value(root);
210        let target_names = self.type_param_names_for_root(root);
211        let mut upper_bounds = Vec::new();
212        let mut seen_upper_bounds = FxHashSet::default();
213        let mut candidates = info.candidates;
214        for bound in info.upper_bounds {
215            if self.occurs_in(root, bound) {
216                continue;
217            }
218            if !target_names.is_empty() && self.upper_bound_cycles_param(bound, &target_names) {
219                self.expand_cyclic_upper_bound(
220                    root,
221                    bound,
222                    &target_names,
223                    &mut candidates,
224                    &mut upper_bounds,
225                );
226                continue;
227            }
228            if seen_upper_bounds.insert(bound) {
229                upper_bounds.push(bound);
230            }
231        }
232
233        if !upper_bounds.is_empty() {
234            candidates.retain(|candidate| {
235                !matches!(
236                    candidate.type_id,
237                    TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR
238                )
239            });
240        }
241
242        // Check if this is a const type parameter to preserve literal types
243        let is_const = self.is_var_const(root);
244
245        let upper_bounds_only = candidates.is_empty() && !upper_bounds.is_empty();
246
247        let result = if !candidates.is_empty() {
248            self.resolve_from_candidates(&candidates, is_const, &upper_bounds)
249        } else if !upper_bounds.is_empty() {
250            // RESTORED: Fall back to upper bounds (constraints) when no candidates exist.
251            // This matches TypeScript: un-inferred generics default to their constraint.
252            // We use intersection in case there are multiple upper bounds (T extends A, T extends B).
253            if upper_bounds.len() == 1 {
254                upper_bounds[0]
255            } else {
256                self.interner.intersection(upper_bounds.clone())
257            }
258        } else {
259            // Only return UNKNOWN if there are NO candidates AND NO upper bounds
260            TypeId::UNKNOWN
261        };
262
263        (root, result, upper_bounds, upper_bounds_only)
264    }
265
266    /// Resolve all type parameters using constraints.
267    pub fn resolve_all_with_constraints(&mut self) -> Result<Vec<(Atom, TypeId)>, InferenceError> {
268        // CRITICAL: Strengthen inter-parameter constraints before resolution
269        // This ensures that constraints flow between dependent type parameters
270        // Example: If T extends U, and T is constrained to string, then U is also
271        // constrained to accept string (string must be assignable to U)
272        self.strengthen_constraints()?;
273
274        let type_params: Vec<_> = self.type_params.clone();
275        let mut results = Vec::new();
276
277        for (name, var, _) in type_params {
278            let ty = self.resolve_with_constraints(var)?;
279            results.push((name, ty));
280        }
281
282        Ok(results)
283    }
284
285    fn resolve_from_candidates(
286        &self,
287        candidates: &[InferenceCandidate],
288        is_const: bool,
289        upper_bounds: &[TypeId],
290    ) -> TypeId {
291        let filtered = self.filter_candidates_by_priority(candidates);
292        if filtered.is_empty() {
293            return TypeId::UNKNOWN;
294        }
295        let filtered_no_never: Vec<_> = filtered
296            .iter()
297            .filter(|c| c.type_id != TypeId::NEVER)
298            .cloned()
299            .collect();
300        if filtered_no_never.is_empty() {
301            return TypeId::NEVER;
302        }
303        let all_from_object_properties = filtered_no_never
304            .iter()
305            .all(|candidate| candidate.from_object_property);
306        // TypeScript preserves literal types when the constraint implies literals
307        // (e.g., T extends "a" | "b"). Widening "b" to string would violate the constraint.
308        let preserve_literals = is_const || self.constraint_implies_literals(upper_bounds);
309        let widened = if preserve_literals {
310            if is_const {
311                filtered_no_never
312                    .iter()
313                    .map(|c| widening::apply_const_assertion(self.interner, c.type_id))
314                    .collect()
315            } else {
316                filtered_no_never.iter().map(|c| c.type_id).collect()
317            }
318        } else {
319            self.widen_candidate_types(&filtered_no_never)
320        };
321        let resolved = self.best_common_type(&widened);
322        if all_from_object_properties
323            && let Some(TypeData::Union(member_list_id)) = self.interner.lookup(resolved)
324        {
325            let member_count = self.interner.type_list(member_list_id).len();
326            if member_count > 1 {
327                let mut first_property_name = None;
328                let mut has_multiple_property_names = false;
329                for candidate in &filtered_no_never {
330                    if let Some(name) = candidate.object_property_name {
331                        if let Some(prev_name) = first_property_name {
332                            if prev_name != name {
333                                has_multiple_property_names = true;
334                                break;
335                            }
336                        } else {
337                            first_property_name = Some(name);
338                        }
339                    } else {
340                        has_multiple_property_names = false;
341                        break;
342                    }
343                }
344
345                if !has_multiple_property_names {
346                    return resolved;
347                }
348
349                if let Some(fallback_idx) = filtered_no_never
350                    .iter()
351                    .enumerate()
352                    .filter_map(|(idx, candidate)| {
353                        candidate.object_property_name.map(|name| {
354                            (
355                                self.interner.resolve_atom_ref(name),
356                                candidate.object_property_index.unwrap_or(u32::MAX),
357                                idx,
358                            )
359                        })
360                    })
361                    .min_by(|(name_l, index_l, _), (name_r, index_r, _)| {
362                        name_l.cmp(name_r).then_with(|| index_l.cmp(index_r))
363                    })
364                    .map(|(_, _, fallback_idx)| fallback_idx)
365                {
366                    return widened[fallback_idx];
367                }
368                return widened[0];
369            }
370        }
371        resolved
372    }
373
374    /// Check if any upper bounds contain or imply literal types.
375    fn constraint_implies_literals(&self, upper_bounds: &[TypeId]) -> bool {
376        upper_bounds
377            .iter()
378            .any(|&bound| self.type_implies_literals(bound))
379    }
380
381    /// Check if a type contains literal types (directly or in unions/intersections).
382    fn type_implies_literals(&self, type_id: TypeId) -> bool {
383        match self.interner.lookup(type_id) {
384            Some(TypeData::Literal(_)) => true,
385            Some(TypeData::Union(list_id)) => {
386                let members = self.interner.type_list(list_id);
387                members.iter().any(|&m| self.type_implies_literals(m))
388            }
389            Some(TypeData::Intersection(list_id)) => {
390                let members = self.interner.type_list(list_id);
391                members.iter().any(|&m| self.type_implies_literals(m))
392            }
393            _ => false,
394        }
395    }
396
397    /// Filter candidates by priority using `InferencePriority`.
398    ///
399    /// CRITICAL FIX: In the new enum, LOWER values = HIGHER priority (processed earlier).
400    /// - `NakedTypeVariable` (1) is highest priority
401    /// - `ReturnType` (32) is lower priority
402    ///
403    /// Therefore we use `.min()` instead of `.max()` to find the highest priority candidate.
404    fn filter_candidates_by_priority(
405        &self,
406        candidates: &[InferenceCandidate],
407    ) -> Vec<InferenceCandidate> {
408        let Some(best_priority) = candidates.iter().map(|c| c.priority).min() else {
409            return Vec::new();
410        };
411        candidates
412            .iter()
413            .filter(|candidate| candidate.priority == best_priority)
414            .cloned()
415            .collect()
416    }
417
418    fn widen_candidate_types(&self, candidates: &[InferenceCandidate]) -> Vec<TypeId> {
419        candidates
420            .iter()
421            .map(|candidate| {
422                // Always widen fresh literal candidates to their base type.
423                // TypeScript widens fresh literals (0 → number, false → boolean)
424                // during inference resolution. Const type parameters are protected
425                // by the is_const check in resolve_from_candidates which uses
426                // apply_const_assertion instead of this method.
427                if candidate.is_fresh_literal {
428                    self.get_base_type(candidate.type_id)
429                        .unwrap_or(candidate.type_id)
430                } else {
431                    candidate.type_id
432                }
433            })
434            .collect()
435    }
436
437    // =========================================================================
438    // Conditional Type Inference
439    // =========================================================================
440
441    /// Infer type parameters from a conditional type.
442    /// When a type parameter appears in a conditional type, we can sometimes
443    /// infer its value from the check and extends clauses.
444    pub fn infer_from_conditional(
445        &mut self,
446        var: InferenceVar,
447        check_type: TypeId,
448        extends_type: TypeId,
449        true_type: TypeId,
450        false_type: TypeId,
451    ) {
452        // If check_type is an inference variable, try to infer from extends_type
453        if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(check_type)
454            && let Some(check_var) = self.find_type_param(info.name)
455            && check_var == self.table.find(var)
456        {
457            // check_type is this variable
458            // Try to infer from extends_type as an upper bound
459            self.add_upper_bound(var, extends_type);
460        }
461
462        // Recursively infer from true/false branches
463        self.infer_from_type(var, true_type);
464        self.infer_from_type(var, false_type);
465    }
466
467    /// Infer type parameters from a type by traversing its structure.
468    fn infer_from_type(&mut self, var: InferenceVar, ty: TypeId) {
469        let root = self.table.find(var);
470
471        // Check if this type contains the inference variable
472        if !self.contains_inference_var(ty, root) {
473            return;
474        }
475
476        match self.interner.lookup(ty) {
477            Some(TypeData::TypeParameter(info)) => {
478                if let Some(param_var) = self.find_type_param(info.name)
479                    && self.table.find(param_var) == root
480                {
481                    // This type is the inference variable itself
482                    // Extract bounds from constraint if present
483                    if let Some(constraint) = info.constraint {
484                        self.add_upper_bound(var, constraint);
485                    }
486                }
487            }
488            Some(TypeData::Array(elem)) => {
489                self.infer_from_type(var, elem);
490            }
491            Some(TypeData::Tuple(elements)) => {
492                let elements = self.interner.tuple_list(elements);
493                for elem in elements.iter() {
494                    self.infer_from_type(var, elem.type_id);
495                }
496            }
497            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
498                let members = self.interner.type_list(members);
499                for &member in members.iter() {
500                    self.infer_from_type(var, member);
501                }
502            }
503            Some(TypeData::Object(shape_id)) => {
504                let shape = self.interner.object_shape(shape_id);
505                for prop in &shape.properties {
506                    self.infer_from_type(var, prop.type_id);
507                }
508            }
509            Some(TypeData::ObjectWithIndex(shape_id)) => {
510                let shape = self.interner.object_shape(shape_id);
511                for prop in &shape.properties {
512                    self.infer_from_type(var, prop.type_id);
513                }
514                if let Some(index) = shape.string_index.as_ref() {
515                    self.infer_from_type(var, index.key_type);
516                    self.infer_from_type(var, index.value_type);
517                }
518                if let Some(index) = shape.number_index.as_ref() {
519                    self.infer_from_type(var, index.key_type);
520                    self.infer_from_type(var, index.value_type);
521                }
522            }
523            Some(TypeData::Application(app_id)) => {
524                let app = self.interner.type_application(app_id);
525                self.infer_from_type(var, app.base);
526                for &arg in &app.args {
527                    self.infer_from_type(var, arg);
528                }
529            }
530            Some(TypeData::Function(shape_id)) => {
531                let shape = self.interner.function_shape(shape_id);
532                for param in &shape.params {
533                    self.infer_from_type(var, param.type_id);
534                }
535                if let Some(this_type) = shape.this_type {
536                    self.infer_from_type(var, this_type);
537                }
538                self.infer_from_type(var, shape.return_type);
539            }
540            Some(TypeData::Conditional(cond_id)) => {
541                let cond = self.interner.conditional_type(cond_id);
542                self.infer_from_conditional(
543                    var,
544                    cond.check_type,
545                    cond.extends_type,
546                    cond.true_type,
547                    cond.false_type,
548                );
549            }
550            Some(TypeData::TemplateLiteral(spans)) => {
551                // Traverse template literal spans to find inference variables
552                let spans = self.interner.template_list(spans);
553                for span in spans.iter() {
554                    if let TemplateSpan::Type(inner) = span {
555                        self.infer_from_type(var, *inner);
556                    }
557                }
558            }
559            _ => {}
560        }
561    }
562
563    /// Check if a type contains an inference variable.
564    pub(crate) fn contains_inference_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
565        let mut visited = FxHashSet::default();
566        self.contains_inference_var_inner(ty, var, &mut visited, 0)
567    }
568
569    fn contains_inference_var_inner(
570        &mut self,
571        ty: TypeId,
572        var: InferenceVar,
573        visited: &mut FxHashSet<TypeId>,
574        depth: usize,
575    ) -> bool {
576        // Safety limit to prevent infinite recursion on deeply nested or cyclic types
577        if depth > MAX_TYPE_RECURSION_DEPTH {
578            return false;
579        }
580        // Prevent infinite loops on cyclic types
581        if !visited.insert(ty) {
582            return false;
583        }
584
585        let root = self.table.find(var);
586
587        match self.interner.lookup(ty) {
588            Some(TypeData::TypeParameter(info) | TypeData::Infer(info)) => {
589                if let Some(param_var) = self.find_type_param(info.name) {
590                    self.table.find(param_var) == root
591                } else {
592                    false
593                }
594            }
595            Some(TypeData::Array(elem)) => {
596                self.contains_inference_var_inner(elem, var, visited, depth + 1)
597            }
598            Some(TypeData::Tuple(elements)) => {
599                let elements = self.interner.tuple_list(elements);
600                elements
601                    .iter()
602                    .any(|e| self.contains_inference_var_inner(e.type_id, var, visited, depth + 1))
603            }
604            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
605                let members = self.interner.type_list(members);
606                members
607                    .iter()
608                    .any(|&m| self.contains_inference_var_inner(m, var, visited, depth + 1))
609            }
610            Some(TypeData::Object(shape_id)) => {
611                let shape = self.interner.object_shape(shape_id);
612                shape
613                    .properties
614                    .iter()
615                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
616            }
617            Some(TypeData::ObjectWithIndex(shape_id)) => {
618                let shape = self.interner.object_shape(shape_id);
619                shape
620                    .properties
621                    .iter()
622                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
623                    || shape.string_index.as_ref().is_some_and(|idx| {
624                        self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
625                            || self.contains_inference_var_inner(
626                                idx.value_type,
627                                var,
628                                visited,
629                                depth + 1,
630                            )
631                    })
632                    || shape.number_index.as_ref().is_some_and(|idx| {
633                        self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
634                            || self.contains_inference_var_inner(
635                                idx.value_type,
636                                var,
637                                visited,
638                                depth + 1,
639                            )
640                    })
641            }
642            Some(TypeData::Application(app_id)) => {
643                let app = self.interner.type_application(app_id);
644                self.contains_inference_var_inner(app.base, var, visited, depth + 1)
645                    || app
646                        .args
647                        .iter()
648                        .any(|&arg| self.contains_inference_var_inner(arg, var, visited, depth + 1))
649            }
650            Some(TypeData::Function(shape_id)) => {
651                let shape = self.interner.function_shape(shape_id);
652                shape
653                    .params
654                    .iter()
655                    .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
656                    || shape.this_type.is_some_and(|t| {
657                        self.contains_inference_var_inner(t, var, visited, depth + 1)
658                    })
659                    || self.contains_inference_var_inner(shape.return_type, var, visited, depth + 1)
660            }
661            Some(TypeData::Conditional(cond_id)) => {
662                let cond = self.interner.conditional_type(cond_id);
663                self.contains_inference_var_inner(cond.check_type, var, visited, depth + 1)
664                    || self.contains_inference_var_inner(cond.extends_type, var, visited, depth + 1)
665                    || self.contains_inference_var_inner(cond.true_type, var, visited, depth + 1)
666                    || self.contains_inference_var_inner(cond.false_type, var, visited, depth + 1)
667            }
668            Some(TypeData::TemplateLiteral(spans)) => {
669                let spans = self.interner.template_list(spans);
670                spans.iter().any(|span| match span {
671                    TemplateSpan::Text(_) => false,
672                    TemplateSpan::Type(inner) => {
673                        self.contains_inference_var_inner(*inner, var, visited, depth + 1)
674                    }
675                })
676            }
677            _ => false,
678        }
679    }
680
681    // =========================================================================
682    // Variance Inference
683    // =========================================================================
684
685    /// Compute the variance of a type parameter within a type.
686    /// Returns (`covariant_count`, `contravariant_count`, `invariant_count`, `bivariant_count`)
687    pub fn compute_variance(&self, ty: TypeId, target_param: Atom) -> (u32, u32, u32, u32) {
688        let mut covariant = 0u32;
689        let mut contravariant = 0u32;
690        let invariant = 0u32;
691        let bivariant = 0u32;
692        let mut state = VarianceState {
693            target_param,
694            covariant: &mut covariant,
695            contravariant: &mut contravariant,
696        };
697
698        self.compute_variance_helper(ty, true, &mut state);
699
700        (covariant, contravariant, invariant, bivariant)
701    }
702
703    fn compute_variance_helper(
704        &self,
705        ty: TypeId,
706        polarity: bool, // true = covariant, false = contravariant
707        state: &mut VarianceState<'_>,
708    ) {
709        match self.interner.lookup(ty) {
710            Some(TypeData::TypeParameter(info)) if info.name == state.target_param => {
711                if polarity {
712                    *state.covariant += 1;
713                } else {
714                    *state.contravariant += 1;
715                }
716            }
717            Some(TypeData::Array(elem)) => {
718                self.compute_variance_helper(elem, polarity, state);
719            }
720            Some(TypeData::Tuple(elements)) => {
721                let elements = self.interner.tuple_list(elements);
722                for elem in elements.iter() {
723                    self.compute_variance_helper(elem.type_id, polarity, state);
724                }
725            }
726            Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
727                let members = self.interner.type_list(members);
728                for &member in members.iter() {
729                    self.compute_variance_helper(member, polarity, state);
730                }
731            }
732            Some(TypeData::Object(shape_id)) => {
733                let shape = self.interner.object_shape(shape_id);
734                for prop in &shape.properties {
735                    // Properties are covariant in their type (read position)
736                    self.compute_variance_helper(prop.type_id, polarity, state);
737                    // Properties are contravariant in their write type (write position)
738                    if prop.write_type != prop.type_id && !prop.readonly {
739                        self.compute_variance_helper(prop.write_type, !polarity, state);
740                    }
741                }
742            }
743            Some(TypeData::ObjectWithIndex(shape_id)) => {
744                let shape = self.interner.object_shape(shape_id);
745                for prop in &shape.properties {
746                    self.compute_variance_helper(prop.type_id, polarity, state);
747                    if prop.write_type != prop.type_id && !prop.readonly {
748                        self.compute_variance_helper(prop.write_type, !polarity, state);
749                    }
750                }
751                if let Some(index) = shape.string_index.as_ref() {
752                    self.compute_variance_helper(index.value_type, polarity, state);
753                }
754                if let Some(index) = shape.number_index.as_ref() {
755                    self.compute_variance_helper(index.value_type, polarity, state);
756                }
757            }
758            Some(TypeData::Application(app_id)) => {
759                let app = self.interner.type_application(app_id);
760                // Variance depends on the generic type definition
761                // For now, assume covariant for all type arguments
762                for &arg in &app.args {
763                    self.compute_variance_helper(arg, polarity, state);
764                }
765            }
766            Some(TypeData::Function(shape_id)) => {
767                let shape = self.interner.function_shape(shape_id);
768                // Parameters are contravariant
769                for param in &shape.params {
770                    self.compute_variance_helper(param.type_id, !polarity, state);
771                }
772                // Return type is covariant
773                self.compute_variance_helper(shape.return_type, polarity, state);
774            }
775            Some(TypeData::Conditional(cond_id)) => {
776                let cond = self.interner.conditional_type(cond_id);
777                // Conditional types are invariant in their type parameters
778                self.compute_variance_helper(cond.check_type, false, state);
779                self.compute_variance_helper(cond.extends_type, false, state);
780                // But can be either in the result
781                self.compute_variance_helper(cond.true_type, polarity, state);
782                self.compute_variance_helper(cond.false_type, polarity, state);
783            }
784            _ => {}
785        }
786    }
787
788    /// Check if a type parameter is invariant at a given position.
789    pub fn is_invariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
790        let (_, _, invariant, _) = self.compute_variance(ty, target_param);
791        invariant > 0
792    }
793
794    /// Check if a type parameter is bivariant at a given position.
795    pub fn is_bivariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
796        let (_, _, _, bivariant) = self.compute_variance(ty, target_param);
797        bivariant > 0
798    }
799
800    /// Get the variance of a type parameter as a string.
801    pub fn get_variance(&self, ty: TypeId, target_param: Atom) -> &'static str {
802        let (covariant, contravariant, invariant, bivariant) =
803            self.compute_variance(ty, target_param);
804
805        if invariant > 0 {
806            "invariant"
807        } else if bivariant > 0 {
808            "bivariant"
809        } else if covariant > 0 && contravariant > 0 {
810            "invariant" // Both covariant and contravariant means invariant
811        } else if covariant > 0 {
812            "covariant"
813        } else if contravariant > 0 {
814            "contravariant"
815        } else {
816            "unused"
817        }
818    }
819
820    // =========================================================================
821    // Enhanced Constraint Resolution
822    // =========================================================================
823
824    /// Try to infer a type parameter from its usage context.
825    /// This implements bidirectional type inference where the context
826    /// (e.g., return type, variable declaration) provides constraints.
827    pub fn infer_from_context(
828        &mut self,
829        var: InferenceVar,
830        context_type: TypeId,
831    ) -> Result<(), InferenceError> {
832        // Add context as an upper bound
833        self.add_upper_bound(var, context_type);
834
835        // If the context type contains this inference variable,
836        // we need to solve more carefully
837        let root = self.table.find(var);
838        if self.contains_inference_var(context_type, root) {
839            // Context contains the inference variable itself
840            // This is a recursive type - we need to handle it specially
841            return Err(InferenceError::OccursCheck {
842                var: root,
843                ty: context_type,
844            });
845        }
846
847        Ok(())
848    }
849
850    /// Detect and unify type parameters that form circular constraints.
851    /// For example, if T extends U and U extends T, they should be unified
852    /// into a single equivalence class for inference purposes.
853    fn unify_circular_constraints(&mut self) -> Result<(), InferenceError> {
854        use rustc_hash::{FxHashMap, FxHashSet};
855
856        let type_params: Vec<_> = self.type_params.clone();
857
858        // Build adjacency list: var -> set of vars it extends (upper bounds)
859        let mut graph: FxHashMap<InferenceVar, FxHashSet<InferenceVar>> = FxHashMap::default();
860        let mut var_for_param: FxHashMap<Atom, InferenceVar> = FxHashMap::default();
861
862        for (name, var, _) in &type_params {
863            let root = self.table.find(*var);
864            var_for_param.insert(*name, root);
865            graph.entry(root).or_default();
866        }
867
868        // Populate edges based on upper_bounds
869        for (_name, var, _) in &type_params {
870            let root = self.table.find(*var);
871            let info = self.table.probe_value(root);
872
873            for &upper in &info.upper_bounds {
874                // Only follow naked type parameter upper bounds (not List<T>, etc.)
875                if let Some(TypeData::TypeParameter(param_info)) = self.interner.lookup(upper)
876                    && let Some(&upper_var) = var_for_param.get(&param_info.name)
877                {
878                    let upper_root = self.table.find(upper_var);
879                    // Add edge: root extends upper_root
880                    graph.entry(root).or_default().insert(upper_root);
881                }
882            }
883        }
884
885        // Find SCCs using Tarjan's algorithm
886        let mut index_counter = 0;
887        let mut indices: FxHashMap<InferenceVar, usize> = FxHashMap::default();
888        let mut lowlink: FxHashMap<InferenceVar, usize> = FxHashMap::default();
889        let mut stack: Vec<InferenceVar> = Vec::new();
890        let mut on_stack: FxHashSet<InferenceVar> = FxHashSet::default();
891        let mut sccs: Vec<Vec<InferenceVar>> = Vec::new();
892
893        struct TarjanState<'a> {
894            graph: &'a FxHashMap<InferenceVar, FxHashSet<InferenceVar>>,
895            index_counter: &'a mut usize,
896            indices: &'a mut FxHashMap<InferenceVar, usize>,
897            lowlink: &'a mut FxHashMap<InferenceVar, usize>,
898            stack: &'a mut Vec<InferenceVar>,
899            on_stack: &'a mut FxHashSet<InferenceVar>,
900            sccs: &'a mut Vec<Vec<InferenceVar>>,
901        }
902
903        fn strongconnect(var: InferenceVar, state: &mut TarjanState) {
904            state.indices.insert(var, *state.index_counter);
905            state.lowlink.insert(var, *state.index_counter);
906            *state.index_counter += 1;
907            state.stack.push(var);
908            state.on_stack.insert(var);
909
910            if let Some(neighbors) = state.graph.get(&var) {
911                for &neighbor in neighbors {
912                    if !state.indices.contains_key(&neighbor) {
913                        strongconnect(neighbor, state);
914                        let neighbor_low = *state.lowlink.get(&neighbor).unwrap_or(&0);
915                        let var_low = state.lowlink.get_mut(&var).unwrap();
916                        *var_low = (*var_low).min(neighbor_low);
917                    } else if state.on_stack.contains(&neighbor) {
918                        let neighbor_idx = *state.indices.get(&neighbor).unwrap_or(&0);
919                        let var_low = state.lowlink.get_mut(&var).unwrap();
920                        *var_low = (*var_low).min(neighbor_idx);
921                    }
922                }
923            }
924
925            if *state.lowlink.get(&var).unwrap_or(&0) == *state.indices.get(&var).unwrap_or(&0) {
926                let mut scc = Vec::new();
927                loop {
928                    let w = state.stack.pop().unwrap();
929                    state.on_stack.remove(&w);
930                    scc.push(w);
931                    if w == var {
932                        break;
933                    }
934                }
935                state.sccs.push(scc);
936            }
937        }
938
939        // Run Tarjan's on all nodes
940        for &var in graph.keys() {
941            if !indices.contains_key(&var) {
942                let mut state = TarjanState {
943                    graph: &graph,
944                    index_counter: &mut index_counter,
945                    indices: &mut indices,
946                    lowlink: &mut lowlink,
947                    stack: &mut stack,
948                    on_stack: &mut on_stack,
949                    sccs: &mut sccs,
950                };
951                strongconnect(var, &mut state);
952            }
953        }
954
955        // Unify variables within each SCC (if SCC has >1 member)
956        for scc in sccs {
957            if scc.len() > 1 {
958                // Unify all variables in this SCC
959                let first = scc[0];
960                for &other in &scc[1..] {
961                    self.unify_vars(first, other)?;
962                }
963            }
964        }
965
966        Ok(())
967    }
968
969    /// Strengthen constraints by analyzing relationships between type parameters.
970    /// For example, if T <: U and we know T = string, then U must be at least string.
971    pub fn strengthen_constraints(&mut self) -> Result<(), InferenceError> {
972        // Detect and unify circular constraints (SCCs)
973        // This ensures that type parameters in cycles (T extends U, U extends T)
974        // are treated as a single equivalence class for inference.
975        self.unify_circular_constraints()?;
976
977        let type_params: Vec<_> = self.type_params.clone();
978        let mut changed = true;
979        let mut iterations = 0;
980
981        // Fixed-point propagation
982        // Iterate to fixed point - continue until no new candidates are added
983        while changed && iterations < MAX_CONSTRAINT_ITERATIONS {
984            changed = false;
985            iterations += 1;
986
987            for (name, var, _) in &type_params {
988                let root = self.table.find(*var);
989
990                // We need to clone info to avoid borrow checker issues while mutating
991                // This is expensive but necessary for correctness in this design
992                let info = self.table.probe_value(root).clone();
993
994                // Propagate candidates UP the extends chain
995                // If T extends U (T <: U), then candidates of T are also candidates of U
996                for &upper in &info.upper_bounds {
997                    if self.propagate_candidates_to_upper(root, upper, *name)? {
998                        changed = true;
999                    }
1000                }
1001            }
1002        }
1003        Ok(())
1004    }
1005
1006    /// Propagates candidates from a subtype (var) to its supertype (upper).
1007    /// If `var extends upper` (var <: upper), then candidates of `var` are also candidates of `upper`.
1008    fn propagate_candidates_to_upper(
1009        &mut self,
1010        var_root: InferenceVar,
1011        upper: TypeId,
1012        exclude_param: Atom,
1013    ) -> Result<bool, InferenceError> {
1014        // Check if 'upper' is a type parameter we are inferring
1015        if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(upper)
1016            && info.name != exclude_param
1017            && let Some(upper_var) = self.find_type_param(info.name)
1018        {
1019            let upper_root = self.table.find(upper_var);
1020
1021            // Don't propagate to self
1022            if var_root == upper_root {
1023                return Ok(false);
1024            }
1025
1026            // Get candidates from the subtype (var)
1027            let var_candidates = self.table.probe_value(var_root).candidates;
1028
1029            // Add them to the supertype (upper)
1030            let mut changed = false;
1031            for candidate in var_candidates {
1032                // Use Circular priority to indicate this came from propagation
1033                if self.add_candidate_if_new(
1034                    upper_root,
1035                    candidate.type_id,
1036                    InferencePriority::Circular,
1037                ) {
1038                    changed = true;
1039                }
1040            }
1041            return Ok(changed);
1042        }
1043        Ok(false)
1044    }
1045
1046    /// Helper to track if we actually added something (for fixed-point loop)
1047    fn add_candidate_if_new(
1048        &mut self,
1049        var: InferenceVar,
1050        ty: TypeId,
1051        priority: InferencePriority,
1052    ) -> bool {
1053        let root = self.table.find(var);
1054        let info = self.table.probe_value(root);
1055
1056        // Check if type already exists in candidates
1057        if info.candidates.iter().any(|c| c.type_id == ty) {
1058            return false;
1059        }
1060
1061        self.add_candidate(var, ty, priority);
1062        true
1063    }
1064
1065    /// Validate that resolved types respect variance constraints.
1066    pub fn validate_variance(&mut self) -> Result<(), InferenceError> {
1067        let type_params: Vec<_> = self.type_params.clone();
1068        for (_name, var, _) in &type_params {
1069            let resolved = match self.probe(*var) {
1070                Some(ty) => ty,
1071                None => continue,
1072            };
1073
1074            // Check if this type parameter appears in its own resolved type
1075            // We use the occurs_in method which already exists and handles this
1076            if self.occurs_in(*var, resolved) {
1077                let root = self.table.find(*var);
1078                // This would be a circular reference
1079                return Err(InferenceError::OccursCheck {
1080                    var: root,
1081                    ty: resolved,
1082                });
1083            }
1084
1085            // For more advanced variance checking, we would need to know
1086            // the declared variance of each type parameter in its generic type
1087            // This is a placeholder for future enhancement
1088        }
1089
1090        Ok(())
1091    }
1092
1093    /// Fix (resolve) inference variables that have candidates from Round 1.
1094    ///
1095    /// This is called after processing non-contextual arguments to "fix" type
1096    /// variables that have enough information, before processing contextual
1097    /// arguments (like lambdas) in Round 2.
1098    ///
1099    /// The fixing process:
1100    /// 1. Finds variables with candidates but no resolved type yet
1101    /// 2. Computes their best current type from candidates
1102    /// 3. Sets the `resolved` field to prevent Round 2 from overriding
1103    ///
1104    /// Variables without candidates are NOT fixed (they might get info from Round 2).
1105    pub fn fix_current_variables(&mut self) -> Result<(), InferenceError> {
1106        let type_params: Vec<_> = self.type_params.clone();
1107
1108        for (_name, var, _is_const) in &type_params {
1109            let root = self.table.find(*var);
1110            let info = self.table.probe_value(root);
1111
1112            // Skip if already resolved
1113            if info.resolved.is_some() {
1114                continue;
1115            }
1116
1117            // Skip if no candidates yet (might get info from Round 2)
1118            if info.candidates.is_empty() {
1119                continue;
1120            }
1121
1122            // Compute the current best type from existing candidates
1123            // This uses the same logic as compute_constraint_result but doesn't
1124            // validate against upper bounds yet (that happens in final resolution)
1125            let is_const = self.is_var_const(root);
1126            let result =
1127                self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds);
1128
1129            // Check for occurs (recursive type)
1130            if self.occurs_in(root, result) {
1131                // Don't fix variables with occurs - let them be resolved later
1132                continue;
1133            }
1134
1135            // Fix this variable by setting resolved field
1136            // This prevents Round 2 from overriding with lower-priority constraints
1137            self.table.union_value(
1138                root,
1139                InferenceInfo {
1140                    resolved: Some(result),
1141                    // Keep candidates and upper_bounds for later validation
1142                    candidates: info.candidates,
1143                    upper_bounds: info.upper_bounds,
1144                },
1145            );
1146        }
1147
1148        Ok(())
1149    }
1150
1151    /// Get the current best substitution for all type parameters.
1152    ///
1153    /// This returns a `TypeSubstitution` mapping each type parameter to its
1154    /// current best type (either resolved or the best candidate so far).
1155    /// Used in Round 2 to provide contextual types to lambda arguments.
1156    pub fn get_current_substitution(&mut self) -> TypeSubstitution {
1157        let mut subst = TypeSubstitution::new();
1158        let type_params: Vec<_> = self.type_params.clone();
1159
1160        for (name, var, _) in &type_params {
1161            let ty = match self.probe(*var) {
1162                Some(resolved) => {
1163                    tracing::trace!(
1164                        ?name,
1165                        ?var,
1166                        ?resolved,
1167                        "get_current_substitution: already resolved"
1168                    );
1169                    resolved
1170                }
1171                None => {
1172                    // Not resolved yet, try to get best candidate
1173                    let root = self.table.find(*var);
1174                    let info = self.table.probe_value(root);
1175                    tracing::trace!(
1176                        ?name, ?var,
1177                        candidates_count = info.candidates.len(),
1178                        upper_bounds_count = info.upper_bounds.len(),
1179                        upper_bounds = ?info.upper_bounds,
1180                        "get_current_substitution: not resolved"
1181                    );
1182
1183                    if !info.candidates.is_empty() {
1184                        let is_const = self.is_var_const(root);
1185                        self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds)
1186                    } else if !info.upper_bounds.is_empty() {
1187                        // No candidates yet, but we have a constraint (upper bound).
1188                        // Use the constraint as contextual fallback so that mapped types
1189                        // like `{ [K in keyof P]: P[K] }` resolve using the constraint
1190                        // type. This matches tsc's behavior for contextual typing of
1191                        // generic call arguments when all arguments are context-sensitive.
1192                        if info.upper_bounds.len() == 1 {
1193                            info.upper_bounds[0]
1194                        } else {
1195                            self.interner.intersection(info.upper_bounds.to_vec())
1196                        }
1197                    } else {
1198                        // No info yet, use unknown as placeholder
1199                        TypeId::UNKNOWN
1200                    }
1201                }
1202            };
1203
1204            subst.insert(*name, ty);
1205        }
1206
1207        subst
1208    }
1209}