Skip to main content

tsz_solver/classes/
inheritance.rs

1//! Inheritance Graph Solver
2//!
3//! Manages the nominal inheritance relationships between classes and interfaces.
4//! Provides O(1) subtype checks via lazy transitive closure and handles
5//! Method Resolution Order (MRO) for member lookup.
6
7use fixedbitset::FixedBitSet;
8use rustc_hash::{FxHashMap, FxHashSet};
9use std::cell::RefCell;
10use std::collections::VecDeque;
11use tsz_binder::SymbolId;
12
13/// Represents a node in the inheritance graph.
14#[derive(Debug, Clone, Default)]
15struct ClassNode {
16    /// Direct parents (extends and implements)
17    parents: Vec<SymbolId>,
18    /// Children (for invalidation/reverse lookup)
19    children: Vec<SymbolId>,
20    /// Cached transitive closure (all ancestors)
21    /// If None, it needs to be computed.
22    ancestors_bitset: Option<FixedBitSet>,
23    /// Cached Method Resolution Order (linearized ancestors)
24    mro: Option<Vec<SymbolId>>,
25}
26
27#[derive(Debug)]
28pub struct InheritanceGraph {
29    /// Map from `SymbolId` to graph node data
30    nodes: RefCell<FxHashMap<SymbolId, ClassNode>>,
31    /// Maximum `SymbolId` seen so far (for `BitSet` sizing)
32    max_symbol_id: RefCell<usize>,
33}
34
35impl Default for InheritanceGraph {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl InheritanceGraph {
42    pub fn new() -> Self {
43        Self {
44            nodes: RefCell::new(FxHashMap::default()),
45            max_symbol_id: RefCell::new(0),
46        }
47    }
48
49    /// Register a class or interface and its direct parents.
50    ///
51    /// # Arguments
52    /// * `child` - The `SymbolId` of the class/interface being defined
53    /// * `parents` - List of `SymbolIds` this type extends or implements
54    pub fn add_inheritance(&self, child: SymbolId, parents: &[SymbolId]) {
55        let mut nodes = self.nodes.borrow_mut();
56        let mut max_id = self.max_symbol_id.borrow_mut();
57
58        // Update max ID for bitset sizing
59        *max_id = (*max_id).max(child.0 as usize);
60        for &p in parents {
61            *max_id = (*max_id).max(p.0 as usize);
62        }
63
64        // Register child
65        let child_node = nodes.entry(child).or_default();
66
67        // Check if edges actually changed to avoid invalidating cache unnecessarily
68        if child_node.parents == parents {
69            return;
70        }
71
72        child_node.parents = parents.to_vec();
73
74        // Invalidate caches
75        child_node.ancestors_bitset = None;
76        child_node.mro = None;
77
78        // Register reverse edges (for future invalidation logic)
79        for &parent in parents {
80            let parent_node = nodes.entry(parent).or_default();
81            if !parent_node.children.contains(&child) {
82                parent_node.children.push(child);
83            }
84        }
85    }
86
87    /// Checks if `child` is a subtype of `ancestor` nominally.
88    ///
89    /// This is an O(1) operation after the first lazy computation.
90    /// Returns `true` if `child` extends or implements `ancestor` (transitively).
91    pub fn is_derived_from(&self, child: SymbolId, ancestor: SymbolId) -> bool {
92        if child == ancestor {
93            return true;
94        }
95
96        // Fast path: check if nodes exist
97        let nodes = self.nodes.borrow();
98        if !nodes.contains_key(&child) || !nodes.contains_key(&ancestor) {
99            return false;
100        }
101        drop(nodes); // Release borrow for compute
102
103        self.ensure_transitive_closure(child);
104
105        let nodes = self.nodes.borrow();
106        if let Some(node) = nodes.get(&child)
107            && let Some(bits) = &node.ancestors_bitset
108        {
109            return bits.contains(ancestor.0 as usize);
110        }
111
112        false
113    }
114
115    /// Gets the Method Resolution Order (MRO) for a symbol.
116    ///
117    /// Returns a list of `SymbolIds` in the order they should be searched for members.
118    /// Implements a depth-first, left-to-right traversal (standard for TS/JS).
119    pub fn get_resolution_order(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
120        self.ensure_mro(symbol_id);
121
122        let nodes = self.nodes.borrow();
123        if let Some(node) = nodes.get(&symbol_id)
124            && let Some(mro) = &node.mro
125        {
126            return mro.clone();
127        }
128
129        vec![symbol_id] // Fallback: just the symbol itself
130    }
131
132    /// Finds the Least Upper Bound (common ancestor) of two symbols.
133    ///
134    /// Returns the most specific symbol that both A and B inherit from.
135    /// In cases of multiple inheritance (interfaces), this might return one of several valid candidates.
136    pub fn find_common_ancestor(&self, a: SymbolId, b: SymbolId) -> Option<SymbolId> {
137        if self.is_derived_from(a, b) {
138            return Some(b);
139        }
140        if self.is_derived_from(b, a) {
141            return Some(a);
142        }
143
144        self.ensure_transitive_closure(a);
145        self.ensure_transitive_closure(b);
146
147        let nodes = self.nodes.borrow();
148        let node_a = nodes.get(&a)?;
149        let node_b = nodes.get(&b)?;
150
151        let bits_a = node_a.ancestors_bitset.as_ref()?;
152        let bits_b = node_b.ancestors_bitset.as_ref()?;
153
154        // Intersection of ancestors
155        let mut common = bits_a.clone();
156        common.intersect_with(bits_b);
157
158        // We want the "lowest" (most specific) ancestor.
159        // In a topological sort, this is usually the one with the longest path or
160        // appearing earliest in MRO.
161        // Simplified approach: Iterate A's MRO and return the first one present in B's ancestors.
162
163        drop(nodes); // Release for MRO check
164        let mro_a = self.get_resolution_order(a);
165
166        mro_a
167            .into_iter()
168            .find(|&ancestor| self.is_derived_from(b, ancestor))
169    }
170
171    /// Detects if adding an edge would create a cycle.
172    pub fn detects_cycle(&self, child: SymbolId, parent: SymbolId) -> bool {
173        // If parent is already derived from child, adding child->parent creates a cycle
174        self.is_derived_from(parent, child)
175    }
176
177    /// Get the direct parents of a symbol (for cycle detection).
178    pub fn get_parents(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
179        let nodes = self.nodes.borrow();
180        if let Some(node) = nodes.get(&symbol_id) {
181            node.parents.clone()
182        } else {
183            Vec::new()
184        }
185    }
186
187    // =========================================================================
188    // Internal Lazy Computation Methods
189    // =========================================================================
190
191    /// Lazily computes the transitive closure (ancestor bitset) for a node.
192    fn ensure_transitive_closure(&self, symbol_id: SymbolId) {
193        let mut nodes = self.nodes.borrow_mut();
194
195        // If already computed, return
196        if let Some(node) = nodes.get(&symbol_id) {
197            if node.ancestors_bitset.is_some() {
198                return;
199            }
200        } else {
201            return; // Node doesn't exist
202        }
203
204        // Stack for DFS
205        let max_len = *self.max_symbol_id.borrow() + 1;
206
207        // Cycle detection set for this traversal
208        let mut path = FxHashSet::default();
209
210        self.compute_closure_recursive(symbol_id, &mut nodes, &mut path, max_len);
211    }
212
213    #[allow(clippy::only_used_in_recursion)]
214    fn compute_closure_recursive(
215        &self,
216        current: SymbolId,
217        nodes: &mut FxHashMap<SymbolId, ClassNode>,
218        path: &mut FxHashSet<SymbolId>,
219        bitset_len: usize,
220    ) {
221        if path.contains(&current) {
222            // Cycle detected, stop recursion here.
223            // In a real compiler, we might emit a diagnostic here,
224            // but the solver just wants to avoid infinite loops.
225            return;
226        }
227
228        // If already computed, we are good
229        if let Some(node) = nodes.get(&current)
230            && node.ancestors_bitset.is_some()
231        {
232            return;
233        }
234
235        path.insert(current);
236
237        // Clone parents to avoid borrowing issues during recursion
238        let parents = if let Some(node) = nodes.get(&current) {
239            node.parents.clone()
240        } else {
241            Vec::new()
242        };
243
244        let mut my_bits = FixedBitSet::with_capacity(bitset_len);
245
246        for parent in parents {
247            // Ensure parent is computed
248            self.compute_closure_recursive(parent, nodes, path, bitset_len);
249
250            // Add parent itself
251            my_bits.insert(parent.0 as usize);
252
253            // Add parent's ancestors
254            if let Some(parent_node) = nodes.get(&parent)
255                && let Some(parent_bits) = &parent_node.ancestors_bitset
256            {
257                my_bits.union_with(parent_bits);
258            }
259        }
260
261        // Save result
262        if let Some(node) = nodes.get_mut(&current) {
263            node.ancestors_bitset = Some(my_bits);
264        }
265
266        path.remove(&current);
267    }
268
269    /// Lazily computes the MRO for a node.
270    fn ensure_mro(&self, symbol_id: SymbolId) {
271        let mut nodes = self.nodes.borrow_mut();
272
273        if let Some(node) = nodes.get(&symbol_id) {
274            if node.mro.is_some() {
275                return;
276            }
277        } else {
278            return;
279        }
280
281        // Standard Depth-First Left-to-Right traversal for TypeScript
282        // (Note: Python uses C3, but TS is simpler)
283        let mut mro = Vec::new();
284        let mut visited = FxHashSet::default();
285        let mut queue = VecDeque::new();
286
287        queue.push_back(symbol_id);
288
289        while let Some(current) = queue.pop_front() {
290            if !visited.insert(current) {
291                continue;
292            }
293
294            mro.push(current);
295
296            if let Some(node) = nodes.get(&current) {
297                // Add parents to queue
298                // For class extends A implements B, C -> A, B, C
299                for parent in &node.parents {
300                    queue.push_back(*parent);
301                }
302            }
303        }
304
305        if let Some(node) = nodes.get_mut(&symbol_id) {
306            node.mro = Some(mro);
307        }
308    }
309
310    /// Clear all cached data (useful for testing or rebuilding)
311    pub fn clear(&self) {
312        self.nodes.borrow_mut().clear();
313        *self.max_symbol_id.borrow_mut() = 0;
314    }
315
316    /// Get the number of nodes in the graph
317    pub fn len(&self) -> usize {
318        self.nodes.borrow().len()
319    }
320
321    /// Check if the graph is empty
322    pub fn is_empty(&self) -> bool {
323        self.nodes.borrow().is_empty()
324    }
325}
326
327#[cfg(test)]
328#[path = "../../tests/inheritance_tests.rs"]
329mod tests;