1use crate::inference::infer::{
12 InferenceCandidate, InferenceContext, InferenceError, InferenceInfo, InferenceVar,
13 MAX_CONSTRAINT_ITERATIONS, MAX_TYPE_RECURSION_DEPTH,
14};
15use crate::instantiation::instantiate::TypeSubstitution;
16use crate::operations::widening;
17use crate::types::{InferencePriority, TemplateSpan, TypeData, TypeId};
18use rustc_hash::FxHashSet;
19use tsz_common::interner::Atom;
20
21struct VarianceState<'a> {
22 target_param: Atom,
23 covariant: &'a mut u32,
24 contravariant: &'a mut u32,
25}
26
27impl<'a> InferenceContext<'a> {
28 pub fn resolve_with_constraints(
40 &mut self,
41 var: InferenceVar,
42 ) -> Result<TypeId, InferenceError> {
43 if let Some(ty) = self.probe(var) {
45 return Ok(ty);
46 }
47
48 let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
49
50 if !upper_bounds_only {
52 let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
53 if let Some(upper) =
54 self.first_failed_upper_bound(result, &filtered_upper_bounds, |a, b| {
55 self.is_subtype(a, b)
56 })
57 {
58 return Err(InferenceError::BoundsViolation {
59 var,
60 lower: result,
61 upper,
62 });
63 }
64 }
65
66 if self.occurs_in(root, result) {
67 return Err(InferenceError::OccursCheck {
68 var: root,
69 ty: result,
70 });
71 }
72
73 self.table.union_value(
75 root,
76 InferenceInfo {
77 resolved: Some(result),
78 ..InferenceInfo::default()
79 },
80 );
81
82 Ok(result)
83 }
84
85 pub fn resolve_with_constraints_by<F>(
88 &mut self,
89 var: InferenceVar,
90 is_subtype: F,
91 ) -> Result<TypeId, InferenceError>
92 where
93 F: FnMut(TypeId, TypeId) -> bool,
94 {
95 if let Some(ty) = self.probe(var) {
97 return Ok(ty);
98 }
99
100 let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
101
102 if !upper_bounds_only {
103 let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
104 if let Some(upper) =
105 self.first_failed_upper_bound(result, &filtered_upper_bounds, is_subtype)
106 {
107 return Err(InferenceError::BoundsViolation {
108 var,
109 lower: result,
110 upper,
111 });
112 }
113 }
114
115 if self.occurs_in(root, result) {
116 return Err(InferenceError::OccursCheck {
117 var: root,
118 ty: result,
119 });
120 }
121
122 self.table.union_value(
123 root,
124 InferenceInfo {
125 resolved: Some(result),
126 ..InferenceInfo::default()
127 },
128 );
129
130 Ok(result)
131 }
132
133 fn filter_relevant_upper_bounds(upper_bounds: &[TypeId]) -> Vec<TypeId> {
134 upper_bounds
135 .iter()
136 .copied()
137 .filter(|&upper| !matches!(upper, TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR))
138 .collect()
139 }
140
141 fn first_failed_upper_bound<F>(
142 &self,
143 result: TypeId,
144 filtered_upper_bounds: &[TypeId],
145 mut is_subtype: F,
146 ) -> Option<TypeId>
147 where
148 F: FnMut(TypeId, TypeId) -> bool,
149 {
150 match filtered_upper_bounds {
151 [] => None,
152 [single] => (!is_subtype(result, *single)).then_some(*single),
153 many => {
154 if many.len() <= Self::UPPER_BOUND_INTERSECTION_FAST_PATH_LIMIT {
158 let intersection = self.interner.intersection(many.to_vec());
159 if is_subtype(result, intersection) {
160 return None;
161 }
162 }
163 if many.len() >= Self::UPPER_BOUND_INTERSECTION_LARGE_SET_THRESHOLD
167 && self.should_try_large_upper_bound_intersection(result, many)
168 {
169 let intersection = self.interner.intersection(many.to_vec());
170 if is_subtype(result, intersection) {
171 return None;
172 }
173 }
174 many.iter()
175 .copied()
176 .find(|&upper| !is_subtype(result, upper))
177 }
178 }
179 }
180
181 fn should_try_large_upper_bound_intersection(&self, result: TypeId, bounds: &[TypeId]) -> bool {
182 self.is_object_like_upper_bound(result)
183 && bounds
184 .iter()
185 .copied()
186 .all(|bound| self.is_object_like_upper_bound(bound))
187 }
188
189 fn is_object_like_upper_bound(&self, ty: TypeId) -> bool {
190 match self.interner.lookup(ty) {
191 Some(
192 TypeData::Object(_)
193 | TypeData::ObjectWithIndex(_)
194 | TypeData::Lazy(_)
195 | TypeData::Intersection(_),
196 ) => true,
197 Some(TypeData::TypeParameter(info)) => info
198 .constraint
199 .is_some_and(|constraint| self.is_object_like_upper_bound(constraint)),
200 _ => false,
201 }
202 }
203
204 fn compute_constraint_result(
205 &mut self,
206 var: InferenceVar,
207 ) -> (InferenceVar, TypeId, Vec<TypeId>, bool) {
208 let root = self.table.find(var);
209 let info = self.table.probe_value(root);
210 let target_names = self.type_param_names_for_root(root);
211 let mut upper_bounds = Vec::new();
212 let mut seen_upper_bounds = FxHashSet::default();
213 let mut candidates = info.candidates;
214 for bound in info.upper_bounds {
215 if self.occurs_in(root, bound) {
216 continue;
217 }
218 if !target_names.is_empty() && self.upper_bound_cycles_param(bound, &target_names) {
219 self.expand_cyclic_upper_bound(
220 root,
221 bound,
222 &target_names,
223 &mut candidates,
224 &mut upper_bounds,
225 );
226 continue;
227 }
228 if seen_upper_bounds.insert(bound) {
229 upper_bounds.push(bound);
230 }
231 }
232
233 if !upper_bounds.is_empty() {
234 candidates.retain(|candidate| {
235 !matches!(
236 candidate.type_id,
237 TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR
238 )
239 });
240 }
241
242 let is_const = self.is_var_const(root);
244
245 let upper_bounds_only = candidates.is_empty() && !upper_bounds.is_empty();
246
247 let result = if !candidates.is_empty() {
248 self.resolve_from_candidates(&candidates, is_const, &upper_bounds)
249 } else if !upper_bounds.is_empty() {
250 if upper_bounds.len() == 1 {
254 upper_bounds[0]
255 } else {
256 self.interner.intersection(upper_bounds.clone())
257 }
258 } else {
259 TypeId::UNKNOWN
261 };
262
263 (root, result, upper_bounds, upper_bounds_only)
264 }
265
266 pub fn resolve_all_with_constraints(&mut self) -> Result<Vec<(Atom, TypeId)>, InferenceError> {
268 self.strengthen_constraints()?;
273
274 let type_params: Vec<_> = self.type_params.clone();
275 let mut results = Vec::new();
276
277 for (name, var, _) in type_params {
278 let ty = self.resolve_with_constraints(var)?;
279 results.push((name, ty));
280 }
281
282 Ok(results)
283 }
284
285 fn resolve_from_candidates(
286 &self,
287 candidates: &[InferenceCandidate],
288 is_const: bool,
289 upper_bounds: &[TypeId],
290 ) -> TypeId {
291 let filtered = self.filter_candidates_by_priority(candidates);
292 if filtered.is_empty() {
293 return TypeId::UNKNOWN;
294 }
295 let filtered_no_never: Vec<_> = filtered
296 .iter()
297 .filter(|c| c.type_id != TypeId::NEVER)
298 .cloned()
299 .collect();
300 if filtered_no_never.is_empty() {
301 return TypeId::NEVER;
302 }
303 let all_from_object_properties = filtered_no_never
304 .iter()
305 .all(|candidate| candidate.from_object_property);
306 let preserve_literals = is_const || self.constraint_implies_literals(upper_bounds);
309 let widened = if preserve_literals {
310 if is_const {
311 filtered_no_never
312 .iter()
313 .map(|c| widening::apply_const_assertion(self.interner, c.type_id))
314 .collect()
315 } else {
316 filtered_no_never.iter().map(|c| c.type_id).collect()
317 }
318 } else {
319 self.widen_candidate_types(&filtered_no_never)
320 };
321 let resolved = self.best_common_type(&widened);
322 if all_from_object_properties
323 && let Some(TypeData::Union(member_list_id)) = self.interner.lookup(resolved)
324 {
325 let member_count = self.interner.type_list(member_list_id).len();
326 if member_count > 1 {
327 let mut first_property_name = None;
328 let mut has_multiple_property_names = false;
329 for candidate in &filtered_no_never {
330 if let Some(name) = candidate.object_property_name {
331 if let Some(prev_name) = first_property_name {
332 if prev_name != name {
333 has_multiple_property_names = true;
334 break;
335 }
336 } else {
337 first_property_name = Some(name);
338 }
339 } else {
340 has_multiple_property_names = false;
341 break;
342 }
343 }
344
345 if !has_multiple_property_names {
346 return resolved;
347 }
348
349 if let Some(fallback_idx) = filtered_no_never
350 .iter()
351 .enumerate()
352 .filter_map(|(idx, candidate)| {
353 candidate.object_property_name.map(|name| {
354 (
355 self.interner.resolve_atom_ref(name),
356 candidate.object_property_index.unwrap_or(u32::MAX),
357 idx,
358 )
359 })
360 })
361 .min_by(|(name_l, index_l, _), (name_r, index_r, _)| {
362 name_l.cmp(name_r).then_with(|| index_l.cmp(index_r))
363 })
364 .map(|(_, _, fallback_idx)| fallback_idx)
365 {
366 return widened[fallback_idx];
367 }
368 return widened[0];
369 }
370 }
371 resolved
372 }
373
374 fn constraint_implies_literals(&self, upper_bounds: &[TypeId]) -> bool {
376 upper_bounds
377 .iter()
378 .any(|&bound| self.type_implies_literals(bound))
379 }
380
381 fn type_implies_literals(&self, type_id: TypeId) -> bool {
383 match self.interner.lookup(type_id) {
384 Some(TypeData::Literal(_)) => true,
385 Some(TypeData::Union(list_id)) => {
386 let members = self.interner.type_list(list_id);
387 members.iter().any(|&m| self.type_implies_literals(m))
388 }
389 Some(TypeData::Intersection(list_id)) => {
390 let members = self.interner.type_list(list_id);
391 members.iter().any(|&m| self.type_implies_literals(m))
392 }
393 _ => false,
394 }
395 }
396
397 fn filter_candidates_by_priority(
405 &self,
406 candidates: &[InferenceCandidate],
407 ) -> Vec<InferenceCandidate> {
408 let Some(best_priority) = candidates.iter().map(|c| c.priority).min() else {
409 return Vec::new();
410 };
411 candidates
412 .iter()
413 .filter(|candidate| candidate.priority == best_priority)
414 .cloned()
415 .collect()
416 }
417
418 fn widen_candidate_types(&self, candidates: &[InferenceCandidate]) -> Vec<TypeId> {
419 candidates
420 .iter()
421 .map(|candidate| {
422 if candidate.is_fresh_literal {
428 self.get_base_type(candidate.type_id)
429 .unwrap_or(candidate.type_id)
430 } else {
431 candidate.type_id
432 }
433 })
434 .collect()
435 }
436
437 pub fn infer_from_conditional(
445 &mut self,
446 var: InferenceVar,
447 check_type: TypeId,
448 extends_type: TypeId,
449 true_type: TypeId,
450 false_type: TypeId,
451 ) {
452 if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(check_type)
454 && let Some(check_var) = self.find_type_param(info.name)
455 && check_var == self.table.find(var)
456 {
457 self.add_upper_bound(var, extends_type);
460 }
461
462 self.infer_from_type(var, true_type);
464 self.infer_from_type(var, false_type);
465 }
466
467 fn infer_from_type(&mut self, var: InferenceVar, ty: TypeId) {
469 let root = self.table.find(var);
470
471 if !self.contains_inference_var(ty, root) {
473 return;
474 }
475
476 match self.interner.lookup(ty) {
477 Some(TypeData::TypeParameter(info)) => {
478 if let Some(param_var) = self.find_type_param(info.name)
479 && self.table.find(param_var) == root
480 {
481 if let Some(constraint) = info.constraint {
484 self.add_upper_bound(var, constraint);
485 }
486 }
487 }
488 Some(TypeData::Array(elem)) => {
489 self.infer_from_type(var, elem);
490 }
491 Some(TypeData::Tuple(elements)) => {
492 let elements = self.interner.tuple_list(elements);
493 for elem in elements.iter() {
494 self.infer_from_type(var, elem.type_id);
495 }
496 }
497 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
498 let members = self.interner.type_list(members);
499 for &member in members.iter() {
500 self.infer_from_type(var, member);
501 }
502 }
503 Some(TypeData::Object(shape_id)) => {
504 let shape = self.interner.object_shape(shape_id);
505 for prop in &shape.properties {
506 self.infer_from_type(var, prop.type_id);
507 }
508 }
509 Some(TypeData::ObjectWithIndex(shape_id)) => {
510 let shape = self.interner.object_shape(shape_id);
511 for prop in &shape.properties {
512 self.infer_from_type(var, prop.type_id);
513 }
514 if let Some(index) = shape.string_index.as_ref() {
515 self.infer_from_type(var, index.key_type);
516 self.infer_from_type(var, index.value_type);
517 }
518 if let Some(index) = shape.number_index.as_ref() {
519 self.infer_from_type(var, index.key_type);
520 self.infer_from_type(var, index.value_type);
521 }
522 }
523 Some(TypeData::Application(app_id)) => {
524 let app = self.interner.type_application(app_id);
525 self.infer_from_type(var, app.base);
526 for &arg in &app.args {
527 self.infer_from_type(var, arg);
528 }
529 }
530 Some(TypeData::Function(shape_id)) => {
531 let shape = self.interner.function_shape(shape_id);
532 for param in &shape.params {
533 self.infer_from_type(var, param.type_id);
534 }
535 if let Some(this_type) = shape.this_type {
536 self.infer_from_type(var, this_type);
537 }
538 self.infer_from_type(var, shape.return_type);
539 }
540 Some(TypeData::Conditional(cond_id)) => {
541 let cond = self.interner.conditional_type(cond_id);
542 self.infer_from_conditional(
543 var,
544 cond.check_type,
545 cond.extends_type,
546 cond.true_type,
547 cond.false_type,
548 );
549 }
550 Some(TypeData::TemplateLiteral(spans)) => {
551 let spans = self.interner.template_list(spans);
553 for span in spans.iter() {
554 if let TemplateSpan::Type(inner) = span {
555 self.infer_from_type(var, *inner);
556 }
557 }
558 }
559 _ => {}
560 }
561 }
562
563 pub(crate) fn contains_inference_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
565 let mut visited = FxHashSet::default();
566 self.contains_inference_var_inner(ty, var, &mut visited, 0)
567 }
568
569 fn contains_inference_var_inner(
570 &mut self,
571 ty: TypeId,
572 var: InferenceVar,
573 visited: &mut FxHashSet<TypeId>,
574 depth: usize,
575 ) -> bool {
576 if depth > MAX_TYPE_RECURSION_DEPTH {
578 return false;
579 }
580 if !visited.insert(ty) {
582 return false;
583 }
584
585 let root = self.table.find(var);
586
587 match self.interner.lookup(ty) {
588 Some(TypeData::TypeParameter(info) | TypeData::Infer(info)) => {
589 if let Some(param_var) = self.find_type_param(info.name) {
590 self.table.find(param_var) == root
591 } else {
592 false
593 }
594 }
595 Some(TypeData::Array(elem)) => {
596 self.contains_inference_var_inner(elem, var, visited, depth + 1)
597 }
598 Some(TypeData::Tuple(elements)) => {
599 let elements = self.interner.tuple_list(elements);
600 elements
601 .iter()
602 .any(|e| self.contains_inference_var_inner(e.type_id, var, visited, depth + 1))
603 }
604 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
605 let members = self.interner.type_list(members);
606 members
607 .iter()
608 .any(|&m| self.contains_inference_var_inner(m, var, visited, depth + 1))
609 }
610 Some(TypeData::Object(shape_id)) => {
611 let shape = self.interner.object_shape(shape_id);
612 shape
613 .properties
614 .iter()
615 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
616 }
617 Some(TypeData::ObjectWithIndex(shape_id)) => {
618 let shape = self.interner.object_shape(shape_id);
619 shape
620 .properties
621 .iter()
622 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
623 || shape.string_index.as_ref().is_some_and(|idx| {
624 self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
625 || self.contains_inference_var_inner(
626 idx.value_type,
627 var,
628 visited,
629 depth + 1,
630 )
631 })
632 || shape.number_index.as_ref().is_some_and(|idx| {
633 self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
634 || self.contains_inference_var_inner(
635 idx.value_type,
636 var,
637 visited,
638 depth + 1,
639 )
640 })
641 }
642 Some(TypeData::Application(app_id)) => {
643 let app = self.interner.type_application(app_id);
644 self.contains_inference_var_inner(app.base, var, visited, depth + 1)
645 || app
646 .args
647 .iter()
648 .any(|&arg| self.contains_inference_var_inner(arg, var, visited, depth + 1))
649 }
650 Some(TypeData::Function(shape_id)) => {
651 let shape = self.interner.function_shape(shape_id);
652 shape
653 .params
654 .iter()
655 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
656 || shape.this_type.is_some_and(|t| {
657 self.contains_inference_var_inner(t, var, visited, depth + 1)
658 })
659 || self.contains_inference_var_inner(shape.return_type, var, visited, depth + 1)
660 }
661 Some(TypeData::Conditional(cond_id)) => {
662 let cond = self.interner.conditional_type(cond_id);
663 self.contains_inference_var_inner(cond.check_type, var, visited, depth + 1)
664 || self.contains_inference_var_inner(cond.extends_type, var, visited, depth + 1)
665 || self.contains_inference_var_inner(cond.true_type, var, visited, depth + 1)
666 || self.contains_inference_var_inner(cond.false_type, var, visited, depth + 1)
667 }
668 Some(TypeData::TemplateLiteral(spans)) => {
669 let spans = self.interner.template_list(spans);
670 spans.iter().any(|span| match span {
671 TemplateSpan::Text(_) => false,
672 TemplateSpan::Type(inner) => {
673 self.contains_inference_var_inner(*inner, var, visited, depth + 1)
674 }
675 })
676 }
677 _ => false,
678 }
679 }
680
681 pub fn compute_variance(&self, ty: TypeId, target_param: Atom) -> (u32, u32, u32, u32) {
688 let mut covariant = 0u32;
689 let mut contravariant = 0u32;
690 let invariant = 0u32;
691 let bivariant = 0u32;
692 let mut state = VarianceState {
693 target_param,
694 covariant: &mut covariant,
695 contravariant: &mut contravariant,
696 };
697
698 self.compute_variance_helper(ty, true, &mut state);
699
700 (covariant, contravariant, invariant, bivariant)
701 }
702
703 fn compute_variance_helper(
704 &self,
705 ty: TypeId,
706 polarity: bool, state: &mut VarianceState<'_>,
708 ) {
709 match self.interner.lookup(ty) {
710 Some(TypeData::TypeParameter(info)) if info.name == state.target_param => {
711 if polarity {
712 *state.covariant += 1;
713 } else {
714 *state.contravariant += 1;
715 }
716 }
717 Some(TypeData::Array(elem)) => {
718 self.compute_variance_helper(elem, polarity, state);
719 }
720 Some(TypeData::Tuple(elements)) => {
721 let elements = self.interner.tuple_list(elements);
722 for elem in elements.iter() {
723 self.compute_variance_helper(elem.type_id, polarity, state);
724 }
725 }
726 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
727 let members = self.interner.type_list(members);
728 for &member in members.iter() {
729 self.compute_variance_helper(member, polarity, state);
730 }
731 }
732 Some(TypeData::Object(shape_id)) => {
733 let shape = self.interner.object_shape(shape_id);
734 for prop in &shape.properties {
735 self.compute_variance_helper(prop.type_id, polarity, state);
737 if prop.write_type != prop.type_id && !prop.readonly {
739 self.compute_variance_helper(prop.write_type, !polarity, state);
740 }
741 }
742 }
743 Some(TypeData::ObjectWithIndex(shape_id)) => {
744 let shape = self.interner.object_shape(shape_id);
745 for prop in &shape.properties {
746 self.compute_variance_helper(prop.type_id, polarity, state);
747 if prop.write_type != prop.type_id && !prop.readonly {
748 self.compute_variance_helper(prop.write_type, !polarity, state);
749 }
750 }
751 if let Some(index) = shape.string_index.as_ref() {
752 self.compute_variance_helper(index.value_type, polarity, state);
753 }
754 if let Some(index) = shape.number_index.as_ref() {
755 self.compute_variance_helper(index.value_type, polarity, state);
756 }
757 }
758 Some(TypeData::Application(app_id)) => {
759 let app = self.interner.type_application(app_id);
760 for &arg in &app.args {
763 self.compute_variance_helper(arg, polarity, state);
764 }
765 }
766 Some(TypeData::Function(shape_id)) => {
767 let shape = self.interner.function_shape(shape_id);
768 for param in &shape.params {
770 self.compute_variance_helper(param.type_id, !polarity, state);
771 }
772 self.compute_variance_helper(shape.return_type, polarity, state);
774 }
775 Some(TypeData::Conditional(cond_id)) => {
776 let cond = self.interner.conditional_type(cond_id);
777 self.compute_variance_helper(cond.check_type, false, state);
779 self.compute_variance_helper(cond.extends_type, false, state);
780 self.compute_variance_helper(cond.true_type, polarity, state);
782 self.compute_variance_helper(cond.false_type, polarity, state);
783 }
784 _ => {}
785 }
786 }
787
788 pub fn is_invariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
790 let (_, _, invariant, _) = self.compute_variance(ty, target_param);
791 invariant > 0
792 }
793
794 pub fn is_bivariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
796 let (_, _, _, bivariant) = self.compute_variance(ty, target_param);
797 bivariant > 0
798 }
799
800 pub fn get_variance(&self, ty: TypeId, target_param: Atom) -> &'static str {
802 let (covariant, contravariant, invariant, bivariant) =
803 self.compute_variance(ty, target_param);
804
805 if invariant > 0 {
806 "invariant"
807 } else if bivariant > 0 {
808 "bivariant"
809 } else if covariant > 0 && contravariant > 0 {
810 "invariant" } else if covariant > 0 {
812 "covariant"
813 } else if contravariant > 0 {
814 "contravariant"
815 } else {
816 "unused"
817 }
818 }
819
820 pub fn infer_from_context(
828 &mut self,
829 var: InferenceVar,
830 context_type: TypeId,
831 ) -> Result<(), InferenceError> {
832 self.add_upper_bound(var, context_type);
834
835 let root = self.table.find(var);
838 if self.contains_inference_var(context_type, root) {
839 return Err(InferenceError::OccursCheck {
842 var: root,
843 ty: context_type,
844 });
845 }
846
847 Ok(())
848 }
849
850 fn unify_circular_constraints(&mut self) -> Result<(), InferenceError> {
854 use rustc_hash::{FxHashMap, FxHashSet};
855
856 let type_params: Vec<_> = self.type_params.clone();
857
858 let mut graph: FxHashMap<InferenceVar, FxHashSet<InferenceVar>> = FxHashMap::default();
860 let mut var_for_param: FxHashMap<Atom, InferenceVar> = FxHashMap::default();
861
862 for (name, var, _) in &type_params {
863 let root = self.table.find(*var);
864 var_for_param.insert(*name, root);
865 graph.entry(root).or_default();
866 }
867
868 for (_name, var, _) in &type_params {
870 let root = self.table.find(*var);
871 let info = self.table.probe_value(root);
872
873 for &upper in &info.upper_bounds {
874 if let Some(TypeData::TypeParameter(param_info)) = self.interner.lookup(upper)
876 && let Some(&upper_var) = var_for_param.get(¶m_info.name)
877 {
878 let upper_root = self.table.find(upper_var);
879 graph.entry(root).or_default().insert(upper_root);
881 }
882 }
883 }
884
885 let mut index_counter = 0;
887 let mut indices: FxHashMap<InferenceVar, usize> = FxHashMap::default();
888 let mut lowlink: FxHashMap<InferenceVar, usize> = FxHashMap::default();
889 let mut stack: Vec<InferenceVar> = Vec::new();
890 let mut on_stack: FxHashSet<InferenceVar> = FxHashSet::default();
891 let mut sccs: Vec<Vec<InferenceVar>> = Vec::new();
892
893 struct TarjanState<'a> {
894 graph: &'a FxHashMap<InferenceVar, FxHashSet<InferenceVar>>,
895 index_counter: &'a mut usize,
896 indices: &'a mut FxHashMap<InferenceVar, usize>,
897 lowlink: &'a mut FxHashMap<InferenceVar, usize>,
898 stack: &'a mut Vec<InferenceVar>,
899 on_stack: &'a mut FxHashSet<InferenceVar>,
900 sccs: &'a mut Vec<Vec<InferenceVar>>,
901 }
902
903 fn strongconnect(var: InferenceVar, state: &mut TarjanState) {
904 state.indices.insert(var, *state.index_counter);
905 state.lowlink.insert(var, *state.index_counter);
906 *state.index_counter += 1;
907 state.stack.push(var);
908 state.on_stack.insert(var);
909
910 if let Some(neighbors) = state.graph.get(&var) {
911 for &neighbor in neighbors {
912 if !state.indices.contains_key(&neighbor) {
913 strongconnect(neighbor, state);
914 let neighbor_low = *state.lowlink.get(&neighbor).unwrap_or(&0);
915 let var_low = state.lowlink.get_mut(&var).unwrap();
916 *var_low = (*var_low).min(neighbor_low);
917 } else if state.on_stack.contains(&neighbor) {
918 let neighbor_idx = *state.indices.get(&neighbor).unwrap_or(&0);
919 let var_low = state.lowlink.get_mut(&var).unwrap();
920 *var_low = (*var_low).min(neighbor_idx);
921 }
922 }
923 }
924
925 if *state.lowlink.get(&var).unwrap_or(&0) == *state.indices.get(&var).unwrap_or(&0) {
926 let mut scc = Vec::new();
927 loop {
928 let w = state.stack.pop().unwrap();
929 state.on_stack.remove(&w);
930 scc.push(w);
931 if w == var {
932 break;
933 }
934 }
935 state.sccs.push(scc);
936 }
937 }
938
939 for &var in graph.keys() {
941 if !indices.contains_key(&var) {
942 let mut state = TarjanState {
943 graph: &graph,
944 index_counter: &mut index_counter,
945 indices: &mut indices,
946 lowlink: &mut lowlink,
947 stack: &mut stack,
948 on_stack: &mut on_stack,
949 sccs: &mut sccs,
950 };
951 strongconnect(var, &mut state);
952 }
953 }
954
955 for scc in sccs {
957 if scc.len() > 1 {
958 let first = scc[0];
960 for &other in &scc[1..] {
961 self.unify_vars(first, other)?;
962 }
963 }
964 }
965
966 Ok(())
967 }
968
969 pub fn strengthen_constraints(&mut self) -> Result<(), InferenceError> {
972 self.unify_circular_constraints()?;
976
977 let type_params: Vec<_> = self.type_params.clone();
978 let mut changed = true;
979 let mut iterations = 0;
980
981 while changed && iterations < MAX_CONSTRAINT_ITERATIONS {
984 changed = false;
985 iterations += 1;
986
987 for (name, var, _) in &type_params {
988 let root = self.table.find(*var);
989
990 let info = self.table.probe_value(root).clone();
993
994 for &upper in &info.upper_bounds {
997 if self.propagate_candidates_to_upper(root, upper, *name)? {
998 changed = true;
999 }
1000 }
1001 }
1002 }
1003 Ok(())
1004 }
1005
1006 fn propagate_candidates_to_upper(
1009 &mut self,
1010 var_root: InferenceVar,
1011 upper: TypeId,
1012 exclude_param: Atom,
1013 ) -> Result<bool, InferenceError> {
1014 if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(upper)
1016 && info.name != exclude_param
1017 && let Some(upper_var) = self.find_type_param(info.name)
1018 {
1019 let upper_root = self.table.find(upper_var);
1020
1021 if var_root == upper_root {
1023 return Ok(false);
1024 }
1025
1026 let var_candidates = self.table.probe_value(var_root).candidates;
1028
1029 let mut changed = false;
1031 for candidate in var_candidates {
1032 if self.add_candidate_if_new(
1034 upper_root,
1035 candidate.type_id,
1036 InferencePriority::Circular,
1037 ) {
1038 changed = true;
1039 }
1040 }
1041 return Ok(changed);
1042 }
1043 Ok(false)
1044 }
1045
1046 fn add_candidate_if_new(
1048 &mut self,
1049 var: InferenceVar,
1050 ty: TypeId,
1051 priority: InferencePriority,
1052 ) -> bool {
1053 let root = self.table.find(var);
1054 let info = self.table.probe_value(root);
1055
1056 if info.candidates.iter().any(|c| c.type_id == ty) {
1058 return false;
1059 }
1060
1061 self.add_candidate(var, ty, priority);
1062 true
1063 }
1064
1065 pub fn validate_variance(&mut self) -> Result<(), InferenceError> {
1067 let type_params: Vec<_> = self.type_params.clone();
1068 for (_name, var, _) in &type_params {
1069 let resolved = match self.probe(*var) {
1070 Some(ty) => ty,
1071 None => continue,
1072 };
1073
1074 if self.occurs_in(*var, resolved) {
1077 let root = self.table.find(*var);
1078 return Err(InferenceError::OccursCheck {
1080 var: root,
1081 ty: resolved,
1082 });
1083 }
1084
1085 }
1089
1090 Ok(())
1091 }
1092
1093 pub fn fix_current_variables(&mut self) -> Result<(), InferenceError> {
1106 let type_params: Vec<_> = self.type_params.clone();
1107
1108 for (_name, var, _is_const) in &type_params {
1109 let root = self.table.find(*var);
1110 let info = self.table.probe_value(root);
1111
1112 if info.resolved.is_some() {
1114 continue;
1115 }
1116
1117 if info.candidates.is_empty() {
1119 continue;
1120 }
1121
1122 let is_const = self.is_var_const(root);
1126 let result =
1127 self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds);
1128
1129 if self.occurs_in(root, result) {
1131 continue;
1133 }
1134
1135 self.table.union_value(
1138 root,
1139 InferenceInfo {
1140 resolved: Some(result),
1141 candidates: info.candidates,
1143 upper_bounds: info.upper_bounds,
1144 },
1145 );
1146 }
1147
1148 Ok(())
1149 }
1150
1151 pub fn get_current_substitution(&mut self) -> TypeSubstitution {
1157 let mut subst = TypeSubstitution::new();
1158 let type_params: Vec<_> = self.type_params.clone();
1159
1160 for (name, var, _) in &type_params {
1161 let ty = match self.probe(*var) {
1162 Some(resolved) => {
1163 tracing::trace!(
1164 ?name,
1165 ?var,
1166 ?resolved,
1167 "get_current_substitution: already resolved"
1168 );
1169 resolved
1170 }
1171 None => {
1172 let root = self.table.find(*var);
1174 let info = self.table.probe_value(root);
1175 tracing::trace!(
1176 ?name, ?var,
1177 candidates_count = info.candidates.len(),
1178 upper_bounds_count = info.upper_bounds.len(),
1179 upper_bounds = ?info.upper_bounds,
1180 "get_current_substitution: not resolved"
1181 );
1182
1183 if !info.candidates.is_empty() {
1184 let is_const = self.is_var_const(root);
1185 self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds)
1186 } else if !info.upper_bounds.is_empty() {
1187 if info.upper_bounds.len() == 1 {
1193 info.upper_bounds[0]
1194 } else {
1195 self.interner.intersection(info.upper_bounds.to_vec())
1196 }
1197 } else {
1198 TypeId::UNKNOWN
1200 }
1201 }
1202 };
1203
1204 subst.insert(*name, ty);
1205 }
1206
1207 subst
1208 }
1209}