tsz_solver/classes/
inheritance.rs1use fixedbitset::FixedBitSet;
8use rustc_hash::{FxHashMap, FxHashSet};
9use std::cell::RefCell;
10use std::collections::VecDeque;
11use tsz_binder::SymbolId;
12
13#[derive(Debug, Clone, Default)]
15struct ClassNode {
16 parents: Vec<SymbolId>,
18 children: Vec<SymbolId>,
20 ancestors_bitset: Option<FixedBitSet>,
23 mro: Option<Vec<SymbolId>>,
25}
26
27#[derive(Debug)]
28pub struct InheritanceGraph {
29 nodes: RefCell<FxHashMap<SymbolId, ClassNode>>,
31 max_symbol_id: RefCell<usize>,
33}
34
35impl Default for InheritanceGraph {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl InheritanceGraph {
42 pub fn new() -> Self {
43 Self {
44 nodes: RefCell::new(FxHashMap::default()),
45 max_symbol_id: RefCell::new(0),
46 }
47 }
48
49 pub fn add_inheritance(&self, child: SymbolId, parents: &[SymbolId]) {
55 let mut nodes = self.nodes.borrow_mut();
56 let mut max_id = self.max_symbol_id.borrow_mut();
57
58 *max_id = (*max_id).max(child.0 as usize);
60 for &p in parents {
61 *max_id = (*max_id).max(p.0 as usize);
62 }
63
64 let child_node = nodes.entry(child).or_default();
66
67 if child_node.parents == parents {
69 return;
70 }
71
72 child_node.parents = parents.to_vec();
73
74 child_node.ancestors_bitset = None;
76 child_node.mro = None;
77
78 for &parent in parents {
80 let parent_node = nodes.entry(parent).or_default();
81 if !parent_node.children.contains(&child) {
82 parent_node.children.push(child);
83 }
84 }
85 }
86
87 pub fn is_derived_from(&self, child: SymbolId, ancestor: SymbolId) -> bool {
92 if child == ancestor {
93 return true;
94 }
95
96 let nodes = self.nodes.borrow();
98 if !nodes.contains_key(&child) || !nodes.contains_key(&ancestor) {
99 return false;
100 }
101 drop(nodes); self.ensure_transitive_closure(child);
104
105 let nodes = self.nodes.borrow();
106 if let Some(node) = nodes.get(&child)
107 && let Some(bits) = &node.ancestors_bitset
108 {
109 return bits.contains(ancestor.0 as usize);
110 }
111
112 false
113 }
114
115 pub fn get_resolution_order(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
120 self.ensure_mro(symbol_id);
121
122 let nodes = self.nodes.borrow();
123 if let Some(node) = nodes.get(&symbol_id)
124 && let Some(mro) = &node.mro
125 {
126 return mro.clone();
127 }
128
129 vec![symbol_id] }
131
132 pub fn find_common_ancestor(&self, a: SymbolId, b: SymbolId) -> Option<SymbolId> {
137 if self.is_derived_from(a, b) {
138 return Some(b);
139 }
140 if self.is_derived_from(b, a) {
141 return Some(a);
142 }
143
144 self.ensure_transitive_closure(a);
145 self.ensure_transitive_closure(b);
146
147 let nodes = self.nodes.borrow();
148 let node_a = nodes.get(&a)?;
149 let node_b = nodes.get(&b)?;
150
151 let bits_a = node_a.ancestors_bitset.as_ref()?;
152 let bits_b = node_b.ancestors_bitset.as_ref()?;
153
154 let mut common = bits_a.clone();
156 common.intersect_with(bits_b);
157
158 drop(nodes); let mro_a = self.get_resolution_order(a);
165
166 mro_a
167 .into_iter()
168 .find(|&ancestor| self.is_derived_from(b, ancestor))
169 }
170
171 pub fn detects_cycle(&self, child: SymbolId, parent: SymbolId) -> bool {
173 self.is_derived_from(parent, child)
175 }
176
177 pub fn get_parents(&self, symbol_id: SymbolId) -> Vec<SymbolId> {
179 let nodes = self.nodes.borrow();
180 if let Some(node) = nodes.get(&symbol_id) {
181 node.parents.clone()
182 } else {
183 Vec::new()
184 }
185 }
186
187 fn ensure_transitive_closure(&self, symbol_id: SymbolId) {
193 let mut nodes = self.nodes.borrow_mut();
194
195 if let Some(node) = nodes.get(&symbol_id) {
197 if node.ancestors_bitset.is_some() {
198 return;
199 }
200 } else {
201 return; }
203
204 let max_len = *self.max_symbol_id.borrow() + 1;
206
207 let mut path = FxHashSet::default();
209
210 self.compute_closure_recursive(symbol_id, &mut nodes, &mut path, max_len);
211 }
212
213 #[allow(clippy::only_used_in_recursion)]
214 fn compute_closure_recursive(
215 &self,
216 current: SymbolId,
217 nodes: &mut FxHashMap<SymbolId, ClassNode>,
218 path: &mut FxHashSet<SymbolId>,
219 bitset_len: usize,
220 ) {
221 if path.contains(¤t) {
222 return;
226 }
227
228 if let Some(node) = nodes.get(¤t)
230 && node.ancestors_bitset.is_some()
231 {
232 return;
233 }
234
235 path.insert(current);
236
237 let parents = if let Some(node) = nodes.get(¤t) {
239 node.parents.clone()
240 } else {
241 Vec::new()
242 };
243
244 let mut my_bits = FixedBitSet::with_capacity(bitset_len);
245
246 for parent in parents {
247 self.compute_closure_recursive(parent, nodes, path, bitset_len);
249
250 my_bits.insert(parent.0 as usize);
252
253 if let Some(parent_node) = nodes.get(&parent)
255 && let Some(parent_bits) = &parent_node.ancestors_bitset
256 {
257 my_bits.union_with(parent_bits);
258 }
259 }
260
261 if let Some(node) = nodes.get_mut(¤t) {
263 node.ancestors_bitset = Some(my_bits);
264 }
265
266 path.remove(¤t);
267 }
268
269 fn ensure_mro(&self, symbol_id: SymbolId) {
271 let mut nodes = self.nodes.borrow_mut();
272
273 if let Some(node) = nodes.get(&symbol_id) {
274 if node.mro.is_some() {
275 return;
276 }
277 } else {
278 return;
279 }
280
281 let mut mro = Vec::new();
284 let mut visited = FxHashSet::default();
285 let mut queue = VecDeque::new();
286
287 queue.push_back(symbol_id);
288
289 while let Some(current) = queue.pop_front() {
290 if !visited.insert(current) {
291 continue;
292 }
293
294 mro.push(current);
295
296 if let Some(node) = nodes.get(¤t) {
297 for parent in &node.parents {
300 queue.push_back(*parent);
301 }
302 }
303 }
304
305 if let Some(node) = nodes.get_mut(&symbol_id) {
306 node.mro = Some(mro);
307 }
308 }
309
310 pub fn clear(&self) {
312 self.nodes.borrow_mut().clear();
313 *self.max_symbol_id.borrow_mut() = 0;
314 }
315
316 pub fn len(&self) -> usize {
318 self.nodes.borrow().len()
319 }
320
321 pub fn is_empty(&self) -> bool {
323 self.nodes.borrow().is_empty()
324 }
325}
326
327#[cfg(test)]
328#[path = "../../tests/inheritance_tests.rs"]
329mod tests;