1use crate::define_index;
10use crate::symbol::{FileId, SymbolId};
11use crate::SymbolKind;
12use serde::{Deserialize, Serialize};
13use slotmap::SecondaryMap;
14use smallvec::SmallVec;
15use std::collections::{HashMap, HashSet};
16
17define_index! {
22 pub struct EdgeId;
24}
25
26define_index! {
27 pub struct MatchExprId;
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum CodeEdgeV2 {
46 Contains,
48 Calls,
50 Implements,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub struct EdgeData {
57 pub from: SymbolId,
59 pub to: SymbolId,
61 pub kind: CodeEdgeV2,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
74pub struct MatchExprDataV2 {
75 pub file_id: FileId,
77 pub enum_id: SymbolId,
79 pub offset: u32,
81 pub line: u32,
83}
84
85#[derive(Clone, Default, Serialize)]
117pub struct CodeGraphV2 {
118 edges: Vec<EdgeData>,
121
122 outgoing: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
125 incoming: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
127
128 nodes: SecondaryMap<SymbolId, ()>,
131
132 by_kind: HashMap<SymbolKind, SmallVec<[SymbolId; 16]>>,
135 crate_roots: SmallVec<[SymbolId; 4]>,
137
138 match_expr_index: SecondaryMap<SymbolId, SmallVec<[MatchExprId; 2]>>,
141 match_exprs: Vec<MatchExprDataV2>,
143}
144
145impl CodeGraphV2 {
146 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn with_capacity(_nodes: usize, edges: usize) -> Self {
153 Self {
154 edges: Vec::with_capacity(edges),
155 outgoing: SecondaryMap::new(),
156 incoming: SecondaryMap::new(),
157 nodes: SecondaryMap::new(),
158 by_kind: HashMap::new(),
159 crate_roots: SmallVec::new(),
160 match_expr_index: SecondaryMap::new(),
161 match_exprs: Vec::new(),
162 }
163 }
164
165 pub fn add_node(&mut self, id: SymbolId) -> bool {
173 if self.nodes.contains_key(id) {
174 return false;
175 }
176 self.nodes.insert(id, ());
177 true
178 }
179
180 #[inline]
182 pub fn contains(&self, id: SymbolId) -> bool {
183 self.nodes.contains_key(id)
184 }
185
186 pub fn remove_node(&mut self, id: SymbolId) -> bool {
190 if self.nodes.remove(id).is_none() {
191 return false;
192 }
193
194 self.outgoing.remove(id);
196 self.incoming.remove(id);
197
198 for symbols in self.by_kind.values_mut() {
200 symbols.retain(|s| *s != id);
201 }
202
203 self.crate_roots.retain(|s| *s != id);
205
206 self.match_expr_index.remove(id);
208
209 true
214 }
215
216 pub fn clear_outgoing_edges(&mut self, id: SymbolId) {
221 if let Some(edge_ids) = self.outgoing.remove(id) {
222 for edge_id in edge_ids.iter().copied() {
224 if let Some(edge) = self.edges.get(edge_id.as_usize()) {
225 let target = edge.to;
226 if let Some(incoming) = self.incoming.get_mut(target) {
227 incoming.retain(|eid| *eid != edge_id);
228 }
229 }
230 }
231 }
232 }
233
234 pub fn add_edge(&mut self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> EdgeId {
242 self.add_node(from);
244 self.add_node(to);
245
246 let edge_id = EdgeId::from_raw(self.edges.len() as u32);
248 self.edges.push(EdgeData { from, to, kind });
249
250 self.outgoing
252 .entry(from)
253 .expect("caller must supply a SymbolId already present in the SlotMap")
254 .or_default()
255 .push(edge_id);
256 self.incoming
257 .entry(to)
258 .expect("caller must supply a SymbolId already present in the SlotMap")
259 .or_default()
260 .push(edge_id);
261
262 edge_id
263 }
264
265 #[inline]
267 pub fn edge(&self, id: EdgeId) -> Option<&EdgeData> {
268 self.edges.get(id.as_usize())
269 }
270
271 pub fn has_edge(&self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> bool {
273 self.outgoing
274 .get(from)
275 .map(|edges| {
276 edges.iter().any(|&eid| {
277 self.edges
278 .get(eid.as_usize())
279 .map(|e| e.to == to && e.kind == kind)
280 .unwrap_or(false)
281 })
282 })
283 .unwrap_or(false)
284 }
285
286 pub fn add_crate_root(&mut self, id: SymbolId) {
292 self.add_node(id);
293 if !self.crate_roots.contains(&id) {
294 self.crate_roots.push(id);
295 }
296 }
297
298 #[inline]
300 pub fn crate_roots(&self) -> &[SymbolId] {
301 &self.crate_roots
302 }
303
304 pub fn add_to_kind_index(&mut self, id: SymbolId, kind: SymbolKind) {
310 let symbols = self.by_kind.entry(kind).or_default();
311 if !symbols.contains(&id) {
312 symbols.push(id);
313 }
314 }
315
316 pub fn iter_by_kind(&self, kind: SymbolKind) -> impl Iterator<Item = SymbolId> + '_ {
318 self.by_kind
319 .get(&kind)
320 .into_iter()
321 .flat_map(|v| v.iter().copied())
322 }
323
324 pub fn outgoing_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
330 self.outgoing
331 .get(id)
332 .into_iter()
333 .flat_map(|edges| edges.iter())
334 .filter_map(|&eid| self.edges.get(eid.as_usize()))
335 }
336
337 pub fn incoming_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
339 self.incoming
340 .get(id)
341 .into_iter()
342 .flat_map(|edges| edges.iter())
343 .filter_map(|&eid| self.edges.get(eid.as_usize()))
344 }
345
346 pub fn callers_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
348 let mut seen = HashSet::new();
349 self.incoming_edges(id)
350 .filter(|e| e.kind == CodeEdgeV2::Calls)
351 .map(|e| e.from)
352 .filter(move |&id| seen.insert(id))
353 }
354
355 pub fn callees_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
357 let mut seen = HashSet::new();
358 self.outgoing_edges(id)
359 .filter(|e| e.kind == CodeEdgeV2::Calls)
360 .map(|e| e.to)
361 .filter(move |&id| seen.insert(id))
362 }
363
364 pub fn implementors_of(&self, trait_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
366 self.incoming_edges(trait_id)
367 .filter(|e| e.kind == CodeEdgeV2::Implements)
368 .map(|e| e.from)
369 }
370
371 pub fn children_of(&self, parent_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
373 self.outgoing_edges(parent_id)
374 .filter(|e| e.kind == CodeEdgeV2::Contains)
375 .map(|e| e.to)
376 }
377
378 pub fn parent_of(&self, id: SymbolId) -> Option<SymbolId> {
380 self.incoming_edges(id)
381 .find(|e| e.kind == CodeEdgeV2::Contains)
382 .map(|e| e.from)
383 }
384
385 pub fn reference_count(&self, id: SymbolId) -> usize {
387 self.incoming_edges(id)
388 .filter(|e| e.kind == CodeEdgeV2::Calls)
389 .count()
390 }
391
392 pub fn impl_count(&self, id: SymbolId) -> usize {
394 self.incoming_edges(id)
395 .filter(|e| e.kind == CodeEdgeV2::Implements)
396 .count()
397 }
398
399 pub fn add_match_expr(&mut self, function_id: SymbolId, data: MatchExprDataV2) -> MatchExprId {
405 let expr_id = MatchExprId::from_raw(self.match_exprs.len() as u32);
406 self.match_exprs.push(data);
407
408 self.match_expr_index
409 .entry(function_id)
410 .expect("caller must supply a function SymbolId already present in the SlotMap")
411 .or_default()
412 .push(expr_id);
413
414 expr_id
415 }
416
417 pub fn match_exprs_in(
419 &self,
420 function_id: SymbolId,
421 ) -> impl Iterator<Item = &MatchExprDataV2> + '_ {
422 self.match_expr_index
423 .get(function_id)
424 .into_iter()
425 .flat_map(|ids| ids.iter())
426 .filter_map(|&id| self.match_exprs.get(id.as_usize()))
427 }
428
429 pub fn match_exprs_for_enum(
431 &self,
432 enum_id: SymbolId,
433 ) -> impl Iterator<Item = (SymbolId, &MatchExprDataV2)> + '_ {
434 self.match_expr_index
435 .iter()
436 .flat_map(move |(func_id, ids)| {
437 ids.iter()
438 .filter_map(|&id| self.match_exprs.get(id.as_usize()))
439 .filter(move |data| data.enum_id == enum_id)
440 .map(move |data| (func_id, data))
441 })
442 }
443
444 pub fn match_expr_count(&self) -> usize {
446 self.match_exprs.len()
447 }
448
449 #[inline]
455 pub fn node_count(&self) -> usize {
456 self.nodes.len()
457 }
458
459 #[inline]
461 pub fn edge_count(&self) -> usize {
462 self.edges.len()
463 }
464
465 #[inline]
467 pub fn is_empty(&self) -> bool {
468 self.nodes.is_empty()
469 }
470
471 pub fn callers_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
486 self.traverse_chain(start, max_depth, ChainDirection::Callers)
487 }
488
489 pub fn callees_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
500 self.traverse_chain(start, max_depth, ChainDirection::Callees)
501 }
502
503 fn traverse_chain(
505 &self,
506 start: SymbolId,
507 max_depth: usize,
508 direction: ChainDirection,
509 ) -> Vec<ChainNode> {
510 use std::collections::{HashSet, VecDeque};
511
512 let mut result = Vec::new();
513 let mut visited = HashSet::new();
514 let mut queue = VecDeque::new();
515
516 visited.insert(start);
517 queue.push_back((start, 0usize));
518
519 while let Some((current, depth)) = queue.pop_front() {
520 if depth > 0 {
521 result.push(ChainNode {
522 symbol: current,
523 depth,
524 });
525 }
526
527 if depth >= max_depth {
528 continue;
529 }
530
531 let neighbors: Vec<SymbolId> = match direction {
532 ChainDirection::Callers => self.callers_of(current).collect(),
533 ChainDirection::Callees => self.callees_of(current).collect(),
534 ChainDirection::TypeUsers | ChainDirection::TypeDeps => {
535 unreachable!("TypeUsers/TypeDeps must use TypeFlowGraphV2")
536 }
537 };
538
539 for neighbor in neighbors {
540 if !visited.contains(&neighbor) {
541 visited.insert(neighbor);
542 queue.push_back((neighbor, depth + 1));
543 }
544 }
545 }
546
547 result
548 }
549
550 pub fn analyze_chain(
552 &self,
553 start: SymbolId,
554 max_depth: usize,
555 direction: ChainDirection,
556 ) -> ChainResult {
557 let nodes = self.traverse_chain(start, max_depth, direction);
558
559 let mut by_depth: HashMap<usize, usize> = HashMap::new();
560 for node in &nodes {
561 *by_depth.entry(node.depth).or_default() += 1;
562 }
563
564 let max_actual_depth = nodes.iter().map(|n| n.depth).max().unwrap_or(0);
565
566 ChainResult {
567 start,
568 direction,
569 max_depth,
570 nodes,
571 max_actual_depth,
572 by_depth,
573 }
574 }
575}
576
577#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
585pub enum ChainDirection {
586 Callers,
588 Callees,
590 TypeUsers,
592 TypeDeps,
594}
595
596impl std::fmt::Display for ChainDirection {
597 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 match self {
599 ChainDirection::Callers => write!(f, "callers"),
600 ChainDirection::Callees => write!(f, "callees"),
601 ChainDirection::TypeUsers => write!(f, "type_users"),
602 ChainDirection::TypeDeps => write!(f, "type_deps"),
603 }
604 }
605}
606
607#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
609pub struct ChainNode {
610 pub symbol: SymbolId,
612 pub depth: usize,
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ChainResult {
621 pub start: SymbolId,
623 pub direction: ChainDirection,
625 pub max_depth: usize,
627 pub nodes: Vec<ChainNode>,
629 pub max_actual_depth: usize,
631 pub by_depth: HashMap<usize, usize>,
633}
634
635impl ChainResult {
636 pub fn total_count(&self) -> usize {
638 self.nodes.len()
639 }
640
641 pub fn at_depth(&self, depth: usize) -> impl Iterator<Item = &ChainNode> {
643 self.nodes.iter().filter(move |n| n.depth == depth)
644 }
645
646 pub fn is_empty(&self) -> bool {
648 self.nodes.is_empty()
649 }
650
651 pub fn symbols(&self) -> impl Iterator<Item = SymbolId> + '_ {
653 self.nodes.iter().map(|n| n.symbol)
654 }
655}
656
657#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::symbol::{SymbolPath, SymbolRegistry};
665
666 fn setup() -> (SymbolRegistry, SymbolId, SymbolId, SymbolId) {
667 let mut registry = SymbolRegistry::new();
668 let id1 = registry
669 .register(SymbolPath::parse("foo::Bar").unwrap(), SymbolKind::Struct)
670 .unwrap();
671 let id2 = registry
672 .register(SymbolPath::parse("foo::baz").unwrap(), SymbolKind::Function)
673 .unwrap();
674 let id3 = registry
675 .register(SymbolPath::parse("foo::qux").unwrap(), SymbolKind::Function)
676 .unwrap();
677 (registry, id1, id2, id3)
678 }
679
680 #[test]
681 fn test_add_node() {
682 let (_, id1, _, _) = setup();
683 let mut graph = CodeGraphV2::new();
684
685 assert!(graph.add_node(id1));
686 assert!(!graph.add_node(id1)); assert!(graph.contains(id1));
688 }
689
690 #[test]
691 fn test_add_edge() {
692 let (_, id1, id2, _) = setup();
693 let mut graph = CodeGraphV2::new();
694
695 graph.add_edge(id1, id2, CodeEdgeV2::Contains);
696
697 assert!(graph.contains(id1));
698 assert!(graph.contains(id2));
699 assert!(graph.has_edge(id1, id2, CodeEdgeV2::Contains));
700 assert!(!graph.has_edge(id2, id1, CodeEdgeV2::Contains));
701 }
702
703 #[test]
704 fn test_callers_of() {
705 let (_, id1, id2, id3) = setup();
706 let mut graph = CodeGraphV2::new();
707
708 graph.add_edge(id1, id3, CodeEdgeV2::Calls);
709 graph.add_edge(id2, id3, CodeEdgeV2::Calls);
710
711 let callers: Vec<_> = graph.callers_of(id3).collect();
712 assert_eq!(callers.len(), 2);
713 assert!(callers.contains(&id1));
714 assert!(callers.contains(&id2));
715 }
716
717 #[test]
718 fn test_children_of() {
719 let (_, id1, id2, id3) = setup();
720 let mut graph = CodeGraphV2::new();
721
722 graph.add_edge(id1, id2, CodeEdgeV2::Contains);
723 graph.add_edge(id1, id3, CodeEdgeV2::Contains);
724
725 let children: Vec<_> = graph.children_of(id1).collect();
726 assert_eq!(children.len(), 2);
727 }
728
729 #[test]
730 fn test_parent_of() {
731 let (_, id1, id2, _) = setup();
732 let mut graph = CodeGraphV2::new();
733
734 graph.add_edge(id1, id2, CodeEdgeV2::Contains);
735
736 assert_eq!(graph.parent_of(id2), Some(id1));
737 assert_eq!(graph.parent_of(id1), None);
738 }
739
740 #[test]
741 fn test_remove_node() {
742 let (_, id1, id2, _) = setup();
743 let mut graph = CodeGraphV2::new();
744
745 graph.add_edge(id1, id2, CodeEdgeV2::Calls);
746 assert_eq!(graph.node_count(), 2);
747
748 assert!(graph.remove_node(id1));
749 assert_eq!(graph.node_count(), 1);
750 assert!(!graph.contains(id1));
751 assert!(graph.contains(id2));
752 }
753
754 #[test]
755 fn test_kind_index() {
756 let (_, id1, id2, id3) = setup();
757 let mut graph = CodeGraphV2::new();
758
759 graph.add_node(id1);
760 graph.add_node(id2);
761 graph.add_node(id3);
762
763 graph.add_to_kind_index(id1, SymbolKind::Struct);
764 graph.add_to_kind_index(id2, SymbolKind::Function);
765 graph.add_to_kind_index(id3, SymbolKind::Function);
766
767 let structs: Vec<_> = graph.iter_by_kind(SymbolKind::Struct).collect();
768 assert_eq!(structs.len(), 1);
769
770 let functions: Vec<_> = graph.iter_by_kind(SymbolKind::Function).collect();
771 assert_eq!(functions.len(), 2);
772 }
773
774 fn setup_chain() -> (
779 SymbolRegistry,
780 SymbolId,
781 SymbolId,
782 SymbolId,
783 SymbolId,
784 SymbolId,
785 ) {
786 let mut registry = SymbolRegistry::new();
787 let a = registry
789 .register(
790 SymbolPath::parse("test::fn_a").unwrap(),
791 SymbolKind::Function,
792 )
793 .unwrap();
794 let b = registry
795 .register(
796 SymbolPath::parse("test::fn_b").unwrap(),
797 SymbolKind::Function,
798 )
799 .unwrap();
800 let c = registry
801 .register(
802 SymbolPath::parse("test::fn_c").unwrap(),
803 SymbolKind::Function,
804 )
805 .unwrap();
806 let d = registry
807 .register(
808 SymbolPath::parse("test::fn_d").unwrap(),
809 SymbolKind::Function,
810 )
811 .unwrap();
812 let e = registry
813 .register(
814 SymbolPath::parse("test::fn_e").unwrap(),
815 SymbolKind::Function,
816 )
817 .unwrap();
818 (registry, a, b, c, d, e)
819 }
820
821 #[test]
822 fn test_callers_chain_simple() {
823 let (_, a, b, c, d, _) = setup_chain();
824 let mut graph = CodeGraphV2::new();
825
826 graph.add_edge(a, b, CodeEdgeV2::Calls);
828 graph.add_edge(b, c, CodeEdgeV2::Calls);
829 graph.add_edge(c, d, CodeEdgeV2::Calls);
830
831 let chain = graph.callers_chain(d, 10);
833 assert_eq!(chain.len(), 3);
834
835 let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
837 assert_eq!(c_node.depth, 1);
838
839 let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
840 assert_eq!(b_node.depth, 2);
841
842 let a_node = chain.iter().find(|n| n.symbol == a).unwrap();
843 assert_eq!(a_node.depth, 3);
844 }
845
846 #[test]
847 fn test_callees_chain_simple() {
848 let (_, a, b, c, d, _) = setup_chain();
849 let mut graph = CodeGraphV2::new();
850
851 graph.add_edge(a, b, CodeEdgeV2::Calls);
853 graph.add_edge(b, c, CodeEdgeV2::Calls);
854 graph.add_edge(c, d, CodeEdgeV2::Calls);
855
856 let chain = graph.callees_chain(a, 10);
858 assert_eq!(chain.len(), 3);
859
860 let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
861 assert_eq!(b_node.depth, 1);
862
863 let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
864 assert_eq!(c_node.depth, 2);
865
866 let d_node = chain.iter().find(|n| n.symbol == d).unwrap();
867 assert_eq!(d_node.depth, 3);
868 }
869
870 #[test]
871 fn test_chain_with_max_depth() {
872 let (_, a, b, c, d, _) = setup_chain();
873 let mut graph = CodeGraphV2::new();
874
875 graph.add_edge(a, b, CodeEdgeV2::Calls);
876 graph.add_edge(b, c, CodeEdgeV2::Calls);
877 graph.add_edge(c, d, CodeEdgeV2::Calls);
878
879 let chain = graph.callees_chain(a, 2);
881 assert_eq!(chain.len(), 2); let symbols: Vec<_> = chain.iter().map(|n| n.symbol).collect();
884 assert!(symbols.contains(&b));
885 assert!(symbols.contains(&c));
886 assert!(!symbols.contains(&d));
887 }
888
889 #[test]
890 fn test_chain_with_cycle() {
891 let (_, a, b, c, _, _) = setup_chain();
892 let mut graph = CodeGraphV2::new();
893
894 graph.add_edge(a, b, CodeEdgeV2::Calls);
896 graph.add_edge(b, c, CodeEdgeV2::Calls);
897 graph.add_edge(c, a, CodeEdgeV2::Calls);
898
899 let chain = graph.callees_chain(a, 10);
901 assert_eq!(chain.len(), 2); }
903
904 #[test]
905 fn test_analyze_chain() {
906 let (_, a, b, c, d, e) = setup_chain();
907 let mut graph = CodeGraphV2::new();
908
909 graph.add_edge(a, b, CodeEdgeV2::Calls);
911 graph.add_edge(b, c, CodeEdgeV2::Calls);
912 graph.add_edge(c, d, CodeEdgeV2::Calls);
913 graph.add_edge(d, e, CodeEdgeV2::Calls);
914
915 let result = graph.analyze_chain(a, 10, ChainDirection::Callees);
916
917 assert_eq!(result.start, a);
918 assert_eq!(result.direction, ChainDirection::Callees);
919 assert_eq!(result.total_count(), 4);
920 assert_eq!(result.max_actual_depth, 4);
921
922 assert_eq!(*result.by_depth.get(&1).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&2).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&3).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&4).unwrap_or(&0), 1); }
928}